diff --git a/tests/integration/model_bridge/test_refactor_factored_attn_matrices.py b/tests/integration/model_bridge/test_refactor_factored_attn_matrices.py new file mode 100644 index 000000000..4963ea650 --- /dev/null +++ b/tests/integration/model_bridge/test_refactor_factored_attn_matrices.py @@ -0,0 +1,99 @@ +"""Test refactor_factored_attn_matrices with TransformerBridge. + +Verifies that the refactored attention matrices produce correct results when +used via TransformerBridge, matching HookedTransformer output. +""" + +import pytest +import torch + +from transformer_lens import HookedTransformer +from transformer_lens.model_bridge import TransformerBridge + + +@pytest.fixture(scope="module") +def model_name(): + return "distilgpt2" + + +@pytest.fixture(scope="module") +def device(): + return "cpu" + + +@pytest.fixture(scope="module") +def test_text(): + return "Natural language processing" + + +@pytest.fixture(scope="module") +def reference_ht(model_name, device): + """HookedTransformer with refactor_factored_attn_matrices=True.""" + return HookedTransformer.from_pretrained( + model_name, + device=device, + refactor_factored_attn_matrices=True, + ) + + +def test_refactor_factored_attn_matrices_loss_matches(model_name, device, test_text, reference_ht): + """Bridge with refactor_factored_attn_matrices should match HookedTransformer.""" + ref_loss = reference_ht(test_text, return_type="loss") + + bridge = TransformerBridge.boot_transformers(model_name, device=device) + bridge.enable_compatibility_mode(refactor_factored_attn_matrices=True) + bridge_loss = bridge(test_text, return_type="loss") + + assert not torch.isnan(bridge_loss), "Bridge produced NaN loss" + assert not torch.isinf(bridge_loss), "Bridge produced infinite loss" + + loss_diff = abs(bridge_loss.item() - ref_loss.item()) + assert loss_diff < 1.0, ( + f"Loss difference too large: {loss_diff:.6f} " + f"(bridge={bridge_loss.item():.4f}, reference={ref_loss.item():.4f})" + ) + + +def test_refactor_factored_attn_matrices_logits_match(model_name, device, test_text, reference_ht): + """Bridge logits should closely match HookedTransformer logits after refactoring.""" + tokens = reference_ht.to_tokens(test_text) + ref_logits = reference_ht(tokens) + + bridge = TransformerBridge.boot_transformers(model_name, device=device) + bridge.enable_compatibility_mode(refactor_factored_attn_matrices=True) + bridge_logits = bridge(tokens) + + # Check shapes match + assert ( + ref_logits.shape == bridge_logits.shape + ), f"Shape mismatch: ref={ref_logits.shape}, bridge={bridge_logits.shape}" + + # Check values are close + max_diff = (ref_logits - bridge_logits).abs().max().item() + assert max_diff < 1.0, f"Max logit difference too large: {max_diff:.6f}" + + +def test_refactor_preserves_fold_ln(model_name, device, test_text): + """Refactoring should not undo fold_ln — both should be applied together.""" + # Reference: fold_ln=True + refactor=True + ref = HookedTransformer.from_pretrained( + model_name, + device=device, + fold_ln=True, + refactor_factored_attn_matrices=True, + ) + ref_loss = ref(test_text, return_type="loss") + + # Bridge: same settings + bridge = TransformerBridge.boot_transformers(model_name, device=device) + bridge.enable_compatibility_mode( + fold_ln=True, + refactor_factored_attn_matrices=True, + ) + bridge_loss = bridge(test_text, return_type="loss") + + loss_diff = abs(bridge_loss.item() - ref_loss.item()) + assert loss_diff < 1.0, ( + f"fold_ln + refactor mismatch: {loss_diff:.6f} " + f"(bridge={bridge_loss.item():.4f}, ref={ref_loss.item():.4f})" + ) diff --git a/transformer_lens/model_bridge/generalized_components/joint_qkv_attention.py b/transformer_lens/model_bridge/generalized_components/joint_qkv_attention.py index ce2eff8d8..3266920aa 100644 --- a/transformer_lens/model_bridge/generalized_components/joint_qkv_attention.py +++ b/transformer_lens/model_bridge/generalized_components/joint_qkv_attention.py @@ -105,6 +105,26 @@ def __init__( self._reference_model: Optional[Any] = None self._layer_idx: Optional[int] = None + # After splitting, the q/k/v LinearBridges hold the authoritative weights. + # The original qkv LinearBridge remains registered in _modules (so + # self.qkv is still accessible) but its parameters are stale copies of + # the pre-split combined weight. This hook excludes them from state_dict + # so weight processing steps never read unprocessed combined weights. + self._register_state_dict_hook(JointQKVAttentionBridge._filter_qkv_state_dict) + + @staticmethod + def _filter_qkv_state_dict( + module: torch.nn.Module, + state_dict: Dict[str, Any], + prefix: str, + local_metadata: Dict[str, Any], + ) -> None: + """State dict hook that removes stale combined QKV entries.""" + qkv_prefix = prefix + "qkv." + keys_to_remove = [k for k in state_dict if k.startswith(qkv_prefix)] + for k in keys_to_remove: + del state_dict[k] + def _create_qkv_conversion_rule(self) -> BaseTensorConversion: """Create the appropriate conversion rule for the individual q, k, and v matrices. diff --git a/transformer_lens/model_bridge/supported_architectures/gpt2.py b/transformer_lens/model_bridge/supported_architectures/gpt2.py index 88b629ac7..2fb5acc41 100644 --- a/transformer_lens/model_bridge/supported_architectures/gpt2.py +++ b/transformer_lens/model_bridge/supported_architectures/gpt2.py @@ -27,14 +27,21 @@ class QKVSplitRearrangeConversion(BaseTensorConversion): - """Custom conversion that splits QKV tensor and then rearranges.""" + """Custom conversion that splits QKV tensor and then rearranges. + + Handles two input formats: + - Combined QKV tensor (from HuggingFace): one dimension is ~3x the other. + Splits into Q/K/V parts, then rearranges to TL format. + - Already-split tensor (from bridge state dict): nn.Linear format + [n_heads*d_head, d_model]. Rearranges directly to TL format. + """ def __init__(self, qkv_index: int, rearrange_pattern: str, **axes_lengths): """Initialize the conversion. Args: qkv_index: Index of Q (0), K (1), or V (2) in the QKV tensor - rearrange_pattern: Einops pattern for rearrangement + rearrange_pattern: Einops pattern for rearrangement (Conv1D format) **axes_lengths: Additional axes lengths for einops """ super().__init__() @@ -42,25 +49,52 @@ def __init__(self, qkv_index: int, rearrange_pattern: str, **axes_lengths): self.rearrange_pattern = rearrange_pattern self.axes_lengths = axes_lengths + def _is_combined_qkv(self, tensor: torch.Tensor) -> bool: + """Check if a tensor is a combined QKV tensor vs already-split.""" + if tensor.ndim == 2: + d0, d1 = tensor.shape + return d1 > d0 * 2 or d0 > d1 * 2 + if tensor.ndim == 1: + n = self.axes_lengths.get("n", 1) + # Combined bias has 3x the expected individual size + return tensor.shape[0] % 3 == 0 and tensor.shape[0] > n * 3 + return False + def handle_conversion(self, input_value: torch.Tensor, *full_context) -> torch.Tensor: """Split QKV tensor and rearrange the selected part.""" - # Determine the split dimension based on tensor shape + if not self._is_combined_qkv(input_value): + # Already-split tensor in nn.Linear format [n_heads*d_head, d_model]. + # The original rearrange_pattern is "d_model (n h) -> n d_model h" + # (Conv1D format). For nn.Linear format, the dims are transposed: + return einops.rearrange( + input_value, "(n h) d_model -> n d_model h", **self.axes_lengths + ) + + # Combined QKV tensor — split then rearrange if len(input_value.shape) == 2: # Weight tensor: [d_model, 3*d_model] -> split along dim=1 - split_dim = 1 + split_dim = 1 if input_value.shape[1] > input_value.shape[0] else 0 elif len(input_value.shape) == 1: # Bias tensor: [3*n_heads*d_head] -> split along dim=0 split_dim = 0 else: raise ValueError(f"Unexpected tensor shape: {input_value.shape}") - # Split the QKV tensor qkv_parts = torch.tensor_split(input_value, 3, dim=split_dim) selected_part = qkv_parts[self.qkv_index] - - # Apply rearrangement return einops.rearrange(selected_part, self.rearrange_pattern, **self.axes_lengths) + def revert(self, input_value: torch.Tensor, *full_context) -> torch.Tensor: + """Revert from TL format [n_heads, d_model, d_head] to nn.Linear format.""" + if input_value.ndim == 3: + return einops.rearrange( + input_value, "n d_model h -> (n h) d_model", **self.axes_lengths + ) + if input_value.ndim == 2: + # Bias in TL format [n_heads, d_head] -> [n_heads*d_head] + return einops.rearrange(input_value, "n h -> (n h)", **self.axes_lengths) + return input_value + class GPT2ArchitectureAdapter(ArchitectureAdapter): """Architecture adapter for GPT2 models.