[2/n] Add sparse softmax to the Triton flash attention kernel#1078
[2/n] Add sparse softmax to the Triton flash attention kernel#1078
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
📝 WalkthroughWalkthroughAdds Triton-backed N:M sparse-softmax to flash attention: kernel tile helpers and constexpr params, autograd and public API plumbing, method/config registration and HF prefill gating, plus comprehensive GPU tests and a changelog entry. Changes
Sequence Diagram(s)sequenceDiagram
participant Caller as "Caller / Model"
participant API as "attention(...) / _Attention"
participant Autograd as "Autograd Function"
participant Triton as "Triton FA Kernel"
participant KV as "KV Storage / Tiles"
rect rgba(100,149,237,0.5)
Caller->>API: call attention(q,k,v,..., sparsity_n, sparsity_m, num_sink_tokens, dense_window_size)
API->>Autograd: _Attention.forward(sparse params)
Autograd->>Triton: launch forward kernel (constexpr sparsity params)
Triton->>KV: load QK tile
Triton->>Triton: apply N:M mask (tile-level) and set pruned scores to -inf
Triton->>Triton: softmax & compute output context
Triton-->>Autograd: return outputs + saved tensors (incl. sparsity params)
end
rect rgba(60,179,113,0.5)
Autograd->>Triton: backward launch with same sparsity params
Triton->>Triton: recompute masked scores (respect sink/window), compute dq/dk/dv
Triton-->>Autograd: return gradients
Autograd-->>Caller: propagate gradients
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 4✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #1078 +/- ##
==========================================
- Coverage 70.18% 70.18% -0.01%
==========================================
Files 228 229 +1
Lines 25952 26008 +56
==========================================
+ Hits 18215 18254 +39
- Misses 7737 7754 +17 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
8ba6efe to
7aa6960
Compare
7aa6960 to
31655ce
Compare
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 279-292: The N:M sparsity branch is being applied during decode
because it only checks SPARSITY_N > 0; change the condition to skip
sparsification when doing cached decode (seq_len_q == 1). Update the if guarding
the block (the one that currently reads "if SPARSITY_N > 0:") to also require
not decoding (e.g., "if SPARSITY_N > 0 and seq_len_q != 1:") so
_apply_sparse_nm_to_qk_tile(scores, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M) is
only called during prefill/non-decoding paths; keep the existing local/sink
logic (kv_start, tile_q, q_abs_block, is_local/is_sink) unchanged.
- Around line 279-292: The sparse-mask logic currently mixes tile-sized units
(BLOCK_M/BLOCK_N) with sparsity parameters (NUM_SINK_BLOCKS,
DENSE_WINDOW_BLOCKS) leading to inconsistent masks between forward/backward; fix
by computing mask membership in token-space instead of tile-space: derive each
KV block index and query row absolute token position from actual token counts
(use seq_len_kv, seq_len_q, kv_start and per-row start = tile_q * BLOCK_M +
row_offset or for whole-tile use tile_token_start = tile_q * BLOCK_M) and then
map those token positions into logical token-blocks of a fixed reference block
size (choose the constant used by backward kernels, e.g., 64 tokens) before
comparing to NUM_SINK_BLOCKS and DENSE_WINDOW_BLOCKS; update q_abs_block,
kv_block_idx, is_sink and is_local computations (the branch that calls
_apply_sparse_nm_to_qk_tile) and apply same normalization in the other
occurrences mentioned (around the other two locations) so forward and backward
use the same token-blocking semantics.
In `@modelopt/torch/sparsity/attention_sparsity/config.py`:
- Around line 99-130: Add validation to the config so invalid N:M and negative
counts are rejected when building the config instead of failing in the Triton
kernel: in the class/constructor where ModeloptField(s) sparsity_n, sparsity_m,
num_sink_blocks, and dense_window_blocks are defined, validate that sparsity_m
is either 4 or 8, sparsity_n is in {1,2,3} when sparsity_m==4 and in {1,2,3,4}
when sparsity_m==8 (or 0/disabled as your semantics require), and that
num_sink_blocks and dense_window_blocks are non-negative; also ensure that the
chosen sparsity mode is only allowed when triton_sparse_softmax is selected or
available and raise a clear config validation error if any rule is violated so
bad configs fail early.
In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py`:
- Around line 350-358: Add an explicit assertion after calling
mtsa.sparsify(...) to verify the Triton backend was applied: check that
model_sparse.config._attn_implementation == "modelopt_triton" (this should be
done right after mtsa.sparsify(...) returns and before comparing logits). Locate
the sparsification call (mtsa.sparsify(..., backend="triton", ...)) and add the
assertion referencing model_sparse.config._attn_implementation to ensure the
Triton kernel registration took effect.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: ce048172-0588-4aec-9474-44eb6c4cbe3b
📥 Commits
Reviewing files that changed from the base of the PR and between 7aa6960f1914d3ed5276afa346d4817d136af320 and 31655ce.
📒 Files selected for processing (9)
CHANGELOG.rstmodelopt/torch/kernels/hf_triton_attention.pymodelopt/torch/kernels/triton_fa.pymodelopt/torch/sparsity/attention_sparsity/config.pymodelopt/torch/sparsity/attention_sparsity/methods/__init__.pymodelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.pytests/gpu/torch/sparsity/attention_sparsity/conftest.pytests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.pytests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py
✅ Files skipped from review due to trivial changes (2)
- modelopt/torch/sparsity/attention_sparsity/methods/init.py
- CHANGELOG.rst
🚧 Files skipped from review as they are similar to previous changes (2)
- tests/gpu/torch/sparsity/attention_sparsity/conftest.py
- tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py
| # --- Optional N:M sparse softmax --- | ||
| if SPARSITY_N > 0: | ||
| # Check if this KV tile should be kept dense | ||
| kv_block_idx = kv_start // BLOCK_N | ||
| is_sink = kv_block_idx < NUM_SINK_BLOCKS | ||
| # causal_offset handles chunked prefill: q starts at (seq_len_kv - seq_len_q) | ||
| causal_offset = seq_len_kv - seq_len_q | ||
| q_abs_block = (tile_q * BLOCK_M + causal_offset) // BLOCK_N | ||
| block_distance = q_abs_block - kv_block_idx | ||
| is_local = (block_distance < DENSE_WINDOW_BLOCKS) and (block_distance >= 0) | ||
| if not is_sink and not is_local: | ||
| scores = _apply_sparse_nm_to_qk_tile( | ||
| scores, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M | ||
| ) |
There was a problem hiding this comment.
Skip N:M masking during decode.
This branch only checks SPARSITY_N > 0, so cached decode (seq_len_q == 1 with separate KV metadata) gets sparsified too. The feature is described as prefill-only; without a decode guard, generation will start pruning the KV cache as soon as sparse mode is enabled.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/kernels/triton_fa.py` around lines 279 - 292, The N:M sparsity
branch is being applied during decode because it only checks SPARSITY_N > 0;
change the condition to skip sparsification when doing cached decode (seq_len_q
== 1). Update the if guarding the block (the one that currently reads "if
SPARSITY_N > 0:") to also require not decoding (e.g., "if SPARSITY_N > 0 and
seq_len_q != 1:") so _apply_sparse_nm_to_qk_tile(scores, BLOCK_M, BLOCK_N,
SPARSITY_N, SPARSITY_M) is only called during prefill/non-decoding paths; keep
the existing local/sink logic (kv_start, tile_q, q_abs_block, is_local/is_sink)
unchanged.
Don't key the sparse mask to autotuned tile sizes.
num_sink_blocks and dense_window_blocks are interpreted in units of BLOCK_N, but forward autotunes BLOCK_N over 32/64/128 while both backward kernels hardcode BLOCK_N=64. That means forward and backward can apply different sparse masks, and the forward q_abs_block check is already wrong for later rows in a tile whenever BLOCK_M > BLOCK_N. Please normalize these regions to a fixed logical token block size (or token counts) and derive locality from each row's absolute position before reusing the same rule in forward and backward.
Also applies to: 480-491, 620-631
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/kernels/triton_fa.py` around lines 279 - 292, The sparse-mask
logic currently mixes tile-sized units (BLOCK_M/BLOCK_N) with sparsity
parameters (NUM_SINK_BLOCKS, DENSE_WINDOW_BLOCKS) leading to inconsistent masks
between forward/backward; fix by computing mask membership in token-space
instead of tile-space: derive each KV block index and query row absolute token
position from actual token counts (use seq_len_kv, seq_len_q, kv_start and
per-row start = tile_q * BLOCK_M + row_offset or for whole-tile use
tile_token_start = tile_q * BLOCK_M) and then map those token positions into
logical token-blocks of a fixed reference block size (choose the constant used
by backward kernels, e.g., 64 tokens) before comparing to NUM_SINK_BLOCKS and
DENSE_WINDOW_BLOCKS; update q_abs_block, kv_block_idx, is_sink and is_local
computations (the branch that calls _apply_sparse_nm_to_qk_tile) and apply same
normalization in the other occurrences mentioned (around the other two
locations) so forward and backward use the same token-blocking semantics.
| model_dense = AutoModelForCausalLM.from_pretrained( | ||
| tiny_llama_dir, | ||
| attn_implementation="modelopt_triton", | ||
| torch_dtype=torch.bfloat16, | ||
| device_map="cuda", | ||
| ) | ||
|
|
||
| o = attention( | ||
| q, | ||
| k, | ||
| v, | ||
| b_start_loc=torch.tensor([0], device="cuda", dtype=torch.int32), | ||
| b_seq_len=torch.tensor([seq_len], device="cuda", dtype=torch.int32), | ||
| max_input_len=seq_len, | ||
| is_causal=True, | ||
| softmax_scale=scale, | ||
| model_dense.eval() | ||
| with torch.no_grad(): | ||
| logits_dense = model_dense(input_ids=ids).logits |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "== attention sparsify entrypoints =="
rg -n -C3 '\bdef\s+sparsify\s*\(' modelopt/torch --type=py
echo
echo "== Triton backend-switch plumbing =="
rg -n -C3 'modelopt_triton|attn_implementation|_set_attn_implementation|register_triton_attention' modelopt/torch --type=pyRepository: NVIDIA/Model-Optimizer
Length of output: 24602
🏁 Script executed:
# Read the test file to see the complete context of the test
cat -n tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py | sed -n '340,400p'Repository: NVIDIA/Model-Optimizer
Length of output: 2564
Add explicit assertions to verify Triton backend was applied to sparse model.
When backend="triton" is configured in sparse_cfg, mtsa.sparsify() automatically calls _set_attn_implementation() which registers the Triton kernel and sets model.config._attn_implementation = "modelopt_triton". The current test at line 390-391 only asserts that sparse logits differ from dense logits, but does not verify the backend was actually switched. Add an explicit assertion after sparsification to confirm model_sparse.config._attn_implementation == "modelopt_triton" before comparing logits.
Also applies to: 376-381, 390-392
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py` around lines
350 - 358, Add an explicit assertion after calling mtsa.sparsify(...) to verify
the Triton backend was applied: check that
model_sparse.config._attn_implementation == "modelopt_triton" (this should be
done right after mtsa.sparsify(...) returns and before comparing logits). Locate
the sparsification call (mtsa.sparsify(..., backend="triton", ...)) and add the
assertion referencing model_sparse.config._attn_implementation to ensure the
Triton kernel registration took effect.
| mask = torch.ones_like(attention_scores, dtype=torch.bool) | ||
| return mask, {} | ||
|
|
||
| def apply_sparsity(self, attention_scores, sparse_mask=None): |
There was a problem hiding this comment.
It would be better if, for Triton-backend attention sparsity methods, we don't have to add the boilerplate in this file (i.e. not have to add a new file to methods/ altogether)
There was a problem hiding this comment.
Good catch. I've made calculate_sparsity and apply_sparsity optional in base class.
|
@coderabbitai full review |
✅ Actions performedFull review triggered. |
|
There was a problem hiding this comment.
♻️ Duplicate comments (4)
modelopt/torch/kernels/triton_fa.py (4)
619-631:⚠️ Potential issue | 🟠 MajorBackward dK/dV: same tile-based locality issue applies here.
See comment on forward kernel (lines 279-292). The fix should be applied consistently here.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/kernels/triton_fa.py` around lines 619 - 631, The backward dK/dV sparse-N mask is missing the same tile-locality handling as the forward kernel; update the backward block (the code around where scores are processed in the backward pass) to perform the same token-level/tile locality checks (compute is_sink using kv_start and NUM_SINK_TOKENS, compute causal_offset = seq_len_kv - seq_len_q, q_abs_pos = qi * BLOCK_M + causal_offset, token_distance = q_abs_pos - kv_start, is_local = (token_distance >= 0) and (token_distance < DENSE_WINDOW_SIZE)) and only call _apply_sparse_nm_to_qk_tile(scores, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M) when not is_sink and not is_local, mirroring the forward kernel's logic so dK/dV uses identical N:M sparsity masking decisions.
280-280:⚠️ Potential issue | 🟠 MajorAdd decode guard — sparsity should be prefill-only per PR description.
The PR states "sparsity masking is applied during prefill only," but the kernel only checks
SPARSITY_N > 0. During decode (seq_len_q == 1), the KV cache will still be sparsified, which could degrade generation quality.🔧 Proposed fix
if SPARSITY_N > 0: + # Skip sparsity during decode (seq_len_q == 1) — apply only during prefill + is_decode = seq_len_q == 1 is_sink = kv_start < NUM_SINK_TOKENS causal_offset = seq_len_kv - seq_len_q q_abs_pos = tile_q * BLOCK_M + causal_offset token_distance = q_abs_pos - kv_start is_local = (token_distance >= 0) and (token_distance < DENSE_WINDOW_SIZE) - if not is_sink and not is_local: + if not is_decode and not is_sink and not is_local: scores = _apply_sparse_nm_to_qk_tile(...)Apply the same guard in
_attn_bwd_dq(line 481) and_attn_bwd_dkdv(line 621).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/kernels/triton_fa.py` at line 280, The sparsity mask is currently applied unconditionally when SPARSITY_N > 0; change the guard in the backward kernels _attn_bwd_dq and _attn_bwd_dkdv so sparsity is only applied during prefill (i.e., skip when decoding with seq_len_q == 1). Locate the existing "if SPARSITY_N > 0:" checks inside _attn_bwd_dq and _attn_bwd_dkdv and strengthen them to also require seq_len_q != 1 (for example: if SPARSITY_N > 0 and seq_len_q != 1) so the KV cache is not sparsified during decode.
480-491:⚠️ Potential issue | 🟠 MajorBackward dQ: same tile-based locality issue applies here.
See comment on forward kernel (lines 279-292). The fix should be applied consistently here.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/kernels/triton_fa.py` around lines 480 - 491, This backward dQ block repeats the same tile-locality bug as the forward kernel; replicate the forward-kernel fix (the block-based locality check used around lines 279-292) here: compute the tile/block-level positions using tile_q, BLOCK_M and the corresponding kv tile start (using tile_k or kv_start aligned to BLOCK_N), derive q_abs_pos and token_distance at block granularity, set is_local based on the block/window bounds (and preserve the is_sink check with NUM_SINK_TOKENS), and only call _apply_sparse_nm_to_qk_tile(scores, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M) when the tile is neither sink nor local—i.e., exactly mirror the forward kernel's locality logic and conditions so forward and backward agree.
279-292:⚠️ Potential issue | 🟠 MajorForward/backward sparse mask mismatch due to autotuned vs fixed tile sizes.
q_abs_pos = tile_q * BLOCK_M + causal_offsetdepends on BLOCK_M, but forward autotunes BLOCK_M over {64, 128} while backward hardcodesBLOCK = 64. For the same query row, the computedq_abs_posand thusis_localcan differ between forward and backward, causing gradient mismatch.For example, query position 70 with causal_offset=0:
- Forward (BLOCK_M=128): tile_q=0, q_abs_pos=0
- Backward (BLOCK_M=64): tile_q=1, q_abs_pos=64
This changes whether tiles fall within
DENSE_WINDOW_SIZE, applying different sparse masks.Consider computing locality per-row (using the actual row offset within the tile) rather than per-tile:
if SPARSITY_N > 0: is_sink = kv_start < NUM_SINK_TOKENS causal_offset = seq_len_kv - seq_len_q - q_abs_pos = tile_q * BLOCK_M + causal_offset - token_distance = q_abs_pos - kv_start - is_local = (token_distance >= 0) and (token_distance < DENSE_WINDOW_SIZE) - if not is_sink and not is_local: + # Per-row locality: check if ANY row in this Q tile is within dense window + q_tile_start = tile_q * BLOCK_M + causal_offset + q_tile_end = q_tile_start + BLOCK_M - 1 + # Tile overlaps dense window if any Q row is within DENSE_WINDOW_SIZE of kv_start + tile_overlaps_window = (q_tile_start - kv_start < DENSE_WINDOW_SIZE) and (q_tile_start >= kv_start) + if not is_sink and not tile_overlaps_window: scores = _apply_sparse_nm_to_qk_tile(...)Alternatively, fix BLOCK_M to match backward's block size when sparsity is enabled.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/kernels/triton_fa.py` around lines 279 - 292, The sparse/locality check uses q_abs_pos = tile_q * BLOCK_M which diverges between forward (autotuned BLOCK_M) and backward (fixed BLOCK=64); fix by computing locality per-row instead of per-tile: derive the absolute query row index as (tile_q * BLOCK_M + row_within_tile) where row_within_tile is the actual row offset inside the current tile (from the loop/index that produces scores), then use that q_abs_pos for is_local/is_sink logic before calling _apply_sparse_nm_to_qk_tile; alternatively, if you prefer the simpler change, force BLOCK_M to the backward block size (64) whenever SPARSITY_N > 0 so forward and backward use the same tile size.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 619-631: The backward dK/dV sparse-N mask is missing the same
tile-locality handling as the forward kernel; update the backward block (the
code around where scores are processed in the backward pass) to perform the same
token-level/tile locality checks (compute is_sink using kv_start and
NUM_SINK_TOKENS, compute causal_offset = seq_len_kv - seq_len_q, q_abs_pos = qi
* BLOCK_M + causal_offset, token_distance = q_abs_pos - kv_start, is_local =
(token_distance >= 0) and (token_distance < DENSE_WINDOW_SIZE)) and only call
_apply_sparse_nm_to_qk_tile(scores, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M)
when not is_sink and not is_local, mirroring the forward kernel's logic so dK/dV
uses identical N:M sparsity masking decisions.
- Line 280: The sparsity mask is currently applied unconditionally when
SPARSITY_N > 0; change the guard in the backward kernels _attn_bwd_dq and
_attn_bwd_dkdv so sparsity is only applied during prefill (i.e., skip when
decoding with seq_len_q == 1). Locate the existing "if SPARSITY_N > 0:" checks
inside _attn_bwd_dq and _attn_bwd_dkdv and strengthen them to also require
seq_len_q != 1 (for example: if SPARSITY_N > 0 and seq_len_q != 1) so the KV
cache is not sparsified during decode.
- Around line 480-491: This backward dQ block repeats the same tile-locality bug
as the forward kernel; replicate the forward-kernel fix (the block-based
locality check used around lines 279-292) here: compute the tile/block-level
positions using tile_q, BLOCK_M and the corresponding kv tile start (using
tile_k or kv_start aligned to BLOCK_N), derive q_abs_pos and token_distance at
block granularity, set is_local based on the block/window bounds (and preserve
the is_sink check with NUM_SINK_TOKENS), and only call
_apply_sparse_nm_to_qk_tile(scores, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M)
when the tile is neither sink nor local—i.e., exactly mirror the forward
kernel's locality logic and conditions so forward and backward agree.
- Around line 279-292: The sparse/locality check uses q_abs_pos = tile_q *
BLOCK_M which diverges between forward (autotuned BLOCK_M) and backward (fixed
BLOCK=64); fix by computing locality per-row instead of per-tile: derive the
absolute query row index as (tile_q * BLOCK_M + row_within_tile) where
row_within_tile is the actual row offset inside the current tile (from the
loop/index that produces scores), then use that q_abs_pos for is_local/is_sink
logic before calling _apply_sparse_nm_to_qk_tile; alternatively, if you prefer
the simpler change, force BLOCK_M to the backward block size (64) whenever
SPARSITY_N > 0 so forward and backward use the same tile size.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 515c68d0-8207-4d6f-9009-6aa727c66189
📒 Files selected for processing (6)
modelopt/torch/kernels/hf_triton_attention.pymodelopt/torch/kernels/triton_fa.pymodelopt/torch/sparsity/attention_sparsity/config.pymodelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.pytests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.pytests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py
🚧 Files skipped from review as they are similar to previous changes (3)
- modelopt/torch/kernels/hf_triton_attention.py
- modelopt/torch/sparsity/attention_sparsity/config.py
- tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py
There was a problem hiding this comment.
🧹 Nitpick comments (4)
tests/gpu/torch/sparsity/attention_sparsity/conftest.py (1)
1-1: Update copyright year to 2026.The license header has copyright year 2024, but the current year is 2026. Consider updating for consistency with the PR date.
📝 Suggested fix
-# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gpu/torch/sparsity/attention_sparsity/conftest.py` at line 1, Update the SPDX license header year from 2024 to 2026 at the top of the conftest.py file (the file-level comment starting with "SPDX-FileCopyrightText") so the copyright line reflects the current year.modelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.py (1)
1-1: Update copyright year to 2026.The license header has copyright year 2024, but this is a new file created in 2026.
📝 Suggested fix
-# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.py` at line 1, Update the SPDX copyright header in the file that currently reads "Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES" to use the year 2026; locate the header at the top of the file (the SPDX/FileCopyrightText comment) and change the year to 2026 so the license header reflects the file creation year.tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py (2)
232-243: Clarify assertion expectation for M=8 pattern.The test parametrizes over both M=4 and M=8 patterns, but line 241's assertion
(kept == n).all()expects exactly N kept. For M=8 withtl.sort-based thresholding, ties may keep ≥N elements. While random inputs (line 236) make ties unlikely, consider adding a comment or splitting the test.The separate
test_sparsity_structure_tiesat lines 250-264 correctly handles this distinction, so this is a minor documentation concern.📝 Suggested clarification
def test_sparsity_structure(self, n, m): - """Verify N:M structure: exactly N kept per group of M.""" + """Verify N:M structure: exactly N kept per group of M (random input avoids ties).""" bm, bn = 32, 64 torch.manual_seed(88) tile = torch.randn(bm, bn, device="cuda", dtype=torch.float32)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py` around lines 232 - 243, test_sparsity_structure currently asserts exactly N kept per group with (kept == n).all() but for M=8 the tl.sort-based thresholding can legitimately keep >=N on ties; update the test to either relax the check for m==8 (use (kept >= n).all() when m == 8) or add a clear comment next to the assertion explaining that ties for M=8 may produce >=N and that test_sparsity_structure_ties covers strict tie behavior; reference test_sparsity_structure, the kept variable, and the call to _test_apply_sparse_nm so reviewers can find and apply the change.
1-1: Update copyright year to 2026.New file should have current copyright year.
📝 Suggested fix
-# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py` at line 1, Update the SPDX copyright header line that currently reads "Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES." to the current year 2026; locate the SPDX header (the line beginning with "SPDX-FileCopyrightText") in tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py and replace 2024 with 2026 so the SPDX header reflects Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@modelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.py`:
- Line 1: Update the SPDX copyright header in the file that currently reads
"Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES" to use the year 2026;
locate the header at the top of the file (the SPDX/FileCopyrightText comment)
and change the year to 2026 so the license header reflects the file creation
year.
In `@tests/gpu/torch/sparsity/attention_sparsity/conftest.py`:
- Line 1: Update the SPDX license header year from 2024 to 2026 at the top of
the conftest.py file (the file-level comment starting with
"SPDX-FileCopyrightText") so the copyright line reflects the current year.
In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py`:
- Around line 232-243: test_sparsity_structure currently asserts exactly N kept
per group with (kept == n).all() but for M=8 the tl.sort-based thresholding can
legitimately keep >=N on ties; update the test to either relax the check for
m==8 (use (kept >= n).all() when m == 8) or add a clear comment next to the
assertion explaining that ties for M=8 may produce >=N and that
test_sparsity_structure_ties covers strict tie behavior; reference
test_sparsity_structure, the kept variable, and the call to
_test_apply_sparse_nm so reviewers can find and apply the change.
- Line 1: Update the SPDX copyright header line that currently reads "Copyright
(c) 2024 NVIDIA CORPORATION & AFFILIATES." to the current year 2026; locate the
SPDX header (the line beginning with "SPDX-FileCopyrightText") in
tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py and
replace 2024 with 2026 so the SPDX header reflects Copyright (c) 2026 NVIDIA
CORPORATION & AFFILIATES.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: dac2310d-e99e-47c1-b62c-962ab622bdd4
📒 Files selected for processing (9)
CHANGELOG.rstmodelopt/torch/kernels/hf_triton_attention.pymodelopt/torch/kernels/triton_fa.pymodelopt/torch/sparsity/attention_sparsity/config.pymodelopt/torch/sparsity/attention_sparsity/methods/__init__.pymodelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.pytests/gpu/torch/sparsity/attention_sparsity/conftest.pytests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.pytests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
9882dbb to
67ae67b
Compare
There was a problem hiding this comment.
🧹 Nitpick comments (2)
modelopt/torch/sparsity/attention_sparsity/methods/registry.py (1)
84-94: Consider adding@abstractmethoddecorator for consistency.
get_sparse_contextraisesNotImplementedErrorbut lacks the@abstractmethoddecorator, unlike thenameproperty (line 109). This inconsistency means subclasses won't get a clear error at instantiation time if they forget to implement it—they'll only fail at runtime when the method is called.✨ Suggested change
+ `@abstractmethod` def get_sparse_context(self, module: torch.nn.Module): """Return a context manager that activates this method's sparsity during forward. Each method subclass implements its own activation mechanism: - Softmax-patching methods replace F.softmax during the forward pass. - Kernel-fused methods set flags on ``module`` that the kernel reads. Args: module: The SparseAttentionModule wrapping the attention layer. """ - raise NotImplementedError(f"{type(self).__name__} must implement get_sparse_context()")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/methods/registry.py` around lines 84 - 94, get_sparse_context currently raises NotImplementedError but isn't decorated as abstract, causing subclasses to only error at runtime; mark get_sparse_context with the `@abstractmethod` decorator (same style used for the name property) so subclasses of the registry base class must implement get_sparse_context (which should return a context manager for activating sparsity on the SparseAttentionModule) and ensure imports/ABC usage are consistent with the existing abstract methods.modelopt/torch/kernels/triton_fa.py (1)
866-871: Consider adding input validation for sparsity parameters.The kernel enforces constraints via
tl.static_assert(SPARSITY_M must be 4 or 8), but invalid combinations at the Python API level could produce unexpected behavior. For example,sparsity_n >= sparsity_morsparsity_n < 0won't be caught until kernel compilation.✨ Suggested validation
def attention( q: torch.Tensor, ... *, sparsity_n: int = 0, sparsity_m: int = 4, num_sink_tokens: int = 0, dense_window_size: int = 64, ) -> torch.Tensor: + if sparsity_n < 0: + raise ValueError(f"sparsity_n must be non-negative, got {sparsity_n}") + if sparsity_n > 0: + if sparsity_m not in (4, 8): + raise ValueError(f"sparsity_m must be 4 or 8, got {sparsity_m}") + if sparsity_n >= sparsity_m: + raise ValueError(f"sparsity_n ({sparsity_n}) must be < sparsity_m ({sparsity_m})") ...🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/kernels/triton_fa.py` around lines 866 - 871, Add Python-level input validation at the start of the function that accepts sparsity_n and sparsity_m (the function with parameters sparsity_n: int = 0, sparsity_m: int = 4, num_sink_tokens, dense_window_size) to prevent invalid combos before kernel compilation: check that sparsity_m is one of the supported values (4 or 8), sparsity_n is an int >= 0 and strictly less than sparsity_m, and that both are integers; if any check fails, raise a ValueError with a clear message referencing sparsity_n/sparsity_m so users see the invalid values. Ensure these checks run before any tl.static_assert or kernel compilation logic so invalid inputs are caught early.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 866-871: Add Python-level input validation at the start of the
function that accepts sparsity_n and sparsity_m (the function with parameters
sparsity_n: int = 0, sparsity_m: int = 4, num_sink_tokens, dense_window_size) to
prevent invalid combos before kernel compilation: check that sparsity_m is one
of the supported values (4 or 8), sparsity_n is an int >= 0 and strictly less
than sparsity_m, and that both are integers; if any check fails, raise a
ValueError with a clear message referencing sparsity_n/sparsity_m so users see
the invalid values. Ensure these checks run before any tl.static_assert or
kernel compilation logic so invalid inputs are caught early.
In `@modelopt/torch/sparsity/attention_sparsity/methods/registry.py`:
- Around line 84-94: get_sparse_context currently raises NotImplementedError but
isn't decorated as abstract, causing subclasses to only error at runtime; mark
get_sparse_context with the `@abstractmethod` decorator (same style used for the
name property) so subclasses of the registry base class must implement
get_sparse_context (which should return a context manager for activating
sparsity on the SparseAttentionModule) and ensure imports/ABC usage are
consistent with the existing abstract methods.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 8083fa36-d78b-4fde-9ad9-9ea3c8eb1666
📒 Files selected for processing (10)
CHANGELOG.rstmodelopt/torch/kernels/hf_triton_attention.pymodelopt/torch/kernels/triton_fa.pymodelopt/torch/sparsity/attention_sparsity/config.pymodelopt/torch/sparsity/attention_sparsity/methods/__init__.pymodelopt/torch/sparsity/attention_sparsity/methods/registry.pymodelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.pytests/gpu/torch/sparsity/attention_sparsity/conftest.pytests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.pytests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py
✅ Files skipped from review due to trivial changes (3)
- CHANGELOG.rst
- tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py
- tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py
🚧 Files skipped from review as they are similar to previous changes (4)
- modelopt/torch/kernels/hf_triton_attention.py
- modelopt/torch/sparsity/attention_sparsity/methods/init.py
- modelopt/torch/sparsity/attention_sparsity/config.py
- modelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.py
Edwardf0t1
left a comment
There was a problem hiding this comment.
LGTM overall, left some comments.
|
|
||
| # N:M sparse softmax — prefill only (decode should not sparsify KV) | ||
| if not is_decode and getattr(module, "_apply_sparse_nm", False): | ||
| method = getattr(module, "_sparse_method_instance", None) |
There was a problem hiding this comment.
Where _sparse_method_instance gets set, if it's outside this PR, please add a comment.
| is_sink = kv_start < NUM_SINK_TOKENS | ||
| # causal_offset handles chunked prefill: q starts at (seq_len_kv - seq_len_q) | ||
| causal_offset = seq_len_kv - seq_len_q | ||
| q_abs_pos = tile_q * BLOCK_M + causal_offset | ||
| token_distance = q_abs_pos - kv_start | ||
| is_local = (token_distance >= 0) and (token_distance < DENSE_WINDOW_SIZE) | ||
| if not is_sink and not is_local: | ||
| scores = _apply_sparse_nm_to_qk_tile( | ||
| scores, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M | ||
| ) |
There was a problem hiding this comment.
Consider extracting the duplicated sink/window check into a shared helper function.
What does this PR do?
Type of change: ?
Type of change: New feature
Add N:M structured sparsity support to the Triton flash attention kernel (
modelopt/torch/kernels/triton_fa.py). For every M consecutive key positions in the attention score tile, keeps the top-N values and sets the rest to -inf before softmax. This is applied during prefill only.Supported patterns: Any N:M where M=4 (N=1,2,3) or M=8 (N=1..4).
Performance (TFLOPS at seq_len=16384, RTX 6000):
Usage
Testing
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: ✅ / ❌ / N/AAdditional Information
Summary by CodeRabbit
New Features
API
Configuration
Tests
Documentation