Skip to content

[3/n] Add skip-softmax to Triton flash attention kernel#1081

Open
kaix-nv wants to merge 2 commits intomainfrom
kaix/triton_fa_skip_softmax
Open

[3/n] Add skip-softmax to Triton flash attention kernel#1081
kaix-nv wants to merge 2 commits intomainfrom
kaix/triton_fa_skip_softmax

Conversation

@kaix-nv
Copy link
Contributor

@kaix-nv kaix-nv commented Mar 20, 2026

What does this PR do?

Type of change: ?

New feature. Add skip-softmax tile skipping to the Triton flash attention kernel.

Usage

# Add a code snippet demonstrating how to use this
from modelopt.torch.kernels import attention

# Skip-softmax with threshold 0.1 (tiles contributing < 10% are skipped)
out = attention(q, k, v, b_start_loc, b_seq_len, max_len,
                skip_softmax_threshold=0.1)

# Via mtsa.sparsify() on HuggingFace models
import modelopt.torch.sparsity.attention_sparsity as mtsa
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B",
                                              torch_dtype=torch.bfloat16,
                                              device_map="cuda")

# Default config
mtsa.sparsify(model, mtsa.SKIP_SOFTMAX_TRITON_DEFAULT)

Testing

Performance (TFLOPS at seq_len=16384, RTX 6000 Pro):

SEQ_LEN ModelOpt Triton PyTorch SDPA Flash Attention 2 Skip-Softmax t=0.01 Skip-Softmax t=0.1
16384.0 188.849922 211.718193 224.242843 172.901804 279.861684
32768.0 175.321787 212.815740 224.833553 146.150702 262.490463
65536.0 167.302839 214.932407 226.456141 145.082937 243.344791

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅ / ❌ / N/A
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ / ❌ / N/A
  • Did you write any new necessary tests?: ✅ / ❌ / N/A
  • Did you update Changelog?: ✅ / ❌ / N/A

Additional Information

Summary by CodeRabbit

  • New Features

    • Added skip-softmax tile-skipping to Triton flash attention with a configurable threshold (default 0.1) and a new keyword to the attention API to enable it.
    • Exposed a default sparse configuration enabling the Triton skip-softmax method.
  • Tests

    • Added comprehensive GPU tests covering threshold behavior, correctness, and HuggingFace integration.
  • Documentation

    • Updated changelog with the new feature.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 20, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 20, 2026

📝 Walkthrough

Walkthrough

Adds a Triton-side skip-softmax tile-skipping optimization to flash attention, integrates it into the sparsity config/registration and HF wrapper, exposes a runtime skip_softmax_threshold in the attention API, and adds unit and integration tests plus a changelog entry.

Changes

Cohort / File(s) Summary
Triton FA kernel
modelopt/torch/kernels/triton_fa.py
Adds APPLY_SKIP_SOFTMAX and SKIP_THRESHOLD_LOG2 constexprs; implements per-row/tile max check to optionally skip softmax updates and V loads/accumulation in forward; backward kernels reapply skip mask so skipped rows do not contribute to gradients; kernels accept runtime flags.
HF attention wrapper
modelopt/torch/kernels/hf_triton_attention.py
triton_attention_forward reads module _apply_skip_softmax and the sparse-method instance threshold, and conditionally passes skip_softmax_threshold into Triton kernel kwargs.
Public API
modelopt/torch/kernels/triton_fa.py (public attention API)
attention(...) gains keyword-only `skip_softmax_threshold: float
Sparsity config & registration
modelopt/torch/sparsity/attention_sparsity/config.py, modelopt/torch/sparsity/attention_sparsity/methods/__init__.py
Adds SparseAttentionAttributeConfig.skip_softmax_threshold: float = 0.1, new SKIP_SOFTMAX_TRITON_DEFAULT preset (method=triton_skip_softmax, backend=triton, enable=True), and imports triton_skip_softmax in package init to register the method.
New sparse method
modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py
Adds TritonSkipSoftmaxMethod (registered as triton_skip_softmax) that stores skip_softmax_threshold, returns an all-True mask from calculate_sparsity, raises in apply_sparsity, and provides a context manager to toggle module._apply_skip_softmax.
Tests & changelog
tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py, CHANGELOG.rst
Adds GPU/triton tests covering disabled/low/high thresholds, shape invariants, monotonic error checks, decode-mode behavior, and a HF integration test applying mtsa.SKIP_SOFTMAX_TRITON_DEFAULT; appends changelog entry for skip-softmax feature.

Sequence Diagram(s)

mermaid
sequenceDiagram
participant HF as HF attention wrapper
participant Sparsity as Sparsity method / config
participant Triton as Triton FA kernel
participant Autograd as Autograd ctx
HF->>Sparsity: get sparse method instance / threshold
Sparsity-->>HF: return skip_softmax_threshold (or None)
HF->>Triton: call attention(..., skip_softmax_threshold=val)
Triton->>Triton: compute tile_row_max, derive can_skip / all_skip
alt tile skippable
Triton-->>Triton: skip softmax update; skip V load/accumulation
else not skippable
Triton-->>Triton: perform online softmax update; accumulate V
end
Triton-->>Autograd: store skip flags in ctx for backward
Autograd->>Triton: backward launch (reapply skip mask to gradients)

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 78.26% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Add skip-softmax to Triton flash attention kernel' clearly summarizes the main change: introducing skip-softmax tile skipping functionality to the Triton flash attention kernel, as evidenced by changes across triton_fa.py, configuration files, test coverage, and integration with the sparsity API.
Security Anti-Patterns ✅ Passed No security anti-patterns found in modified Python files. All code avoids unsafe deserialization, hardcoded trust_remote_code=True, eval/exec on external input, nosec bypass comments, and new dependencies.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch kaix/triton_fa_skip_softmax

Comment @coderabbitai help to get the list of available commands and usage tips.

@codecov
Copy link

codecov bot commented Mar 20, 2026

Codecov Report

❌ Patch coverage is 51.85185% with 13 lines in your changes missing coverage. Please review.
✅ Project coverage is 70.21%. Comparing base (b61fb4e) to head (ecc5540).
⚠️ Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
.../attention_sparsity/methods/triton_skip_softmax.py 45.83% 13 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1081      +/-   ##
==========================================
- Coverage   70.24%   70.21%   -0.04%     
==========================================
  Files         227      228       +1     
  Lines       25909    25935      +26     
==========================================
+ Hits        18201    18211      +10     
- Misses       7708     7724      +16     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@kaix-nv kaix-nv changed the title Add skip-softmax tile skipping to Triton flash attention kernel Add skip-softmax to Triton flash attention kernel Mar 21, 2026
@kaix-nv kaix-nv force-pushed the kaix/triton_fa_skip_softmax branch from 9225466 to cc0e9b3 Compare March 21, 2026 00:51
@kaix-nv kaix-nv marked this pull request as ready for review March 21, 2026 21:33
@kaix-nv kaix-nv requested a review from a team as a code owner March 21, 2026 21:33
@kaix-nv kaix-nv force-pushed the kaix/triton_fa_skip_softmax branch from cc0e9b3 to 270b94e Compare March 21, 2026 21:33
@kaix-nv kaix-nv requested a review from a team as a code owner March 21, 2026 21:33
@kaix-nv kaix-nv force-pushed the kaix/triton_fa_skip_softmax branch from 270b94e to 6c65ef3 Compare March 21, 2026 21:35
@kaix-nv kaix-nv requested review from rohansjoshi and removed request for shengliangxu March 21, 2026 21:35
@kaix-nv kaix-nv force-pushed the kaix/triton_fa_skip_softmax branch from 6c65ef3 to 012fb20 Compare March 21, 2026 21:37
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py (1)

494-505: Avoid asserting monotonic MAE from random samples.

Lines 500-505 assume that increasing the threshold must increase mean(abs(out_skip - out_dense)), but that is not guaranteed; extra skipped tiles can still reduce the final error through cancellation on a fixed input. This is likely to be flaky across seeds and GPU/dtype combinations. Prefer a directly monotonic signal, or weaken the expectation.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py` around lines
494 - 505, The test `test_monotonic_approximation_error` assumes mean absolute
error increases strictly with skip_softmax_threshold, which is flaky; change the
assertion to a weaker, robust check: compute errors for thresholds via
attention(q,k,v,locs,lens,512,softmax_scale=scale,skip_softmax_threshold=...),
then either remove the strict stepwise monotonic assertion and instead assert a
single inequality between the smallest and largest thresholds with a tolerance
(e.g., errors[0] <= errors[-1] + tol) or allow small per-step regressions by
checking non-decrease within a small relative/absolute tolerance; update the
final assert accordingly and keep references to the variables/functions used
(attention, out_dense, out_skip, errors, skip_softmax_threshold).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 380-384: The backward pass is recomputing the skip mask from final
lse which can differ from the forward per-tile running row_max; instead persist
the exact forward skip decisions (or the pre-tile row_max used in forward) so
the backward replays them exactly: modify the forward path that computes
tile_row_max / can_skip (used when APPLY_SKIP_SOFTMAX) to store the boolean skip
mask (or the pre-tile max) alongside tensors needed for backward and have the
backward use that saved mask when zeroing p (rather than recomputing can_skip
from lse and SKIP_THRESHOLD_LOG2); as a short-term alternative, gate
APPLY_SKIP_SOFTMAX to inference-only until you add this saved metadata so
gradients remain correct.

In `@modelopt/torch/sparsity/attention_sparsity/config.py`:
- Around line 99-107: The skip_softmax_threshold field must be validated to
ensure it is a fraction in [0, 1]; update the config parsing/validation so
negative values or values >1 raise during parse rather than silently changing
kernel behavior. Modify the typed config that defines skip_softmax_threshold
(the ModeloptField) to enforce 0.0 <= skip_softmax_threshold <= 1.0 — either by
adding a pydantic validator for skip_softmax_threshold or adding an explicit
check in the config class constructor/__post_init__ that raises a ValueError
with a clear message if the constraint is violated. Ensure the error triggers
during config parse/instantiation so callers get immediate feedback.

---

Nitpick comments:
In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py`:
- Around line 494-505: The test `test_monotonic_approximation_error` assumes
mean absolute error increases strictly with skip_softmax_threshold, which is
flaky; change the assertion to a weaker, robust check: compute errors for
thresholds via
attention(q,k,v,locs,lens,512,softmax_scale=scale,skip_softmax_threshold=...),
then either remove the strict stepwise monotonic assertion and instead assert a
single inequality between the smallest and largest thresholds with a tolerance
(e.g., errors[0] <= errors[-1] + tol) or allow small per-step regressions by
checking non-decrease within a small relative/absolute tolerance; update the
final assert accordingly and keep references to the variables/functions used
(attention, out_dense, out_skip, errors, skip_softmax_threshold).

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 4a8c1dda-739d-4a3e-b939-e729f5e6858d

📥 Commits

Reviewing files that changed from the base of the PR and between 08e5f92 and 270b94eb73c2d1ce98f0ca3e7e478e51c1ef342f.

📒 Files selected for processing (6)
  • CHANGELOG.rst
  • modelopt/torch/kernels/hf_triton_attention.py
  • modelopt/torch/kernels/triton_fa.py
  • modelopt/torch/sparsity/attention_sparsity/config.py
  • modelopt/torch/sparsity/attention_sparsity/methods/__init__.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py

Comment on lines +380 to +384
# Re-apply skip-softmax: zero out rows that were skipped in forward
if APPLY_SKIP_SOFTMAX:
tile_row_max = tl.max(scores, 1)
can_skip = tile_row_max < (lse + SKIP_THRESHOLD_LOG2)
p = tl.where(can_skip[:, None], 0.0, p)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Backward is not replaying the forward skip decisions.

Lines 380-384 and 510-514 rebuild can_skip from final lse, but forward skipped against the per-tile running row_max. Since lse >= row_max_pre_tile, backward can zero gradients for tiles that were kept in forward, so skip_softmax_threshold gives silently wrong grads. Please save the exact forward skip mask / pre-tile max for backward, or gate this mode to inference until that metadata exists.

🛡️ Safe short-term guard
         apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0
+        if apply_skip and (q.requires_grad or k.requires_grad or v.requires_grad):
+            raise NotImplementedError(
+                "skip_softmax_threshold is inference-only until backward can replay "
+                "the exact forward skip decisions."
+            )
         if apply_skip:
             import math

Also applies to: 510-514, 616-628

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/triton_fa.py` around lines 380 - 384, The backward
pass is recomputing the skip mask from final lse which can differ from the
forward per-tile running row_max; instead persist the exact forward skip
decisions (or the pre-tile row_max used in forward) so the backward replays them
exactly: modify the forward path that computes tile_row_max / can_skip (used
when APPLY_SKIP_SOFTMAX) to store the boolean skip mask (or the pre-tile max)
alongside tensors needed for backward and have the backward use that saved mask
when zeroing p (rather than recomputing can_skip from lse and
SKIP_THRESHOLD_LOG2); as a short-term alternative, gate APPLY_SKIP_SOFTMAX to
inference-only until you add this saved metadata so gradients remain correct.

Comment on lines +99 to +107
skip_softmax_threshold: float = ModeloptField(
default=0.1,
title="Skip-softmax threshold.",
description=(
"Tiles contributing less than this fraction are skipped entirely. "
"Only used by triton_skip_softmax. Typical values: 1e-3 to 1e-1. "
"Set to 0 to disable."
),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Validate skip_softmax_threshold in the typed config.

Line 99 introduces a public fraction, but negative or >1 values currently pass validation and change kernel behavior in surprising ways. Reject them at parse time instead of silently treating them as “disabled” or “skip almost everything.”

🧩 Suggested constraint
     skip_softmax_threshold: float = ModeloptField(
         default=0.1,
         title="Skip-softmax threshold.",
         description=(
             "Tiles contributing less than this fraction are skipped entirely. "
             "Only used by triton_skip_softmax. Typical values: 1e-3 to 1e-1. "
             "Set to 0 to disable."
         ),
+        ge=0.0,
+        le=1.0,
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/config.py` around lines 99 - 107,
The skip_softmax_threshold field must be validated to ensure it is a fraction in
[0, 1]; update the config parsing/validation so negative values or values >1
raise during parse rather than silently changing kernel behavior. Modify the
typed config that defines skip_softmax_threshold (the ModeloptField) to enforce
0.0 <= skip_softmax_threshold <= 1.0 — either by adding a pydantic validator for
skip_softmax_threshold or adding an explicit check in the config class
constructor/__post_init__ that raises a ValueError with a clear message if the
constraint is violated. Ensure the error triggers during config
parse/instantiation so callers get immediate feedback.

Comment on lines +562 to +600
def test_skip_softmax_via_sparsify(self, tiny_llama_dir):
"""mtsa.sparsify() with triton_skip_softmax produces finite logits."""
pytest.importorskip("transformers")
from transformers import AutoModelForCausalLM, AutoTokenizer

import modelopt.torch.sparsity.attention_sparsity as mtsa

tok = AutoTokenizer.from_pretrained(tiny_llama_dir)
if tok.pad_token_id is None:
tok.pad_token_id = tok.eos_token_id
ids = torch.randint(1, tok.vocab_size, (1, 64), device="cuda")

# Dense baseline (triton backend, no skip)
model_dense = AutoModelForCausalLM.from_pretrained(
tiny_llama_dir,
attn_implementation="modelopt_triton",
torch_dtype=torch.bfloat16,
device_map="cuda",
)
model_dense.eval()
with torch.no_grad():
logits_dense = model_dense(input_ids=ids).logits
del model_dense

# Skip-softmax via mtsa.sparsify()
model_skip = AutoModelForCausalLM.from_pretrained(
tiny_llama_dir,
torch_dtype=torch.bfloat16,
device_map="cuda",
)
mtsa.sparsify(model_skip, mtsa.SKIP_SOFTMAX_TRITON_DEFAULT)
model_skip.eval()
with torch.no_grad():
logits_skip = model_skip(input_ids=ids).logits

assert not torch.isnan(logits_skip).any(), "NaN in skip-softmax logits"
assert not torch.isinf(logits_skip).any(), "Inf in skip-softmax logits"
# On short sequences (64 tokens), no tiles are skipped — output should match dense
torch.testing.assert_close(logits_skip, logits_dense, rtol=1e-3, atol=1e-3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

This HF integration test never turns skip-softmax on.

With BLOCK_N=64 pinned in modelopt/torch/kernels/triton_fa.py Lines 44-45 under pytest, the 64-token input on Line 572 yields exactly one KV tile, and the first tile can never be skipped. That means this still passes even if triton_attention_forward() stops forwarding skip_softmax_threshold. Use a multi-tile input and a deliberately aggressive threshold so the new plumbing is actually exercised.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
modelopt/torch/kernels/triton_fa.py (1)

380-384: ⚠️ Potential issue | 🔴 Critical

Backward still cannot replay the forward skip decisions.

ctx only saves the scalar skip flag/threshold, so these backward kernels rebuild can_skip from final lse instead of the pre-tile row_max used in forward. Since lse is always at least as large as the forward running max, backward can zero gradients for tiles that were kept in forward. Please either persist the exact forward mask / pre-tile max or keep skip_softmax_threshold inference-only until backward can replay the same predicate. The public docstring should not claim “the same skip decision” until this is fixed.

🛡️ Safe short-term guard
         apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0
+        if apply_skip and (q.requires_grad or k.requires_grad or v.requires_grad):
+            raise NotImplementedError(
+                "skip_softmax_threshold is inference-only until backward can replay "
+                "the exact forward skip decisions."
+            )
         if apply_skip:
             import math

Also applies to: 510-514, 627-628

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/triton_fa.py` around lines 380 - 384, The backward
kernels are recomputing the skip predicate from lse, which differs from the
forward pre-tile max and causes incorrect gradient zeroing; change the forward
pass to save the exact per-tile skip mask or the pre-tile row_max into ctx (not
just the scalar skip_softmax_threshold) and have the backward kernels (the code
paths using APPLY_SKIP_SOFTMAX where can_skip is computed) read that saved
mask/value from ctx to reconstruct the exact same can_skip used in forward;
alternatively, make skip_softmax_threshold inference-only until backward can
replay the same predicate and update the public docstring to stop claiming “the
same skip decision” until fixed.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 568-575: Validate the skip_softmax_threshold value before
computing skip_threshold_log2: treat only None or 0.0 as disabled, and raise a
ValueError for NaN, infinite, negative, or >1 values (accept only values in the
open interval (0, 1] for enabling). Update the logic around apply_skip,
skip_softmax_threshold, and skip_threshold_log2 to perform this check and raise
early with a clear message, and apply the same validation to the other
occurrence of the same pattern in this file (the block around the second
occurrence noted in the comment).

---

Duplicate comments:
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 380-384: The backward kernels are recomputing the skip predicate
from lse, which differs from the forward pre-tile max and causes incorrect
gradient zeroing; change the forward pass to save the exact per-tile skip mask
or the pre-tile row_max into ctx (not just the scalar skip_softmax_threshold)
and have the backward kernels (the code paths using APPLY_SKIP_SOFTMAX where
can_skip is computed) read that saved mask/value from ctx to reconstruct the
exact same can_skip used in forward; alternatively, make skip_softmax_threshold
inference-only until backward can replay the same predicate and update the
public docstring to stop claiming “the same skip decision” until fixed.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: fa0b5612-c9b2-47f7-a1bf-cb211e19a57e

📥 Commits

Reviewing files that changed from the base of the PR and between 270b94eb73c2d1ce98f0ca3e7e478e51c1ef342f and 012fb20.

📒 Files selected for processing (6)
  • CHANGELOG.rst
  • modelopt/torch/kernels/hf_triton_attention.py
  • modelopt/torch/kernels/triton_fa.py
  • modelopt/torch/sparsity/attention_sparsity/config.py
  • modelopt/torch/sparsity/attention_sparsity/methods/__init__.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py
✅ Files skipped from review due to trivial changes (1)
  • CHANGELOG.rst
🚧 Files skipped from review as they are similar to previous changes (4)
  • modelopt/torch/kernels/hf_triton_attention.py
  • modelopt/torch/sparsity/attention_sparsity/methods/init.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py
  • modelopt/torch/sparsity/attention_sparsity/config.py

Comment on lines +568 to +575
# Skip-softmax: convert threshold to log2 space for the kernel
apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0
if apply_skip:
import math

skip_threshold_log2 = math.log2(skip_softmax_threshold)
else:
skip_threshold_log2 = 0.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Reject invalid skip_softmax_threshold values up front.

This knob is documented as a contribution fraction, but the host-side parsing currently accepts nan, inf, negatives, and values above 1. That means a typo can either silently disable the feature or make later tiles overly skippable. Please reserve None/0 as the only disable cases and raise on anything outside (0, 1].

🧪 Proposed fix
-        apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0
-        if apply_skip:
-            import math
-
-            skip_threshold_log2 = math.log2(skip_softmax_threshold)
+        import math
+
+        if skip_softmax_threshold is None or skip_softmax_threshold == 0.0:
+            apply_skip = False
+        else:
+            if not math.isfinite(skip_softmax_threshold) or not (0.0 < skip_softmax_threshold <= 1.0):
+                raise ValueError(
+                    "skip_softmax_threshold must be a finite float in (0, 1], or None/0 to disable."
+                )
+            apply_skip = True
+
+        if apply_skip:
+            skip_threshold_log2 = math.log2(skip_softmax_threshold)
         else:
             skip_threshold_log2 = 0.0

Also applies to: 762-768

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/triton_fa.py` around lines 568 - 575, Validate the
skip_softmax_threshold value before computing skip_threshold_log2: treat only
None or 0.0 as disabled, and raise a ValueError for NaN, infinite, negative, or
>1 values (accept only values in the open interval (0, 1] for enabling). Update
the logic around apply_skip, skip_softmax_threshold, and skip_threshold_log2 to
perform this check and raise early with a clear message, and apply the same
validation to the other occurrence of the same pattern in this file (the block
around the second occurrence noted in the comment).

@kaix-nv kaix-nv changed the title Add skip-softmax to Triton flash attention kernel [3/n] Add skip-softmax to Triton flash attention kernel Mar 23, 2026
kaix-nv added 2 commits March 23, 2026 15:35
Signed-off-by: Kai Xu <kaix@nvidia.com>
@kaix-nv kaix-nv force-pushed the kaix/triton_fa_skip_softmax branch from 012fb20 to ecc5540 Compare March 23, 2026 22:35
@github-actions
Copy link
Contributor

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1081/

Built to branch gh-pages at 2026-03-23 22:39 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (3)
modelopt/torch/kernels/triton_fa.py (2)

380-384: ⚠️ Potential issue | 🔴 Critical

Do not enable skip_softmax_threshold during training yet.

Forward skips against the pre-tile running row_max, but these backward paths rebuild can_skip from final lse. That can zero gradients for tiles that were not skipped in forward, so training with this flag is still incorrect. Please either persist the exact forward skip mask / pre-tile max or gate this mode to inference only.

🛡️ Safe short-term guard
         apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0
+        if apply_skip and (q.requires_grad or k.requires_grad or v.requires_grad):
+            raise NotImplementedError(
+                "skip_softmax_threshold is inference-only until backward can replay "
+                "the exact forward skip decisions."
+            )
         if apply_skip:
             import math

Also applies to: 510-514

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/triton_fa.py` around lines 380 - 384, The backward
code recomputes can_skip from lse and SKIP_THRESHOLD_LOG2 which can differ from
the forward decision (tile_row_max) causing incorrect zeroed gradients when
APPLY_SKIP_SOFTMAX (skip_softmax_threshold) is enabled; fix by either (A)
persisting the exact forward skip mask (compute and store tile_row_max and/or
can_skip from the forward pass and reuse that mask in the backward path when
restoring p) or (B) disallowing this mode during training by gating
APPLY_SKIP_SOFTMAX/skip_softmax_threshold to inference-only; update logic that
references tile_row_max, can_skip, scores, lse and p to use the persisted mask
or the inference-only guard accordingly.

568-575: ⚠️ Potential issue | 🟡 Minor

Validate skip_softmax_threshold before computing log2.

Only None and 0 are documented disable cases. Negative, non-finite, or >1 values currently either get silently treated as off or make skipping much more aggressive than the API contract suggests.

🧪 Suggested input validation
-        # Skip-softmax: convert threshold to log2 space for the kernel
-        apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0
-        if apply_skip:
-            import math
-
-            skip_threshold_log2 = math.log2(skip_softmax_threshold)
-        else:
-            skip_threshold_log2 = 0.0
+        # Skip-softmax: convert threshold to log2 space for the kernel
+        import math
+
+        if skip_softmax_threshold is None or skip_softmax_threshold == 0.0:
+            apply_skip = False
+            skip_threshold_log2 = 0.0
+        else:
+            if not math.isfinite(skip_softmax_threshold) or not (0.0 < skip_softmax_threshold <= 1.0):
+                raise ValueError(
+                    "skip_softmax_threshold must be a finite float in (0, 1], or None/0 to disable."
+                )
+            apply_skip = True
+            skip_threshold_log2 = math.log2(skip_softmax_threshold)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/triton_fa.py` around lines 568 - 575, Validate
skip_softmax_threshold before computing log2: when computing apply_skip and
skip_threshold_log2 (use the variables skip_softmax_threshold, apply_skip,
skip_threshold_log2 and the math.log2 call), ensure that if
skip_softmax_threshold is not None it is a finite numeric value and within the
documented range (0 < value <= 1); treat 0 or None as “off”; for values that are
negative, non-finite (NaN/inf) or >1 raise a clear ValueError (or TypeError for
wrong type) with a message explaining allowed values so the code never silently
treats invalid inputs as off or miscomputes the log2.
modelopt/torch/sparsity/attention_sparsity/config.py (1)

99-107: ⚠️ Potential issue | 🟡 Minor

Validate skip_softmax_threshold during config parsing.

This new public fraction still accepts negatives, non-finite values, and values above 1, which makes the Triton path either silently disable skipping or skip far too aggressively. Reject invalid values when the config is instantiated instead of relying on runtime behavior.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/config.py` around lines 99 - 107,
The public field skip_softmax_threshold can be negative, non-finite, or >1; add
validation at config instantiation so invalid values are rejected early: in the
config class that defines skip_softmax_threshold (the class using
ModeloptField), implement a validation step (e.g. __post_init__ or a
pydantic/ModeloptField validator) that checks the value is finite and 0.0 <=
skip_softmax_threshold <= 1.0 and raise a ValueError with a clear message if
not; this ensures invalid inputs are caught when the config is created rather
than at runtime in triton_skip_softmax.
🧹 Nitpick comments (1)
modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py (1)

65-71: Restore the previous module flag in finally.

This context manager always writes False on exit, so nested or stacked uses on the same module can clobber an outer active context. Restoring the prior value makes the activation state composable.

♻️ Suggested fix
         `@contextmanager`
         def _skip_softmax_context():
+            prev_apply_skip_softmax = getattr(module, "_apply_skip_softmax", False)
             module._apply_skip_softmax = True
             try:
                 yield
             finally:
-                module._apply_skip_softmax = False
+                module._apply_skip_softmax = prev_apply_skip_softmax
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py`
around lines 65 - 71, The _skip_softmax_context context manager currently
overwrites module._apply_skip_softmax to False on exit, which breaks nested
contexts; modify _skip_softmax_context to save the prior value (e.g., prev =
module._apply_skip_softmax), set module._apply_skip_softmax = True on entry, and
in the finally block restore module._apply_skip_softmax = prev so nested or
stacked uses of the context preserve outer states (apply this change inside the
_skip_softmax_context definition).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 380-384: The backward code recomputes can_skip from lse and
SKIP_THRESHOLD_LOG2 which can differ from the forward decision (tile_row_max)
causing incorrect zeroed gradients when APPLY_SKIP_SOFTMAX
(skip_softmax_threshold) is enabled; fix by either (A) persisting the exact
forward skip mask (compute and store tile_row_max and/or can_skip from the
forward pass and reuse that mask in the backward path when restoring p) or (B)
disallowing this mode during training by gating
APPLY_SKIP_SOFTMAX/skip_softmax_threshold to inference-only; update logic that
references tile_row_max, can_skip, scores, lse and p to use the persisted mask
or the inference-only guard accordingly.
- Around line 568-575: Validate skip_softmax_threshold before computing log2:
when computing apply_skip and skip_threshold_log2 (use the variables
skip_softmax_threshold, apply_skip, skip_threshold_log2 and the math.log2 call),
ensure that if skip_softmax_threshold is not None it is a finite numeric value
and within the documented range (0 < value <= 1); treat 0 or None as “off”; for
values that are negative, non-finite (NaN/inf) or >1 raise a clear ValueError
(or TypeError for wrong type) with a message explaining allowed values so the
code never silently treats invalid inputs as off or miscomputes the log2.

In `@modelopt/torch/sparsity/attention_sparsity/config.py`:
- Around line 99-107: The public field skip_softmax_threshold can be negative,
non-finite, or >1; add validation at config instantiation so invalid values are
rejected early: in the config class that defines skip_softmax_threshold (the
class using ModeloptField), implement a validation step (e.g. __post_init__ or a
pydantic/ModeloptField validator) that checks the value is finite and 0.0 <=
skip_softmax_threshold <= 1.0 and raise a ValueError with a clear message if
not; this ensures invalid inputs are caught when the config is created rather
than at runtime in triton_skip_softmax.

---

Nitpick comments:
In `@modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py`:
- Around line 65-71: The _skip_softmax_context context manager currently
overwrites module._apply_skip_softmax to False on exit, which breaks nested
contexts; modify _skip_softmax_context to save the prior value (e.g., prev =
module._apply_skip_softmax), set module._apply_skip_softmax = True on entry, and
in the finally block restore module._apply_skip_softmax = prev so nested or
stacked uses of the context preserve outer states (apply this change inside the
_skip_softmax_context definition).

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: b3835fd4-3e45-4467-a16f-8477c8ba3c2c

📥 Commits

Reviewing files that changed from the base of the PR and between 012fb20 and ecc5540.

📒 Files selected for processing (7)
  • CHANGELOG.rst
  • modelopt/torch/kernels/hf_triton_attention.py
  • modelopt/torch/kernels/triton_fa.py
  • modelopt/torch/sparsity/attention_sparsity/config.py
  • modelopt/torch/sparsity/attention_sparsity/methods/__init__.py
  • modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py
✅ Files skipped from review due to trivial changes (2)
  • CHANGELOG.rst
  • modelopt/torch/sparsity/attention_sparsity/methods/init.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py

@kaix-nv kaix-nv requested review from Edwardf0t1 and jingyu-ml March 24, 2026 00:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant