[PyTorch] Pad V when Q/V head dims differ (MLA) for THD#2629
[PyTorch] Pad V when Q/V head dims differ (MLA) for THD#2629HollowMan6 wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
Greptile OverviewGreptile SummaryAdds support for Multi-Latent Attention (MLA) in THD format when Q and V have different head dimensions. When V's head dimension is smaller than Q's, the implementation pads V to match Q's dimension before attention computation, then trims the output back to the original V dimension. Key changes:
Implementation details:
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller
participant DotProductAttention
participant AttentionBackend as Attention Backend<br/>(Flash/Fused/Unfused)
Caller->>DotProductAttention: forward(Q, K, V)<br/>head_dim_qk=128, head_dim_v=64
Note over DotProductAttention: Check THD format &<br/>head dim mismatch
alt head_dim_v < head_dim_qk
DotProductAttention->>DotProductAttention: Save orig_v_dim = 64
DotProductAttention->>DotProductAttention: Pad V: 64 → 128<br/>Set pad_v_for_thd = True
DotProductAttention->>DotProductAttention: Update head_dim_v = 128
end
DotProductAttention->>AttentionBackend: attention(Q, K, V_padded)
AttentionBackend-->>DotProductAttention: attn_out (head_dim=128)
alt pad_v_for_thd == True
DotProductAttention->>DotProductAttention: _trim_thd_output()
Note over DotProductAttention: Reshape using head_dim_v (128)<br/>Trim to orig_v_dim (64)
DotProductAttention->>DotProductAttention: attn_out[..., :64]
end
DotProductAttention-->>Caller: Return trimmed output<br/>(head_dim=64)
|
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py
Show resolved
Hide resolved
There was a problem hiding this comment.
Pull request overview
This PR adds support for Multi-head Latent Attention (MLA) with mismatched Q/V head dimensions in the THD (Total-Hidden-Dimension) format. When the value tensor has a smaller head dimension than the query/key tensors, the code pads the value tensor to match the Q/K head dimension, runs the attention operation, and then trims the output back to the original V dimension.
Changes:
- Added padding logic for V tensor when head dimensions differ in THD format
- Implemented trimming function to restore correct output dimensions after attention
- Added test case for THD attention with mismatched Q/V head dimensions
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py | Implements padding of V tensor before attention and trimming of output after attention for THD format with mismatched Q/V head dimensions |
| tests/pytorch/attention/test_attention.py | Adds test case to verify THD attention works with different Q/V head dimensions |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py
Show resolved
Hide resolved
Signed-off-by: Hollow Man <hollowman@opensuse.org>
Description
For MLA, we shall pad V when Q/V head dims differ for THD
Similar to NVIDIA/Megatron-LM#3003
Fixes NVIDIA/Megatron-LM#1698
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: