[3/n] Add skip-softmax to Triton flash attention kernel#1081
[3/n] Add skip-softmax to Triton flash attention kernel#1081
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
📝 WalkthroughWalkthroughAdds a Triton-side skip-softmax tile-skipping optimization to flash attention, integrates it into the sparsity config/registration and HF wrapper, exposes a runtime Changes
Sequence Diagram(s)mermaid Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1081 +/- ##
==========================================
- Coverage 70.24% 70.21% -0.04%
==========================================
Files 227 228 +1
Lines 25909 25935 +26
==========================================
+ Hits 18201 18211 +10
- Misses 7708 7724 +16 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
9225466 to
cc0e9b3
Compare
cc0e9b3 to
270b94e
Compare
270b94e to
6c65ef3
Compare
6c65ef3 to
012fb20
Compare
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (1)
tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py (1)
494-505: Avoid asserting monotonic MAE from random samples.Lines 500-505 assume that increasing the threshold must increase
mean(abs(out_skip - out_dense)), but that is not guaranteed; extra skipped tiles can still reduce the final error through cancellation on a fixed input. This is likely to be flaky across seeds and GPU/dtype combinations. Prefer a directly monotonic signal, or weaken the expectation.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py` around lines 494 - 505, The test `test_monotonic_approximation_error` assumes mean absolute error increases strictly with skip_softmax_threshold, which is flaky; change the assertion to a weaker, robust check: compute errors for thresholds via attention(q,k,v,locs,lens,512,softmax_scale=scale,skip_softmax_threshold=...), then either remove the strict stepwise monotonic assertion and instead assert a single inequality between the smallest and largest thresholds with a tolerance (e.g., errors[0] <= errors[-1] + tol) or allow small per-step regressions by checking non-decrease within a small relative/absolute tolerance; update the final assert accordingly and keep references to the variables/functions used (attention, out_dense, out_skip, errors, skip_softmax_threshold).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 380-384: The backward pass is recomputing the skip mask from final
lse which can differ from the forward per-tile running row_max; instead persist
the exact forward skip decisions (or the pre-tile row_max used in forward) so
the backward replays them exactly: modify the forward path that computes
tile_row_max / can_skip (used when APPLY_SKIP_SOFTMAX) to store the boolean skip
mask (or the pre-tile max) alongside tensors needed for backward and have the
backward use that saved mask when zeroing p (rather than recomputing can_skip
from lse and SKIP_THRESHOLD_LOG2); as a short-term alternative, gate
APPLY_SKIP_SOFTMAX to inference-only until you add this saved metadata so
gradients remain correct.
In `@modelopt/torch/sparsity/attention_sparsity/config.py`:
- Around line 99-107: The skip_softmax_threshold field must be validated to
ensure it is a fraction in [0, 1]; update the config parsing/validation so
negative values or values >1 raise during parse rather than silently changing
kernel behavior. Modify the typed config that defines skip_softmax_threshold
(the ModeloptField) to enforce 0.0 <= skip_softmax_threshold <= 1.0 — either by
adding a pydantic validator for skip_softmax_threshold or adding an explicit
check in the config class constructor/__post_init__ that raises a ValueError
with a clear message if the constraint is violated. Ensure the error triggers
during config parse/instantiation so callers get immediate feedback.
---
Nitpick comments:
In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py`:
- Around line 494-505: The test `test_monotonic_approximation_error` assumes
mean absolute error increases strictly with skip_softmax_threshold, which is
flaky; change the assertion to a weaker, robust check: compute errors for
thresholds via
attention(q,k,v,locs,lens,512,softmax_scale=scale,skip_softmax_threshold=...),
then either remove the strict stepwise monotonic assertion and instead assert a
single inequality between the smallest and largest thresholds with a tolerance
(e.g., errors[0] <= errors[-1] + tol) or allow small per-step regressions by
checking non-decrease within a small relative/absolute tolerance; update the
final assert accordingly and keep references to the variables/functions used
(attention, out_dense, out_skip, errors, skip_softmax_threshold).
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 4a8c1dda-739d-4a3e-b939-e729f5e6858d
📥 Commits
Reviewing files that changed from the base of the PR and between 08e5f92 and 270b94eb73c2d1ce98f0ca3e7e478e51c1ef342f.
📒 Files selected for processing (6)
CHANGELOG.rstmodelopt/torch/kernels/hf_triton_attention.pymodelopt/torch/kernels/triton_fa.pymodelopt/torch/sparsity/attention_sparsity/config.pymodelopt/torch/sparsity/attention_sparsity/methods/__init__.pytests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py
modelopt/torch/kernels/triton_fa.py
Outdated
| # Re-apply skip-softmax: zero out rows that were skipped in forward | ||
| if APPLY_SKIP_SOFTMAX: | ||
| tile_row_max = tl.max(scores, 1) | ||
| can_skip = tile_row_max < (lse + SKIP_THRESHOLD_LOG2) | ||
| p = tl.where(can_skip[:, None], 0.0, p) |
There was a problem hiding this comment.
Backward is not replaying the forward skip decisions.
Lines 380-384 and 510-514 rebuild can_skip from final lse, but forward skipped against the per-tile running row_max. Since lse >= row_max_pre_tile, backward can zero gradients for tiles that were kept in forward, so skip_softmax_threshold gives silently wrong grads. Please save the exact forward skip mask / pre-tile max for backward, or gate this mode to inference until that metadata exists.
🛡️ Safe short-term guard
apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0
+ if apply_skip and (q.requires_grad or k.requires_grad or v.requires_grad):
+ raise NotImplementedError(
+ "skip_softmax_threshold is inference-only until backward can replay "
+ "the exact forward skip decisions."
+ )
if apply_skip:
import mathAlso applies to: 510-514, 616-628
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/kernels/triton_fa.py` around lines 380 - 384, The backward
pass is recomputing the skip mask from final lse which can differ from the
forward per-tile running row_max; instead persist the exact forward skip
decisions (or the pre-tile row_max used in forward) so the backward replays them
exactly: modify the forward path that computes tile_row_max / can_skip (used
when APPLY_SKIP_SOFTMAX) to store the boolean skip mask (or the pre-tile max)
alongside tensors needed for backward and have the backward use that saved mask
when zeroing p (rather than recomputing can_skip from lse and
SKIP_THRESHOLD_LOG2); as a short-term alternative, gate APPLY_SKIP_SOFTMAX to
inference-only until you add this saved metadata so gradients remain correct.
| skip_softmax_threshold: float = ModeloptField( | ||
| default=0.1, | ||
| title="Skip-softmax threshold.", | ||
| description=( | ||
| "Tiles contributing less than this fraction are skipped entirely. " | ||
| "Only used by triton_skip_softmax. Typical values: 1e-3 to 1e-1. " | ||
| "Set to 0 to disable." | ||
| ), | ||
| ) |
There was a problem hiding this comment.
Validate skip_softmax_threshold in the typed config.
Line 99 introduces a public fraction, but negative or >1 values currently pass validation and change kernel behavior in surprising ways. Reject them at parse time instead of silently treating them as “disabled” or “skip almost everything.”
🧩 Suggested constraint
skip_softmax_threshold: float = ModeloptField(
default=0.1,
title="Skip-softmax threshold.",
description=(
"Tiles contributing less than this fraction are skipped entirely. "
"Only used by triton_skip_softmax. Typical values: 1e-3 to 1e-1. "
"Set to 0 to disable."
),
+ ge=0.0,
+ le=1.0,
)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/sparsity/attention_sparsity/config.py` around lines 99 - 107,
The skip_softmax_threshold field must be validated to ensure it is a fraction in
[0, 1]; update the config parsing/validation so negative values or values >1
raise during parse rather than silently changing kernel behavior. Modify the
typed config that defines skip_softmax_threshold (the ModeloptField) to enforce
0.0 <= skip_softmax_threshold <= 1.0 — either by adding a pydantic validator for
skip_softmax_threshold or adding an explicit check in the config class
constructor/__post_init__ that raises a ValueError with a clear message if the
constraint is violated. Ensure the error triggers during config
parse/instantiation so callers get immediate feedback.
| def test_skip_softmax_via_sparsify(self, tiny_llama_dir): | ||
| """mtsa.sparsify() with triton_skip_softmax produces finite logits.""" | ||
| pytest.importorskip("transformers") | ||
| from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
|
||
| import modelopt.torch.sparsity.attention_sparsity as mtsa | ||
|
|
||
| tok = AutoTokenizer.from_pretrained(tiny_llama_dir) | ||
| if tok.pad_token_id is None: | ||
| tok.pad_token_id = tok.eos_token_id | ||
| ids = torch.randint(1, tok.vocab_size, (1, 64), device="cuda") | ||
|
|
||
| # Dense baseline (triton backend, no skip) | ||
| model_dense = AutoModelForCausalLM.from_pretrained( | ||
| tiny_llama_dir, | ||
| attn_implementation="modelopt_triton", | ||
| torch_dtype=torch.bfloat16, | ||
| device_map="cuda", | ||
| ) | ||
| model_dense.eval() | ||
| with torch.no_grad(): | ||
| logits_dense = model_dense(input_ids=ids).logits | ||
| del model_dense | ||
|
|
||
| # Skip-softmax via mtsa.sparsify() | ||
| model_skip = AutoModelForCausalLM.from_pretrained( | ||
| tiny_llama_dir, | ||
| torch_dtype=torch.bfloat16, | ||
| device_map="cuda", | ||
| ) | ||
| mtsa.sparsify(model_skip, mtsa.SKIP_SOFTMAX_TRITON_DEFAULT) | ||
| model_skip.eval() | ||
| with torch.no_grad(): | ||
| logits_skip = model_skip(input_ids=ids).logits | ||
|
|
||
| assert not torch.isnan(logits_skip).any(), "NaN in skip-softmax logits" | ||
| assert not torch.isinf(logits_skip).any(), "Inf in skip-softmax logits" | ||
| # On short sequences (64 tokens), no tiles are skipped — output should match dense | ||
| torch.testing.assert_close(logits_skip, logits_dense, rtol=1e-3, atol=1e-3) |
There was a problem hiding this comment.
This HF integration test never turns skip-softmax on.
With BLOCK_N=64 pinned in modelopt/torch/kernels/triton_fa.py Lines 44-45 under pytest, the 64-token input on Line 572 yields exactly one KV tile, and the first tile can never be skipped. That means this still passes even if triton_attention_forward() stops forwarding skip_softmax_threshold. Use a multi-tile input and a deliberately aggressive threshold so the new plumbing is actually exercised.
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
modelopt/torch/kernels/triton_fa.py (1)
380-384:⚠️ Potential issue | 🔴 CriticalBackward still cannot replay the forward skip decisions.
ctxonly saves the scalar skip flag/threshold, so these backward kernels rebuildcan_skipfrom finallseinstead of the pre-tilerow_maxused in forward. Sincelseis always at least as large as the forward running max, backward can zero gradients for tiles that were kept in forward. Please either persist the exact forward mask / pre-tile max or keepskip_softmax_thresholdinference-only until backward can replay the same predicate. The public docstring should not claim “the same skip decision” until this is fixed.🛡️ Safe short-term guard
apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0 + if apply_skip and (q.requires_grad or k.requires_grad or v.requires_grad): + raise NotImplementedError( + "skip_softmax_threshold is inference-only until backward can replay " + "the exact forward skip decisions." + ) if apply_skip: import mathAlso applies to: 510-514, 627-628
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/kernels/triton_fa.py` around lines 380 - 384, The backward kernels are recomputing the skip predicate from lse, which differs from the forward pre-tile max and causes incorrect gradient zeroing; change the forward pass to save the exact per-tile skip mask or the pre-tile row_max into ctx (not just the scalar skip_softmax_threshold) and have the backward kernels (the code paths using APPLY_SKIP_SOFTMAX where can_skip is computed) read that saved mask/value from ctx to reconstruct the exact same can_skip used in forward; alternatively, make skip_softmax_threshold inference-only until backward can replay the same predicate and update the public docstring to stop claiming “the same skip decision” until fixed.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 568-575: Validate the skip_softmax_threshold value before
computing skip_threshold_log2: treat only None or 0.0 as disabled, and raise a
ValueError for NaN, infinite, negative, or >1 values (accept only values in the
open interval (0, 1] for enabling). Update the logic around apply_skip,
skip_softmax_threshold, and skip_threshold_log2 to perform this check and raise
early with a clear message, and apply the same validation to the other
occurrence of the same pattern in this file (the block around the second
occurrence noted in the comment).
---
Duplicate comments:
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 380-384: The backward kernels are recomputing the skip predicate
from lse, which differs from the forward pre-tile max and causes incorrect
gradient zeroing; change the forward pass to save the exact per-tile skip mask
or the pre-tile row_max into ctx (not just the scalar skip_softmax_threshold)
and have the backward kernels (the code paths using APPLY_SKIP_SOFTMAX where
can_skip is computed) read that saved mask/value from ctx to reconstruct the
exact same can_skip used in forward; alternatively, make skip_softmax_threshold
inference-only until backward can replay the same predicate and update the
public docstring to stop claiming “the same skip decision” until fixed.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: fa0b5612-c9b2-47f7-a1bf-cb211e19a57e
📥 Commits
Reviewing files that changed from the base of the PR and between 270b94eb73c2d1ce98f0ca3e7e478e51c1ef342f and 012fb20.
📒 Files selected for processing (6)
CHANGELOG.rstmodelopt/torch/kernels/hf_triton_attention.pymodelopt/torch/kernels/triton_fa.pymodelopt/torch/sparsity/attention_sparsity/config.pymodelopt/torch/sparsity/attention_sparsity/methods/__init__.pytests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py
✅ Files skipped from review due to trivial changes (1)
- CHANGELOG.rst
🚧 Files skipped from review as they are similar to previous changes (4)
- modelopt/torch/kernels/hf_triton_attention.py
- modelopt/torch/sparsity/attention_sparsity/methods/init.py
- tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py
- modelopt/torch/sparsity/attention_sparsity/config.py
| # Skip-softmax: convert threshold to log2 space for the kernel | ||
| apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0 | ||
| if apply_skip: | ||
| import math | ||
|
|
||
| skip_threshold_log2 = math.log2(skip_softmax_threshold) | ||
| else: | ||
| skip_threshold_log2 = 0.0 |
There was a problem hiding this comment.
Reject invalid skip_softmax_threshold values up front.
This knob is documented as a contribution fraction, but the host-side parsing currently accepts nan, inf, negatives, and values above 1. That means a typo can either silently disable the feature or make later tiles overly skippable. Please reserve None/0 as the only disable cases and raise on anything outside (0, 1].
🧪 Proposed fix
- apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0
- if apply_skip:
- import math
-
- skip_threshold_log2 = math.log2(skip_softmax_threshold)
+ import math
+
+ if skip_softmax_threshold is None or skip_softmax_threshold == 0.0:
+ apply_skip = False
+ else:
+ if not math.isfinite(skip_softmax_threshold) or not (0.0 < skip_softmax_threshold <= 1.0):
+ raise ValueError(
+ "skip_softmax_threshold must be a finite float in (0, 1], or None/0 to disable."
+ )
+ apply_skip = True
+
+ if apply_skip:
+ skip_threshold_log2 = math.log2(skip_softmax_threshold)
else:
skip_threshold_log2 = 0.0Also applies to: 762-768
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/kernels/triton_fa.py` around lines 568 - 575, Validate the
skip_softmax_threshold value before computing skip_threshold_log2: treat only
None or 0.0 as disabled, and raise a ValueError for NaN, infinite, negative, or
>1 values (accept only values in the open interval (0, 1] for enabling). Update
the logic around apply_skip, skip_softmax_threshold, and skip_threshold_log2 to
perform this check and raise early with a clear message, and apply the same
validation to the other occurrence of the same pattern in this file (the block
around the second occurrence noted in the comment).
Signed-off-by: Kai Xu <kaix@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
012fb20 to
ecc5540
Compare
|
There was a problem hiding this comment.
♻️ Duplicate comments (3)
modelopt/torch/kernels/triton_fa.py (2)
380-384:⚠️ Potential issue | 🔴 CriticalDo not enable
skip_softmax_thresholdduring training yet.Forward skips against the pre-tile running
row_max, but these backward paths rebuildcan_skipfrom finallse. That can zero gradients for tiles that were not skipped in forward, so training with this flag is still incorrect. Please either persist the exact forward skip mask / pre-tile max or gate this mode to inference only.🛡️ Safe short-term guard
apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0 + if apply_skip and (q.requires_grad or k.requires_grad or v.requires_grad): + raise NotImplementedError( + "skip_softmax_threshold is inference-only until backward can replay " + "the exact forward skip decisions." + ) if apply_skip: import mathAlso applies to: 510-514
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/kernels/triton_fa.py` around lines 380 - 384, The backward code recomputes can_skip from lse and SKIP_THRESHOLD_LOG2 which can differ from the forward decision (tile_row_max) causing incorrect zeroed gradients when APPLY_SKIP_SOFTMAX (skip_softmax_threshold) is enabled; fix by either (A) persisting the exact forward skip mask (compute and store tile_row_max and/or can_skip from the forward pass and reuse that mask in the backward path when restoring p) or (B) disallowing this mode during training by gating APPLY_SKIP_SOFTMAX/skip_softmax_threshold to inference-only; update logic that references tile_row_max, can_skip, scores, lse and p to use the persisted mask or the inference-only guard accordingly.
568-575:⚠️ Potential issue | 🟡 MinorValidate
skip_softmax_thresholdbefore computinglog2.Only
Noneand0are documented disable cases. Negative, non-finite, or>1values currently either get silently treated as off or make skipping much more aggressive than the API contract suggests.🧪 Suggested input validation
- # Skip-softmax: convert threshold to log2 space for the kernel - apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0 - if apply_skip: - import math - - skip_threshold_log2 = math.log2(skip_softmax_threshold) - else: - skip_threshold_log2 = 0.0 + # Skip-softmax: convert threshold to log2 space for the kernel + import math + + if skip_softmax_threshold is None or skip_softmax_threshold == 0.0: + apply_skip = False + skip_threshold_log2 = 0.0 + else: + if not math.isfinite(skip_softmax_threshold) or not (0.0 < skip_softmax_threshold <= 1.0): + raise ValueError( + "skip_softmax_threshold must be a finite float in (0, 1], or None/0 to disable." + ) + apply_skip = True + skip_threshold_log2 = math.log2(skip_softmax_threshold)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/kernels/triton_fa.py` around lines 568 - 575, Validate skip_softmax_threshold before computing log2: when computing apply_skip and skip_threshold_log2 (use the variables skip_softmax_threshold, apply_skip, skip_threshold_log2 and the math.log2 call), ensure that if skip_softmax_threshold is not None it is a finite numeric value and within the documented range (0 < value <= 1); treat 0 or None as “off”; for values that are negative, non-finite (NaN/inf) or >1 raise a clear ValueError (or TypeError for wrong type) with a message explaining allowed values so the code never silently treats invalid inputs as off or miscomputes the log2.modelopt/torch/sparsity/attention_sparsity/config.py (1)
99-107:⚠️ Potential issue | 🟡 MinorValidate
skip_softmax_thresholdduring config parsing.This new public fraction still accepts negatives, non-finite values, and values above
1, which makes the Triton path either silently disable skipping or skip far too aggressively. Reject invalid values when the config is instantiated instead of relying on runtime behavior.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/config.py` around lines 99 - 107, The public field skip_softmax_threshold can be negative, non-finite, or >1; add validation at config instantiation so invalid values are rejected early: in the config class that defines skip_softmax_threshold (the class using ModeloptField), implement a validation step (e.g. __post_init__ or a pydantic/ModeloptField validator) that checks the value is finite and 0.0 <= skip_softmax_threshold <= 1.0 and raise a ValueError with a clear message if not; this ensures invalid inputs are caught when the config is created rather than at runtime in triton_skip_softmax.
🧹 Nitpick comments (1)
modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py (1)
65-71: Restore the previous module flag infinally.This context manager always writes
Falseon exit, so nested or stacked uses on the same module can clobber an outer active context. Restoring the prior value makes the activation state composable.♻️ Suggested fix
`@contextmanager` def _skip_softmax_context(): + prev_apply_skip_softmax = getattr(module, "_apply_skip_softmax", False) module._apply_skip_softmax = True try: yield finally: - module._apply_skip_softmax = False + module._apply_skip_softmax = prev_apply_skip_softmax🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py` around lines 65 - 71, The _skip_softmax_context context manager currently overwrites module._apply_skip_softmax to False on exit, which breaks nested contexts; modify _skip_softmax_context to save the prior value (e.g., prev = module._apply_skip_softmax), set module._apply_skip_softmax = True on entry, and in the finally block restore module._apply_skip_softmax = prev so nested or stacked uses of the context preserve outer states (apply this change inside the _skip_softmax_context definition).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 380-384: The backward code recomputes can_skip from lse and
SKIP_THRESHOLD_LOG2 which can differ from the forward decision (tile_row_max)
causing incorrect zeroed gradients when APPLY_SKIP_SOFTMAX
(skip_softmax_threshold) is enabled; fix by either (A) persisting the exact
forward skip mask (compute and store tile_row_max and/or can_skip from the
forward pass and reuse that mask in the backward path when restoring p) or (B)
disallowing this mode during training by gating
APPLY_SKIP_SOFTMAX/skip_softmax_threshold to inference-only; update logic that
references tile_row_max, can_skip, scores, lse and p to use the persisted mask
or the inference-only guard accordingly.
- Around line 568-575: Validate skip_softmax_threshold before computing log2:
when computing apply_skip and skip_threshold_log2 (use the variables
skip_softmax_threshold, apply_skip, skip_threshold_log2 and the math.log2 call),
ensure that if skip_softmax_threshold is not None it is a finite numeric value
and within the documented range (0 < value <= 1); treat 0 or None as “off”; for
values that are negative, non-finite (NaN/inf) or >1 raise a clear ValueError
(or TypeError for wrong type) with a message explaining allowed values so the
code never silently treats invalid inputs as off or miscomputes the log2.
In `@modelopt/torch/sparsity/attention_sparsity/config.py`:
- Around line 99-107: The public field skip_softmax_threshold can be negative,
non-finite, or >1; add validation at config instantiation so invalid values are
rejected early: in the config class that defines skip_softmax_threshold (the
class using ModeloptField), implement a validation step (e.g. __post_init__ or a
pydantic/ModeloptField validator) that checks the value is finite and 0.0 <=
skip_softmax_threshold <= 1.0 and raise a ValueError with a clear message if
not; this ensures invalid inputs are caught when the config is created rather
than at runtime in triton_skip_softmax.
---
Nitpick comments:
In `@modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py`:
- Around line 65-71: The _skip_softmax_context context manager currently
overwrites module._apply_skip_softmax to False on exit, which breaks nested
contexts; modify _skip_softmax_context to save the prior value (e.g., prev =
module._apply_skip_softmax), set module._apply_skip_softmax = True on entry, and
in the finally block restore module._apply_skip_softmax = prev so nested or
stacked uses of the context preserve outer states (apply this change inside the
_skip_softmax_context definition).
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: b3835fd4-3e45-4467-a16f-8477c8ba3c2c
📒 Files selected for processing (7)
CHANGELOG.rstmodelopt/torch/kernels/hf_triton_attention.pymodelopt/torch/kernels/triton_fa.pymodelopt/torch/sparsity/attention_sparsity/config.pymodelopt/torch/sparsity/attention_sparsity/methods/__init__.pymodelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.pytests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py
✅ Files skipped from review due to trivial changes (2)
- CHANGELOG.rst
- modelopt/torch/sparsity/attention_sparsity/methods/init.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py
What does this PR do?
Type of change: ?
New feature. Add skip-softmax tile skipping to the Triton flash attention kernel.
Usage
Testing
Performance (TFLOPS at seq_len=16384, RTX 6000 Pro):
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: ✅ / ❌ / N/AAdditional Information
Summary by CodeRabbit
New Features
Tests
Documentation