Skip to content

[2/n] Add sparse softmax to the Triton flash attention kernel#1078

Open
kaix-nv wants to merge 4 commits intomainfrom
kaix/triton_fa_sparse24
Open

[2/n] Add sparse softmax to the Triton flash attention kernel#1078
kaix-nv wants to merge 4 commits intomainfrom
kaix/triton_fa_sparse24

Conversation

@kaix-nv
Copy link
Contributor

@kaix-nv kaix-nv commented Mar 19, 2026

What does this PR do?

Type of change: ?

Type of change: New feature

Add N:M structured sparsity support to the Triton flash attention kernel (modelopt/torch/kernels/triton_fa.py). For every M consecutive key positions in the attention score tile, keeps the top-N values and sets the rest to -inf before softmax. This is applied during prefill only.

Supported patterns: Any N:M where M=4 (N=1,2,3) or M=8 (N=1..4).

  • Sink tokens and dense window blocks for preserving local attention and attention sinks

Performance (TFLOPS at seq_len=16384, RTX 6000):

Pattern TFLOPS % of Dense
Dense 89.3 100%
2:4 (M=4) 69.5 78%
4:8 (M=8) 57.3 64%

Usage

# Add a code snippet demonstrating how to use this
from modelopt.torch.kernels import attention

# 2:4 sparsity (keep top 2 of every 4 K positions)
out = attention(q, k, v, b_start_loc, b_seq_len, max_len,
                sparsity_n=2, sparsity_m=4)

# 4:8 sparsity with sink tokens and dense window
out = attention(q, k, v, b_start_loc, b_seq_len, max_len,
                sparsity_n=4, sparsity_m=8,
                num_sink_tokens=4, dense_window_blocks=2)

# Dense (default, zero overhead)
out = attention(q, k, v, b_start_loc, b_seq_len, max_len)

# Via mtsa.sparsify() on HuggingFace models
import modelopt.torch.sparsity.attention_sparsity as mtsa
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B",
                                              torch_dtype=torch.bfloat16,
                                              device_map="cuda")

# Default config
mtsa.sparsify(model, mtsa.SPARSE_SOFTMAX_DEFAULT)

Testing

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅ / ❌ / N/A
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ / ❌ / N/A
  • Did you write any new necessary tests?: ✅ / ❌ / N/A
  • Did you update Changelog?: ✅ / ❌ / N/A

Additional Information

Summary by CodeRabbit

  • New Features

    • N:M structured sparse softmax added to Triton flash-attention prefill with configurable dense-window and sink-token behavior.
  • API

    • attention(...) gains keyword-only sparsity params: sparsity_n, sparsity_m, num_sink_tokens, dense_window_size; HF/Triton integration applies these for prefill.
  • Configuration

    • New exported default preset to control sparse-softmax method, sparsity, window, and sink settings.
  • Tests

    • Expanded GPU tests for N:M correctness, tile structure, and backward gradients.
  • Documentation

    • CHANGELOG updated with the new feature entry.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 19, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 19, 2026

📝 Walkthrough

Walkthrough

Adds Triton-backed N:M sparse-softmax to flash attention: kernel tile helpers and constexpr params, autograd and public API plumbing, method/config registration and HF prefill gating, plus comprehensive GPU tests and a changelog entry.

Changes

Cohort / File(s) Summary
Changelog
CHANGELOG.rst
Added 0.44 entry documenting N:M sparse softmax support for the Triton flash-attention kernel.
Triton flash-attention core
modelopt/torch/kernels/triton_fa.py
Added Triton helpers (_sparse_nm_masks_m4, _apply_sparse_nm_to_qk_tile), new constexpr params (SPARSITY_N, SPARSITY_M, NUM_SINK_TOKENS, DENSE_WINDOW_SIZE), conditional N:M masking in forward and mirrored gating in backward recomputation, and propagated sparsity/dense-window args through autograd _Attention and public attention(...).
HF wrapper integration
modelopt/torch/kernels/hf_triton_attention.py
Prefill path now supplies Triton kernel kwargs (sparsity_n, sparsity_m, num_sink_tokens, dense_window_size) when a sparse method instance signals _apply_sparse_nm; decode path unchanged.
Sparsity config & default
modelopt/torch/sparsity/attention_sparsity/config.py
Extended SparseAttentionAttributeConfig with sparsity_n, sparsity_m, num_sink_tokens, dense_window_size, added validators (e.g., sparsity_m ∈ {4,8}, non-negative bounds, cross-field max sparsity_n), and added exported SPARSE_SOFTMAX_DEFAULT.
Methods package init
modelopt/torch/sparsity/attention_sparsity/methods/__init__.py
Now imports triton_sparse_softmax at package init to register the Triton-backed method.
Triton sparse-softmax method
modelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.py
Added TritonSparseSoftmaxMethod (registered as triton_sparse_softmax) that records sparsity params and exposes get_sparse_context to toggle module._apply_sparse_nm; method relies on base-class defaults for mask/apply behavior.
Method base / registry
modelopt/torch/sparsity/attention_sparsity/methods/registry.py
Made calculate_sparsity() return an all-True boolean mask by default and apply_sparsity() raise a NotImplementedError indicating kernel fusion; removed abstract requirement for these two methods.
Test utilities
tests/gpu/torch/sparsity/attention_sparsity/conftest.py
New shared GPU test helpers and fixtures: make_qkv, make_varlen_meta, sdpa_reference, and tiny_llama_dir.
Tests (refactor + dense checks)
tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py
Refactored to use shared conftest utilities, reorganized forward/backward tests (TestForward, TestBackward), added test_sparse_disabled_matches_dense, and updated HF integration tests to exercise sparsify config.
New N:M sparsity tests
tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py
New Triton-gated GPU tests covering prefill N:M behavior: end-to-end correctness/behavior (TestSparseNM), tile-level unit tests for _apply_sparse_nm_to_qk_tile (TestSparseTileStructure), and backward gradient sanity under sparsity (TestSparseBackward).

Sequence Diagram(s)

sequenceDiagram
    participant Caller as "Caller / Model"
    participant API as "attention(...) / _Attention"
    participant Autograd as "Autograd Function"
    participant Triton as "Triton FA Kernel"
    participant KV as "KV Storage / Tiles"

    rect rgba(100,149,237,0.5)
    Caller->>API: call attention(q,k,v,..., sparsity_n, sparsity_m, num_sink_tokens, dense_window_size)
    API->>Autograd: _Attention.forward(sparse params)
    Autograd->>Triton: launch forward kernel (constexpr sparsity params)
    Triton->>KV: load QK tile
    Triton->>Triton: apply N:M mask (tile-level) and set pruned scores to -inf
    Triton->>Triton: softmax & compute output context
    Triton-->>Autograd: return outputs + saved tensors (incl. sparsity params)
    end

    rect rgba(60,179,113,0.5)
    Autograd->>Triton: backward launch with same sparsity params
    Triton->>Triton: recompute masked scores (respect sink/window), compute dq/dk/dv
    Triton-->>Autograd: return gradients
    Autograd-->>Caller: propagate gradients
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 4
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and concisely describes the main change: adding N:M sparse softmax capability to the Triton flash attention kernel, which is the primary feature across all modified files.
Docstring Coverage ✅ Passed Docstring coverage is 90.57% which is sufficient. The required threshold is 80.00%.
Security Anti-Patterns ✅ Passed No security anti-patterns found. Triton kernels, Pydantic validation, and test fixtures follow security best practices.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch kaix/triton_fa_sparse24

Comment @coderabbitai help to get the list of available commands and usage tips.

@codecov
Copy link

codecov bot commented Mar 19, 2026

Codecov Report

❌ Patch coverage is 58.92857% with 23 lines in your changes missing coverage. Please review.
✅ Project coverage is 70.18%. Comparing base (d698864) to head (67ae67b).

Files with missing lines Patch % Lines
...ttention_sparsity/methods/triton_sparse_softmax.py 36.36% 14 Missing ⚠️
...delopt/torch/sparsity/attention_sparsity/config.py 75.00% 8 Missing ⚠️
...ch/sparsity/attention_sparsity/methods/registry.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1078      +/-   ##
==========================================
- Coverage   70.18%   70.18%   -0.01%     
==========================================
  Files         228      229       +1     
  Lines       25952    26008      +56     
==========================================
+ Hits        18215    18254      +39     
- Misses       7737     7754      +17     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@kaix-nv kaix-nv force-pushed the kaix/triton_fa_sparse24 branch 2 times, most recently from 8ba6efe to 7aa6960 Compare March 20, 2026 04:47
@kaix-nv kaix-nv marked this pull request as ready for review March 20, 2026 05:16
@kaix-nv kaix-nv requested a review from a team as a code owner March 20, 2026 05:16
@kaix-nv kaix-nv requested review from ChenhanYu, Edwardf0t1, cjluo-nv, kevalmorabia97 and rohansjoshi and removed request for ChenhanYu and Edwardf0t1 March 20, 2026 05:16
@kaix-nv kaix-nv force-pushed the kaix/triton_fa_sparse24 branch from 7aa6960 to 31655ce Compare March 21, 2026 19:43
@kaix-nv kaix-nv requested a review from a team as a code owner March 21, 2026 19:43
@kaix-nv kaix-nv changed the title Add 2:4 sparse softmax to the Triton flash attention kernel Add sparse softmax to the Triton flash attention kernel Mar 21, 2026
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 279-292: The N:M sparsity branch is being applied during decode
because it only checks SPARSITY_N > 0; change the condition to skip
sparsification when doing cached decode (seq_len_q == 1). Update the if guarding
the block (the one that currently reads "if SPARSITY_N > 0:") to also require
not decoding (e.g., "if SPARSITY_N > 0 and seq_len_q != 1:") so
_apply_sparse_nm_to_qk_tile(scores, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M) is
only called during prefill/non-decoding paths; keep the existing local/sink
logic (kv_start, tile_q, q_abs_block, is_local/is_sink) unchanged.
- Around line 279-292: The sparse-mask logic currently mixes tile-sized units
(BLOCK_M/BLOCK_N) with sparsity parameters (NUM_SINK_BLOCKS,
DENSE_WINDOW_BLOCKS) leading to inconsistent masks between forward/backward; fix
by computing mask membership in token-space instead of tile-space: derive each
KV block index and query row absolute token position from actual token counts
(use seq_len_kv, seq_len_q, kv_start and per-row start = tile_q * BLOCK_M +
row_offset or for whole-tile use tile_token_start = tile_q * BLOCK_M) and then
map those token positions into logical token-blocks of a fixed reference block
size (choose the constant used by backward kernels, e.g., 64 tokens) before
comparing to NUM_SINK_BLOCKS and DENSE_WINDOW_BLOCKS; update q_abs_block,
kv_block_idx, is_sink and is_local computations (the branch that calls
_apply_sparse_nm_to_qk_tile) and apply same normalization in the other
occurrences mentioned (around the other two locations) so forward and backward
use the same token-blocking semantics.

In `@modelopt/torch/sparsity/attention_sparsity/config.py`:
- Around line 99-130: Add validation to the config so invalid N:M and negative
counts are rejected when building the config instead of failing in the Triton
kernel: in the class/constructor where ModeloptField(s) sparsity_n, sparsity_m,
num_sink_blocks, and dense_window_blocks are defined, validate that sparsity_m
is either 4 or 8, sparsity_n is in {1,2,3} when sparsity_m==4 and in {1,2,3,4}
when sparsity_m==8 (or 0/disabled as your semantics require), and that
num_sink_blocks and dense_window_blocks are non-negative; also ensure that the
chosen sparsity mode is only allowed when triton_sparse_softmax is selected or
available and raise a clear config validation error if any rule is violated so
bad configs fail early.

In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py`:
- Around line 350-358: Add an explicit assertion after calling
mtsa.sparsify(...) to verify the Triton backend was applied: check that
model_sparse.config._attn_implementation == "modelopt_triton" (this should be
done right after mtsa.sparsify(...) returns and before comparing logits). Locate
the sparsification call (mtsa.sparsify(..., backend="triton", ...)) and add the
assertion referencing model_sparse.config._attn_implementation to ensure the
Triton kernel registration took effect.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: ce048172-0588-4aec-9474-44eb6c4cbe3b

📥 Commits

Reviewing files that changed from the base of the PR and between 7aa6960f1914d3ed5276afa346d4817d136af320 and 31655ce.

📒 Files selected for processing (9)
  • CHANGELOG.rst
  • modelopt/torch/kernels/hf_triton_attention.py
  • modelopt/torch/kernels/triton_fa.py
  • modelopt/torch/sparsity/attention_sparsity/config.py
  • modelopt/torch/sparsity/attention_sparsity/methods/__init__.py
  • modelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.py
  • tests/gpu/torch/sparsity/attention_sparsity/conftest.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py
✅ Files skipped from review due to trivial changes (2)
  • modelopt/torch/sparsity/attention_sparsity/methods/init.py
  • CHANGELOG.rst
🚧 Files skipped from review as they are similar to previous changes (2)
  • tests/gpu/torch/sparsity/attention_sparsity/conftest.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py

Comment on lines +279 to +292
# --- Optional N:M sparse softmax ---
if SPARSITY_N > 0:
# Check if this KV tile should be kept dense
kv_block_idx = kv_start // BLOCK_N
is_sink = kv_block_idx < NUM_SINK_BLOCKS
# causal_offset handles chunked prefill: q starts at (seq_len_kv - seq_len_q)
causal_offset = seq_len_kv - seq_len_q
q_abs_block = (tile_q * BLOCK_M + causal_offset) // BLOCK_N
block_distance = q_abs_block - kv_block_idx
is_local = (block_distance < DENSE_WINDOW_BLOCKS) and (block_distance >= 0)
if not is_sink and not is_local:
scores = _apply_sparse_nm_to_qk_tile(
scores, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Skip N:M masking during decode.

This branch only checks SPARSITY_N > 0, so cached decode (seq_len_q == 1 with separate KV metadata) gets sparsified too. The feature is described as prefill-only; without a decode guard, generation will start pruning the KV cache as soon as sparse mode is enabled.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/triton_fa.py` around lines 279 - 292, The N:M sparsity
branch is being applied during decode because it only checks SPARSITY_N > 0;
change the condition to skip sparsification when doing cached decode (seq_len_q
== 1). Update the if guarding the block (the one that currently reads "if
SPARSITY_N > 0:") to also require not decoding (e.g., "if SPARSITY_N > 0 and
seq_len_q != 1:") so _apply_sparse_nm_to_qk_tile(scores, BLOCK_M, BLOCK_N,
SPARSITY_N, SPARSITY_M) is only called during prefill/non-decoding paths; keep
the existing local/sink logic (kv_start, tile_q, q_abs_block, is_local/is_sink)
unchanged.

⚠️ Potential issue | 🔴 Critical

Don't key the sparse mask to autotuned tile sizes.

num_sink_blocks and dense_window_blocks are interpreted in units of BLOCK_N, but forward autotunes BLOCK_N over 32/64/128 while both backward kernels hardcode BLOCK_N=64. That means forward and backward can apply different sparse masks, and the forward q_abs_block check is already wrong for later rows in a tile whenever BLOCK_M > BLOCK_N. Please normalize these regions to a fixed logical token block size (or token counts) and derive locality from each row's absolute position before reusing the same rule in forward and backward.

Also applies to: 480-491, 620-631

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/triton_fa.py` around lines 279 - 292, The sparse-mask
logic currently mixes tile-sized units (BLOCK_M/BLOCK_N) with sparsity
parameters (NUM_SINK_BLOCKS, DENSE_WINDOW_BLOCKS) leading to inconsistent masks
between forward/backward; fix by computing mask membership in token-space
instead of tile-space: derive each KV block index and query row absolute token
position from actual token counts (use seq_len_kv, seq_len_q, kv_start and
per-row start = tile_q * BLOCK_M + row_offset or for whole-tile use
tile_token_start = tile_q * BLOCK_M) and then map those token positions into
logical token-blocks of a fixed reference block size (choose the constant used
by backward kernels, e.g., 64 tokens) before comparing to NUM_SINK_BLOCKS and
DENSE_WINDOW_BLOCKS; update q_abs_block, kv_block_idx, is_sink and is_local
computations (the branch that calls _apply_sparse_nm_to_qk_tile) and apply same
normalization in the other occurrences mentioned (around the other two
locations) so forward and backward use the same token-blocking semantics.

Comment on lines +350 to +358
model_dense = AutoModelForCausalLM.from_pretrained(
tiny_llama_dir,
attn_implementation="modelopt_triton",
torch_dtype=torch.bfloat16,
device_map="cuda",
)

o = attention(
q,
k,
v,
b_start_loc=torch.tensor([0], device="cuda", dtype=torch.int32),
b_seq_len=torch.tensor([seq_len], device="cuda", dtype=torch.int32),
max_input_len=seq_len,
is_causal=True,
softmax_scale=scale,
model_dense.eval()
with torch.no_grad():
logits_dense = model_dense(input_ids=ids).logits
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== attention sparsify entrypoints =="
rg -n -C3 '\bdef\s+sparsify\s*\(' modelopt/torch --type=py

echo
echo "== Triton backend-switch plumbing =="
rg -n -C3 'modelopt_triton|attn_implementation|_set_attn_implementation|register_triton_attention' modelopt/torch --type=py

Repository: NVIDIA/Model-Optimizer

Length of output: 24602


🏁 Script executed:

# Read the test file to see the complete context of the test
cat -n tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py | sed -n '340,400p'

Repository: NVIDIA/Model-Optimizer

Length of output: 2564


Add explicit assertions to verify Triton backend was applied to sparse model.

When backend="triton" is configured in sparse_cfg, mtsa.sparsify() automatically calls _set_attn_implementation() which registers the Triton kernel and sets model.config._attn_implementation = "modelopt_triton". The current test at line 390-391 only asserts that sparse logits differ from dense logits, but does not verify the backend was actually switched. Add an explicit assertion after sparsification to confirm model_sparse.config._attn_implementation == "modelopt_triton" before comparing logits.

Also applies to: 376-381, 390-392

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py` around lines
350 - 358, Add an explicit assertion after calling mtsa.sparsify(...) to verify
the Triton backend was applied: check that
model_sparse.config._attn_implementation == "modelopt_triton" (this should be
done right after mtsa.sparsify(...) returns and before comparing logits). Locate
the sparsification call (mtsa.sparsify(..., backend="triton", ...)) and add the
assertion referencing model_sparse.config._attn_implementation to ensure the
Triton kernel registration took effect.

@kaix-nv kaix-nv changed the title Add sparse softmax to the Triton flash attention kernel [2/n] Add sparse softmax to the Triton flash attention kernel Mar 23, 2026
@kaix-nv kaix-nv requested review from Edwardf0t1 March 24, 2026 00:16
mask = torch.ones_like(attention_scores, dtype=torch.bool)
return mask, {}

def apply_sparsity(self, attention_scores, sparse_mask=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better if, for Triton-backend attention sparsity methods, we don't have to add the boilerplate in this file (i.e. not have to add a new file to methods/ altogether)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. I've made calculate_sparsity and apply_sparsity optional in base class.

@kaix-nv
Copy link
Contributor Author

kaix-nv commented Mar 24, 2026

@coderabbitai full review

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 24, 2026

✅ Actions performed

Full review triggered.

@github-actions
Copy link
Contributor

github-actions bot commented Mar 24, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1078/

Built to branch gh-pages at 2026-03-24 21:01 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (4)
modelopt/torch/kernels/triton_fa.py (4)

619-631: ⚠️ Potential issue | 🟠 Major

Backward dK/dV: same tile-based locality issue applies here.

See comment on forward kernel (lines 279-292). The fix should be applied consistently here.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/triton_fa.py` around lines 619 - 631, The backward
dK/dV sparse-N mask is missing the same tile-locality handling as the forward
kernel; update the backward block (the code around where scores are processed in
the backward pass) to perform the same token-level/tile locality checks (compute
is_sink using kv_start and NUM_SINK_TOKENS, compute causal_offset = seq_len_kv -
seq_len_q, q_abs_pos = qi * BLOCK_M + causal_offset, token_distance = q_abs_pos
- kv_start, is_local = (token_distance >= 0) and (token_distance <
DENSE_WINDOW_SIZE)) and only call _apply_sparse_nm_to_qk_tile(scores, BLOCK_M,
BLOCK_N, SPARSITY_N, SPARSITY_M) when not is_sink and not is_local, mirroring
the forward kernel's logic so dK/dV uses identical N:M sparsity masking
decisions.

280-280: ⚠️ Potential issue | 🟠 Major

Add decode guard — sparsity should be prefill-only per PR description.

The PR states "sparsity masking is applied during prefill only," but the kernel only checks SPARSITY_N > 0. During decode (seq_len_q == 1), the KV cache will still be sparsified, which could degrade generation quality.

🔧 Proposed fix
 if SPARSITY_N > 0:
+    # Skip sparsity during decode (seq_len_q == 1) — apply only during prefill
+    is_decode = seq_len_q == 1
     is_sink = kv_start < NUM_SINK_TOKENS
     causal_offset = seq_len_kv - seq_len_q
     q_abs_pos = tile_q * BLOCK_M + causal_offset
     token_distance = q_abs_pos - kv_start
     is_local = (token_distance >= 0) and (token_distance < DENSE_WINDOW_SIZE)
-    if not is_sink and not is_local:
+    if not is_decode and not is_sink and not is_local:
         scores = _apply_sparse_nm_to_qk_tile(...)

Apply the same guard in _attn_bwd_dq (line 481) and _attn_bwd_dkdv (line 621).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/triton_fa.py` at line 280, The sparsity mask is
currently applied unconditionally when SPARSITY_N > 0; change the guard in the
backward kernels _attn_bwd_dq and _attn_bwd_dkdv so sparsity is only applied
during prefill (i.e., skip when decoding with seq_len_q == 1). Locate the
existing "if SPARSITY_N > 0:" checks inside _attn_bwd_dq and _attn_bwd_dkdv and
strengthen them to also require seq_len_q != 1 (for example: if SPARSITY_N > 0
and seq_len_q != 1) so the KV cache is not sparsified during decode.

480-491: ⚠️ Potential issue | 🟠 Major

Backward dQ: same tile-based locality issue applies here.

See comment on forward kernel (lines 279-292). The fix should be applied consistently here.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/triton_fa.py` around lines 480 - 491, This backward dQ
block repeats the same tile-locality bug as the forward kernel; replicate the
forward-kernel fix (the block-based locality check used around lines 279-292)
here: compute the tile/block-level positions using tile_q, BLOCK_M and the
corresponding kv tile start (using tile_k or kv_start aligned to BLOCK_N),
derive q_abs_pos and token_distance at block granularity, set is_local based on
the block/window bounds (and preserve the is_sink check with NUM_SINK_TOKENS),
and only call _apply_sparse_nm_to_qk_tile(scores, BLOCK_M, BLOCK_N, SPARSITY_N,
SPARSITY_M) when the tile is neither sink nor local—i.e., exactly mirror the
forward kernel's locality logic and conditions so forward and backward agree.

279-292: ⚠️ Potential issue | 🟠 Major

Forward/backward sparse mask mismatch due to autotuned vs fixed tile sizes.

q_abs_pos = tile_q * BLOCK_M + causal_offset depends on BLOCK_M, but forward autotunes BLOCK_M over {64, 128} while backward hardcodes BLOCK = 64. For the same query row, the computed q_abs_pos and thus is_local can differ between forward and backward, causing gradient mismatch.

For example, query position 70 with causal_offset=0:

  • Forward (BLOCK_M=128): tile_q=0, q_abs_pos=0
  • Backward (BLOCK_M=64): tile_q=1, q_abs_pos=64

This changes whether tiles fall within DENSE_WINDOW_SIZE, applying different sparse masks.

Consider computing locality per-row (using the actual row offset within the tile) rather than per-tile:

 if SPARSITY_N > 0:
     is_sink = kv_start < NUM_SINK_TOKENS
     causal_offset = seq_len_kv - seq_len_q
-    q_abs_pos = tile_q * BLOCK_M + causal_offset
-    token_distance = q_abs_pos - kv_start
-    is_local = (token_distance >= 0) and (token_distance < DENSE_WINDOW_SIZE)
-    if not is_sink and not is_local:
+    # Per-row locality: check if ANY row in this Q tile is within dense window
+    q_tile_start = tile_q * BLOCK_M + causal_offset
+    q_tile_end = q_tile_start + BLOCK_M - 1
+    # Tile overlaps dense window if any Q row is within DENSE_WINDOW_SIZE of kv_start
+    tile_overlaps_window = (q_tile_start - kv_start < DENSE_WINDOW_SIZE) and (q_tile_start >= kv_start)
+    if not is_sink and not tile_overlaps_window:
         scores = _apply_sparse_nm_to_qk_tile(...)

Alternatively, fix BLOCK_M to match backward's block size when sparsity is enabled.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/triton_fa.py` around lines 279 - 292, The
sparse/locality check uses q_abs_pos = tile_q * BLOCK_M which diverges between
forward (autotuned BLOCK_M) and backward (fixed BLOCK=64); fix by computing
locality per-row instead of per-tile: derive the absolute query row index as
(tile_q * BLOCK_M + row_within_tile) where row_within_tile is the actual row
offset inside the current tile (from the loop/index that produces scores), then
use that q_abs_pos for is_local/is_sink logic before calling
_apply_sparse_nm_to_qk_tile; alternatively, if you prefer the simpler change,
force BLOCK_M to the backward block size (64) whenever SPARSITY_N > 0 so forward
and backward use the same tile size.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 619-631: The backward dK/dV sparse-N mask is missing the same
tile-locality handling as the forward kernel; update the backward block (the
code around where scores are processed in the backward pass) to perform the same
token-level/tile locality checks (compute is_sink using kv_start and
NUM_SINK_TOKENS, compute causal_offset = seq_len_kv - seq_len_q, q_abs_pos = qi
* BLOCK_M + causal_offset, token_distance = q_abs_pos - kv_start, is_local =
(token_distance >= 0) and (token_distance < DENSE_WINDOW_SIZE)) and only call
_apply_sparse_nm_to_qk_tile(scores, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M)
when not is_sink and not is_local, mirroring the forward kernel's logic so dK/dV
uses identical N:M sparsity masking decisions.
- Line 280: The sparsity mask is currently applied unconditionally when
SPARSITY_N > 0; change the guard in the backward kernels _attn_bwd_dq and
_attn_bwd_dkdv so sparsity is only applied during prefill (i.e., skip when
decoding with seq_len_q == 1). Locate the existing "if SPARSITY_N > 0:" checks
inside _attn_bwd_dq and _attn_bwd_dkdv and strengthen them to also require
seq_len_q != 1 (for example: if SPARSITY_N > 0 and seq_len_q != 1) so the KV
cache is not sparsified during decode.
- Around line 480-491: This backward dQ block repeats the same tile-locality bug
as the forward kernel; replicate the forward-kernel fix (the block-based
locality check used around lines 279-292) here: compute the tile/block-level
positions using tile_q, BLOCK_M and the corresponding kv tile start (using
tile_k or kv_start aligned to BLOCK_N), derive q_abs_pos and token_distance at
block granularity, set is_local based on the block/window bounds (and preserve
the is_sink check with NUM_SINK_TOKENS), and only call
_apply_sparse_nm_to_qk_tile(scores, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M)
when the tile is neither sink nor local—i.e., exactly mirror the forward
kernel's locality logic and conditions so forward and backward agree.
- Around line 279-292: The sparse/locality check uses q_abs_pos = tile_q *
BLOCK_M which diverges between forward (autotuned BLOCK_M) and backward (fixed
BLOCK=64); fix by computing locality per-row instead of per-tile: derive the
absolute query row index as (tile_q * BLOCK_M + row_within_tile) where
row_within_tile is the actual row offset inside the current tile (from the
loop/index that produces scores), then use that q_abs_pos for is_local/is_sink
logic before calling _apply_sparse_nm_to_qk_tile; alternatively, if you prefer
the simpler change, force BLOCK_M to the backward block size (64) whenever
SPARSITY_N > 0 so forward and backward use the same tile size.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 515c68d0-8207-4d6f-9009-6aa727c66189

📥 Commits

Reviewing files that changed from the base of the PR and between 31655ce and 9882dbb.

📒 Files selected for processing (6)
  • modelopt/torch/kernels/hf_triton_attention.py
  • modelopt/torch/kernels/triton_fa.py
  • modelopt/torch/sparsity/attention_sparsity/config.py
  • modelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py
🚧 Files skipped from review as they are similar to previous changes (3)
  • modelopt/torch/kernels/hf_triton_attention.py
  • modelopt/torch/sparsity/attention_sparsity/config.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (4)
tests/gpu/torch/sparsity/attention_sparsity/conftest.py (1)

1-1: Update copyright year to 2026.

The license header has copyright year 2024, but the current year is 2026. Consider updating for consistency with the PR date.

📝 Suggested fix
-# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gpu/torch/sparsity/attention_sparsity/conftest.py` at line 1, Update
the SPDX license header year from 2024 to 2026 at the top of the conftest.py
file (the file-level comment starting with "SPDX-FileCopyrightText") so the
copyright line reflects the current year.
modelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.py (1)

1-1: Update copyright year to 2026.

The license header has copyright year 2024, but this is a new file created in 2026.

📝 Suggested fix
-# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.py`
at line 1, Update the SPDX copyright header in the file that currently reads
"Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES" to use the year 2026;
locate the header at the top of the file (the SPDX/FileCopyrightText comment)
and change the year to 2026 so the license header reflects the file creation
year.
tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py (2)

232-243: Clarify assertion expectation for M=8 pattern.

The test parametrizes over both M=4 and M=8 patterns, but line 241's assertion (kept == n).all() expects exactly N kept. For M=8 with tl.sort-based thresholding, ties may keep ≥N elements. While random inputs (line 236) make ties unlikely, consider adding a comment or splitting the test.

The separate test_sparsity_structure_ties at lines 250-264 correctly handles this distinction, so this is a minor documentation concern.

📝 Suggested clarification
     def test_sparsity_structure(self, n, m):
-        """Verify N:M structure: exactly N kept per group of M."""
+        """Verify N:M structure: exactly N kept per group of M (random input avoids ties)."""
         bm, bn = 32, 64
         torch.manual_seed(88)
         tile = torch.randn(bm, bn, device="cuda", dtype=torch.float32)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py`
around lines 232 - 243, test_sparsity_structure currently asserts exactly N kept
per group with (kept == n).all() but for M=8 the tl.sort-based thresholding can
legitimately keep >=N on ties; update the test to either relax the check for
m==8 (use (kept >= n).all() when m == 8) or add a clear comment next to the
assertion explaining that ties for M=8 may produce >=N and that
test_sparsity_structure_ties covers strict tie behavior; reference
test_sparsity_structure, the kept variable, and the call to
_test_apply_sparse_nm so reviewers can find and apply the change.

1-1: Update copyright year to 2026.

New file should have current copyright year.

📝 Suggested fix
-# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py` at
line 1, Update the SPDX copyright header line that currently reads "Copyright
(c) 2024 NVIDIA CORPORATION & AFFILIATES." to the current year 2026; locate the
SPDX header (the line beginning with "SPDX-FileCopyrightText") in
tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py and
replace 2024 with 2026 so the SPDX header reflects Copyright (c) 2026 NVIDIA
CORPORATION & AFFILIATES.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@modelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.py`:
- Line 1: Update the SPDX copyright header in the file that currently reads
"Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES" to use the year 2026;
locate the header at the top of the file (the SPDX/FileCopyrightText comment)
and change the year to 2026 so the license header reflects the file creation
year.

In `@tests/gpu/torch/sparsity/attention_sparsity/conftest.py`:
- Line 1: Update the SPDX license header year from 2024 to 2026 at the top of
the conftest.py file (the file-level comment starting with
"SPDX-FileCopyrightText") so the copyright line reflects the current year.

In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py`:
- Around line 232-243: test_sparsity_structure currently asserts exactly N kept
per group with (kept == n).all() but for M=8 the tl.sort-based thresholding can
legitimately keep >=N on ties; update the test to either relax the check for
m==8 (use (kept >= n).all() when m == 8) or add a clear comment next to the
assertion explaining that ties for M=8 may produce >=N and that
test_sparsity_structure_ties covers strict tie behavior; reference
test_sparsity_structure, the kept variable, and the call to
_test_apply_sparse_nm so reviewers can find and apply the change.
- Line 1: Update the SPDX copyright header line that currently reads "Copyright
(c) 2024 NVIDIA CORPORATION & AFFILIATES." to the current year 2026; locate the
SPDX header (the line beginning with "SPDX-FileCopyrightText") in
tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py and
replace 2024 with 2026 so the SPDX header reflects Copyright (c) 2026 NVIDIA
CORPORATION & AFFILIATES.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: dac2310d-e99e-47c1-b62c-962ab622bdd4

📥 Commits

Reviewing files that changed from the base of the PR and between 08e5f92 and 9882dbb.

📒 Files selected for processing (9)
  • CHANGELOG.rst
  • modelopt/torch/kernels/hf_triton_attention.py
  • modelopt/torch/kernels/triton_fa.py
  • modelopt/torch/sparsity/attention_sparsity/config.py
  • modelopt/torch/sparsity/attention_sparsity/methods/__init__.py
  • modelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.py
  • tests/gpu/torch/sparsity/attention_sparsity/conftest.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py

kaix-nv added 2 commits March 24, 2026 13:56
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
@kaix-nv kaix-nv force-pushed the kaix/triton_fa_sparse24 branch from 9882dbb to 67ae67b Compare March 24, 2026 20:56
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
modelopt/torch/sparsity/attention_sparsity/methods/registry.py (1)

84-94: Consider adding @abstractmethod decorator for consistency.

get_sparse_context raises NotImplementedError but lacks the @abstractmethod decorator, unlike the name property (line 109). This inconsistency means subclasses won't get a clear error at instantiation time if they forget to implement it—they'll only fail at runtime when the method is called.

✨ Suggested change
+    `@abstractmethod`
     def get_sparse_context(self, module: torch.nn.Module):
         """Return a context manager that activates this method's sparsity during forward.

         Each method subclass implements its own activation mechanism:
         - Softmax-patching methods replace F.softmax during the forward pass.
         - Kernel-fused methods set flags on ``module`` that the kernel reads.

         Args:
             module: The SparseAttentionModule wrapping the attention layer.
         """
-        raise NotImplementedError(f"{type(self).__name__} must implement get_sparse_context()")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/methods/registry.py` around lines
84 - 94, get_sparse_context currently raises NotImplementedError but isn't
decorated as abstract, causing subclasses to only error at runtime; mark
get_sparse_context with the `@abstractmethod` decorator (same style used for the
name property) so subclasses of the registry base class must implement
get_sparse_context (which should return a context manager for activating
sparsity on the SparseAttentionModule) and ensure imports/ABC usage are
consistent with the existing abstract methods.
modelopt/torch/kernels/triton_fa.py (1)

866-871: Consider adding input validation for sparsity parameters.

The kernel enforces constraints via tl.static_assert (SPARSITY_M must be 4 or 8), but invalid combinations at the Python API level could produce unexpected behavior. For example, sparsity_n >= sparsity_m or sparsity_n < 0 won't be caught until kernel compilation.

✨ Suggested validation
 def attention(
     q: torch.Tensor,
     ...
     *,
     sparsity_n: int = 0,
     sparsity_m: int = 4,
     num_sink_tokens: int = 0,
     dense_window_size: int = 64,
 ) -> torch.Tensor:
+    if sparsity_n < 0:
+        raise ValueError(f"sparsity_n must be non-negative, got {sparsity_n}")
+    if sparsity_n > 0:
+        if sparsity_m not in (4, 8):
+            raise ValueError(f"sparsity_m must be 4 or 8, got {sparsity_m}")
+        if sparsity_n >= sparsity_m:
+            raise ValueError(f"sparsity_n ({sparsity_n}) must be < sparsity_m ({sparsity_m})")
     ...
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/triton_fa.py` around lines 866 - 871, Add Python-level
input validation at the start of the function that accepts sparsity_n and
sparsity_m (the function with parameters sparsity_n: int = 0, sparsity_m: int =
4, num_sink_tokens, dense_window_size) to prevent invalid combos before kernel
compilation: check that sparsity_m is one of the supported values (4 or 8),
sparsity_n is an int >= 0 and strictly less than sparsity_m, and that both are
integers; if any check fails, raise a ValueError with a clear message
referencing sparsity_n/sparsity_m so users see the invalid values. Ensure these
checks run before any tl.static_assert or kernel compilation logic so invalid
inputs are caught early.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 866-871: Add Python-level input validation at the start of the
function that accepts sparsity_n and sparsity_m (the function with parameters
sparsity_n: int = 0, sparsity_m: int = 4, num_sink_tokens, dense_window_size) to
prevent invalid combos before kernel compilation: check that sparsity_m is one
of the supported values (4 or 8), sparsity_n is an int >= 0 and strictly less
than sparsity_m, and that both are integers; if any check fails, raise a
ValueError with a clear message referencing sparsity_n/sparsity_m so users see
the invalid values. Ensure these checks run before any tl.static_assert or
kernel compilation logic so invalid inputs are caught early.

In `@modelopt/torch/sparsity/attention_sparsity/methods/registry.py`:
- Around line 84-94: get_sparse_context currently raises NotImplementedError but
isn't decorated as abstract, causing subclasses to only error at runtime; mark
get_sparse_context with the `@abstractmethod` decorator (same style used for the
name property) so subclasses of the registry base class must implement
get_sparse_context (which should return a context manager for activating
sparsity on the SparseAttentionModule) and ensure imports/ABC usage are
consistent with the existing abstract methods.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 8083fa36-d78b-4fde-9ad9-9ea3c8eb1666

📥 Commits

Reviewing files that changed from the base of the PR and between 9882dbb and 67ae67b.

📒 Files selected for processing (10)
  • CHANGELOG.rst
  • modelopt/torch/kernels/hf_triton_attention.py
  • modelopt/torch/kernels/triton_fa.py
  • modelopt/torch/sparsity/attention_sparsity/config.py
  • modelopt/torch/sparsity/attention_sparsity/methods/__init__.py
  • modelopt/torch/sparsity/attention_sparsity/methods/registry.py
  • modelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.py
  • tests/gpu/torch/sparsity/attention_sparsity/conftest.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py
✅ Files skipped from review due to trivial changes (3)
  • CHANGELOG.rst
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_sparse_nm.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py
🚧 Files skipped from review as they are similar to previous changes (4)
  • modelopt/torch/kernels/hf_triton_attention.py
  • modelopt/torch/sparsity/attention_sparsity/methods/init.py
  • modelopt/torch/sparsity/attention_sparsity/config.py
  • modelopt/torch/sparsity/attention_sparsity/methods/triton_sparse_softmax.py

Copy link
Contributor

@Edwardf0t1 Edwardf0t1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall, left some comments.


# N:M sparse softmax — prefill only (decode should not sparsify KV)
if not is_decode and getattr(module, "_apply_sparse_nm", False):
method = getattr(module, "_sparse_method_instance", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where _sparse_method_instance gets set, if it's outside this PR, please add a comment.

Comment on lines +283 to +292
is_sink = kv_start < NUM_SINK_TOKENS
# causal_offset handles chunked prefill: q starts at (seq_len_kv - seq_len_q)
causal_offset = seq_len_kv - seq_len_q
q_abs_pos = tile_q * BLOCK_M + causal_offset
token_distance = q_abs_pos - kv_start
is_local = (token_distance >= 0) and (token_distance < DENSE_WINDOW_SIZE)
if not is_sink and not is_local:
scores = _apply_sparse_nm_to_qk_tile(
scores, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider extracting the duplicated sink/window check into a shared helper function.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants