Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from typing import Any, Dict, Tuple, Union

import pytest
import math
import torch
import torch.nn.functional as F

from transformer_engine.pytorch.quantization import FP8GlobalStateManager, get_fp8_te_dtype
from transformer_engine.common import recipe
Expand Down Expand Up @@ -893,6 +895,63 @@ def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout):
)


@pytest.mark.skipif(get_cudnn_version() < (9, 0, 0), reason="cuDNN 9.0.0+ is required.")
@pytest.mark.skipif(
get_device_compute_capability() < (9, 0), reason="THD is only supported on Hopper+."
)
@pytest.mark.parametrize("dtype", param_types_lean)
def test_dpa_thd_qv_head_dim_mismatch(dtype):
"""Test THD DotProductAttention when Q/V head dims differ."""
seq_len = 32
num_heads = 4
head_dim_qk = 128
head_dim_v = 64

q = torch.randn(seq_len, num_heads, head_dim_qk, device="cuda", dtype=dtype, requires_grad=True)
k = torch.randn(seq_len, num_heads, head_dim_qk, device="cuda", dtype=dtype, requires_grad=True)
v = torch.randn(seq_len, num_heads, head_dim_v, device="cuda", dtype=dtype, requires_grad=True)

cu_seqlens = torch.tensor([0, seq_len], device="cuda", dtype=torch.int32)

dpa = DotProductAttention(
num_heads,
(head_dim_qk, head_dim_v),
qkv_format="thd",
attn_mask_type="padding_causal",
).to(device="cuda", dtype=dtype)

out = dpa(
q,
k,
v,
qkv_format="thd",
cu_seqlens_q=cu_seqlens,
cu_seqlens_kv=cu_seqlens,
max_seqlen_q=seq_len,
max_seqlen_kv=seq_len,
attn_mask_type="padding_causal",
)

assert out.shape == (seq_len, num_heads * head_dim_v)
out.sum().backward()
assert q.grad is not None
assert k.grad is not None
assert v.grad is not None

# Reference attention (causal) in float32 for numerical check.
q_ref = q.detach().float()
k_ref = k.detach().float()
v_ref = v.detach().float()
scores = torch.einsum("thd,shd->ths", q_ref, k_ref) / math.sqrt(head_dim_qk)
causal_mask = torch.triu(
torch.ones(seq_len, seq_len, device="cuda", dtype=torch.bool), diagonal=1
)
scores = scores.masked_fill(causal_mask.unsqueeze(1), float("-inf"))
probs = F.softmax(scores, dim=-1)
ref = torch.einsum("ths,shd->thd", probs, v_ref).reshape(seq_len, -1)
torch.testing.assert_close(out.detach().float(), ref, rtol=5e-2, atol=5e-2)


def _run_dot_product_attention(
dtype: torch.dtype,
config: ModelConfig,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import logging

import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter

import transformer_engine_torch as tex
Expand Down Expand Up @@ -1234,6 +1235,21 @@ def forward(
inference_params=inference_params,
)

thd_qkv_format = q_format == "thd" and kv_format == "thd"
orig_v_dim = None
pad_v_for_thd = False
if (
thd_qkv_format
and not isinstance(value_layer, Float8Tensor)
and head_dim_qk != head_dim_v
):
orig_v_dim = value_layer.shape[-1]
if orig_v_dim < head_dim_qk:
# Pad V so THD attention can run when Q/V head dims differ.
value_layer = F.pad(value_layer, (0, head_dim_qk - orig_v_dim))
head_dim_v = value_layer.shape[-1]
pad_v_for_thd = True

# adjust max_seqlen and cu_seqlens for CP
cp_size = 1
if isinstance(self.cp_group, dist_group_type):
Expand Down Expand Up @@ -1437,6 +1453,35 @@ def forward(
else None
)

def _trim_thd_output(attn_out):
if not pad_v_for_thd:
return attn_out

def _trim_data(data):
if data.ndim == 2:
data = data.reshape(data.shape[0], num_attention_heads, head_dim_v)
data = data[..., :orig_v_dim]
return data.reshape(data.shape[0], -1)
if data.ndim == 3:
return data[..., :orig_v_dim]
return data

def _trim_tensor(out):
if out is None:
return out
if isinstance(out, Float8Tensor):
out_data = _trim_data(out._data)
return Float8Tensor.make_like(out, data=out_data, shape=out_data.shape)
return _trim_data(out)

if isinstance(attn_out, tuple):
return (_trim_tensor(attn_out[0]),) + attn_out[1:]
if isinstance(attn_out, list):
if attn_out:
attn_out[0] = _trim_tensor(attn_out[0])
return attn_out
return _trim_tensor(attn_out)

if use_flash_attention:
if core_attention_bias_type == "alibi":
alibi_slopes, _ = dpa_utils.get_alibi(
Expand All @@ -1446,7 +1491,7 @@ def forward(
max_seqlen_kv,
alibi_slopes=alibi_slopes,
)
return self.flash_attention(
attn_out = self.flash_attention(
query_layer,
key_layer,
value_layer,
Expand All @@ -1471,6 +1516,7 @@ def forward(
fp8_output=fp8_output,
num_splits=num_splits,
)
return _trim_thd_output(attn_out)

if use_fused_attention:
fu_core_attention_bias_type = core_attention_bias_type
Expand All @@ -1487,7 +1533,7 @@ def forward(
bottom_right_alignment=bottom_right_diagonal,
)
if checkpoint_core_attention:
return self._checkpointed_attention_forward(
attn_out = self._checkpointed_attention_forward(
self.fused_attention,
query_layer,
key_layer,
Expand Down Expand Up @@ -1519,7 +1565,8 @@ def forward(
softmax_offset=softmax_offset,
fp8_output=fp8_output,
)
return self.fused_attention(
return _trim_thd_output(attn_out)
attn_out = self.fused_attention(
query_layer,
key_layer,
value_layer,
Expand Down Expand Up @@ -1550,11 +1597,12 @@ def forward(
softmax_offset=softmax_offset,
fp8_output=fp8_output,
)
return _trim_thd_output(attn_out)

if use_unfused_attention:
allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1"
if checkpoint_core_attention:
return self._checkpointed_attention_forward(
attn_out = self._checkpointed_attention_forward(
self.unfused_attention,
_alibi_cache,
query_layer,
Expand All @@ -1579,7 +1627,8 @@ def forward(
quantizers=self.quantizers,
fp8_output=fp8_output,
)
return self.unfused_attention(
return _trim_thd_output(attn_out)
attn_out = self.unfused_attention(
_alibi_cache,
query_layer,
key_layer,
Expand All @@ -1603,4 +1652,5 @@ def forward(
quantizers=self.quantizers,
fp8_output=fp8_output,
)
return _trim_thd_output(attn_out)
return None