diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index bd0ac41974..f9566aa9ef 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -8,7 +8,9 @@ from typing import Any, Dict, Tuple, Union import pytest +import math import torch +import torch.nn.functional as F from transformer_engine.pytorch.quantization import FP8GlobalStateManager, get_fp8_te_dtype from transformer_engine.common import recipe @@ -893,6 +895,63 @@ def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout): ) +@pytest.mark.skipif(get_cudnn_version() < (9, 0, 0), reason="cuDNN 9.0.0+ is required.") +@pytest.mark.skipif( + get_device_compute_capability() < (9, 0), reason="THD is only supported on Hopper+." +) +@pytest.mark.parametrize("dtype", param_types_lean) +def test_dpa_thd_qv_head_dim_mismatch(dtype): + """Test THD DotProductAttention when Q/V head dims differ.""" + seq_len = 32 + num_heads = 4 + head_dim_qk = 128 + head_dim_v = 64 + + q = torch.randn(seq_len, num_heads, head_dim_qk, device="cuda", dtype=dtype, requires_grad=True) + k = torch.randn(seq_len, num_heads, head_dim_qk, device="cuda", dtype=dtype, requires_grad=True) + v = torch.randn(seq_len, num_heads, head_dim_v, device="cuda", dtype=dtype, requires_grad=True) + + cu_seqlens = torch.tensor([0, seq_len], device="cuda", dtype=torch.int32) + + dpa = DotProductAttention( + num_heads, + (head_dim_qk, head_dim_v), + qkv_format="thd", + attn_mask_type="padding_causal", + ).to(device="cuda", dtype=dtype) + + out = dpa( + q, + k, + v, + qkv_format="thd", + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + max_seqlen_q=seq_len, + max_seqlen_kv=seq_len, + attn_mask_type="padding_causal", + ) + + assert out.shape == (seq_len, num_heads * head_dim_v) + out.sum().backward() + assert q.grad is not None + assert k.grad is not None + assert v.grad is not None + + # Reference attention (causal) in float32 for numerical check. + q_ref = q.detach().float() + k_ref = k.detach().float() + v_ref = v.detach().float() + scores = torch.einsum("thd,shd->ths", q_ref, k_ref) / math.sqrt(head_dim_qk) + causal_mask = torch.triu( + torch.ones(seq_len, seq_len, device="cuda", dtype=torch.bool), diagonal=1 + ) + scores = scores.masked_fill(causal_mask.unsqueeze(1), float("-inf")) + probs = F.softmax(scores, dim=-1) + ref = torch.einsum("ths,shd->thd", probs, v_ref).reshape(seq_len, -1) + torch.testing.assert_close(out.detach().float(), ref, rtol=5e-2, atol=5e-2) + + def _run_dot_product_attention( dtype: torch.dtype, config: ModelConfig, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 5a554d86ec..8dc6cbc83b 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -11,6 +11,7 @@ import logging import torch +import torch.nn.functional as F from torch.nn.parameter import Parameter import transformer_engine_torch as tex @@ -1234,6 +1235,21 @@ def forward( inference_params=inference_params, ) + thd_qkv_format = q_format == "thd" and kv_format == "thd" + orig_v_dim = None + pad_v_for_thd = False + if ( + thd_qkv_format + and not isinstance(value_layer, Float8Tensor) + and head_dim_qk != head_dim_v + ): + orig_v_dim = value_layer.shape[-1] + if orig_v_dim < head_dim_qk: + # Pad V so THD attention can run when Q/V head dims differ. + value_layer = F.pad(value_layer, (0, head_dim_qk - orig_v_dim)) + head_dim_v = value_layer.shape[-1] + pad_v_for_thd = True + # adjust max_seqlen and cu_seqlens for CP cp_size = 1 if isinstance(self.cp_group, dist_group_type): @@ -1437,6 +1453,35 @@ def forward( else None ) + def _trim_thd_output(attn_out): + if not pad_v_for_thd: + return attn_out + + def _trim_data(data): + if data.ndim == 2: + data = data.reshape(data.shape[0], num_attention_heads, head_dim_v) + data = data[..., :orig_v_dim] + return data.reshape(data.shape[0], -1) + if data.ndim == 3: + return data[..., :orig_v_dim] + return data + + def _trim_tensor(out): + if out is None: + return out + if isinstance(out, Float8Tensor): + out_data = _trim_data(out._data) + return Float8Tensor.make_like(out, data=out_data, shape=out_data.shape) + return _trim_data(out) + + if isinstance(attn_out, tuple): + return (_trim_tensor(attn_out[0]),) + attn_out[1:] + if isinstance(attn_out, list): + if attn_out: + attn_out[0] = _trim_tensor(attn_out[0]) + return attn_out + return _trim_tensor(attn_out) + if use_flash_attention: if core_attention_bias_type == "alibi": alibi_slopes, _ = dpa_utils.get_alibi( @@ -1446,7 +1491,7 @@ def forward( max_seqlen_kv, alibi_slopes=alibi_slopes, ) - return self.flash_attention( + attn_out = self.flash_attention( query_layer, key_layer, value_layer, @@ -1471,6 +1516,7 @@ def forward( fp8_output=fp8_output, num_splits=num_splits, ) + return _trim_thd_output(attn_out) if use_fused_attention: fu_core_attention_bias_type = core_attention_bias_type @@ -1487,7 +1533,7 @@ def forward( bottom_right_alignment=bottom_right_diagonal, ) if checkpoint_core_attention: - return self._checkpointed_attention_forward( + attn_out = self._checkpointed_attention_forward( self.fused_attention, query_layer, key_layer, @@ -1519,7 +1565,8 @@ def forward( softmax_offset=softmax_offset, fp8_output=fp8_output, ) - return self.fused_attention( + return _trim_thd_output(attn_out) + attn_out = self.fused_attention( query_layer, key_layer, value_layer, @@ -1550,11 +1597,12 @@ def forward( softmax_offset=softmax_offset, fp8_output=fp8_output, ) + return _trim_thd_output(attn_out) if use_unfused_attention: allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" if checkpoint_core_attention: - return self._checkpointed_attention_forward( + attn_out = self._checkpointed_attention_forward( self.unfused_attention, _alibi_cache, query_layer, @@ -1579,7 +1627,8 @@ def forward( quantizers=self.quantizers, fp8_output=fp8_output, ) - return self.unfused_attention( + return _trim_thd_output(attn_out) + attn_out = self.unfused_attention( _alibi_cache, query_layer, key_layer, @@ -1603,4 +1652,5 @@ def forward( quantizers=self.quantizers, fp8_output=fp8_output, ) + return _trim_thd_output(attn_out) return None