:feat(ws1): add NativeAttentionOp pure-PyTorch standard-softmax reference#188
:feat(ws1): add NativeAttentionOp pure-PyTorch standard-softmax reference#188maxiaosong1124 wants to merge 7 commits into
Conversation
WS1 ground-truth attention op for issue RL-Align#108 (Qwen3-8B GQA attention): - NativeAttentionOp: out = softmax(Q Kᵀ * scale + masks) @ V, a hand-written naive softmax (NOT F.scaled_dot_product_attention / flash) so the reduction order is fixed for the batch-invariance contract. GQA 32/8 via repeat_interleave, causal offset Skv-Sq+1 (prefill + decode), key_padding_mask (True=valid), scale default 1/sqrt(128). Exposes the forward / forward_fp32 dual-path contract (fp32 ground truth + dtype-behavior path); forward_fp32 disables TF32/autocast for a strict fp32 reference. Pure function, fp32 accumulation. - register PYTORCH_NATIVE_ATTENTION in OpBackend and the cuda/rocm/cpu priority maps under op_type "attention" (distinct from the production "attn" / PYTORCH_ATTN SDPA fallback) - tests/test_attention.py: forward_fp32 vs independent fp32 reference, closed-form causal/decode, GQA replication + divisibility guard, scale, key-padding, dtype-path accuracy (Axis-B), Axis-A batch invariance (slice + chunked + padding), purity, gradient flow, registry dispatch, GPU-only LARGE Qwen3-8B smoke - docs/operators/attention.md + nav/index wiring
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (2)
🚧 Files skipped from review as they are similar to previous changes (2)
📝 WalkthroughWalkthroughAdds a PyTorch-native reference softmax attention operator, registers it under ChangesNativeAttentionOp — implementation, registry, tests, and docs
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@docs/operators/attention.md`:
- Around line 16-20: Add a language tag to the fenced block in the attention
markdown doc so it no longer triggers MD040; update the existing fence around
the Q/K/V diagram to use a text-style code block while keeping the diagram
content unchanged.
- Around line 77-80: The attention dispatch description conflates backend
resolution with dtype-specific execution; update the
`kernel_registry.get_op("attention")` explanation to say it selects
`NativeAttentionOp` via the `OpBackend` priority map, while the public
`__call__` path dispatches through `forward(...)` based on input dtype rather
than always using `forward_fp32(...)`. Clarify that backend selection and API
path selection are separate, and keep the wording aligned with the
`PYTORCH_NATIVE_ATTENTION`/`NativeAttentionOp` behavior described in this
section.
In `@rl_engine/kernels/ops/pytorch/attention/standard_attn.py`:
- Around line 164-169: The attention path in standard_attn.py can produce NaNs
when key_padding_mask masks every key in a row, because scores becomes all -inf
before torch.softmax. Update the logic around the key_padding_mask handling in
the attention computation to detect fully masked rows before calling softmax,
and force those rows to produce zero attention output instead of passing them
into softmax; keep the fix localized to the existing attention flow around
scores, probs, and out.
In `@tests/test_attention.py`:
- Around line 360-367: The LARGE-GPU smoke test skip logic is overestimating
required memory because `_enough_gpu_memory()` already multiplies by 1.5, and
the caller currently passes `scores_bytes * 3`, which doubles the intended
threshold. Update the caller around the LARGE-GPU check in `test_attention.py`
to pass a value that matches the documented ~50GB peak, using the existing
`_enough_gpu_memory()` helper and the `scores_bytes` calculation so the combined
threshold aligns with the comment and does not skip on valid 80GB GPUs.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 2085bbcc-4114-4a65-b260-305416a3e967
📒 Files selected for processing (6)
docs/.nav.ymldocs/operators/README.mddocs/operators/attention.mdrl_engine/kernels/ops/pytorch/attention/standard_attn.pyrl_engine/kernels/registry.pytests/test_attention.py
- standard_attn: define fully key-padding-masked query rows as 0 (was NaN); guarded to the padding branch so the no-pad path is unchanged, row-independent so Axis-A holds; add test_fully_masked_query_returns_zero_not_nan - test: drop the double 1.5x margin in _enough_gpu_memory (LARGE skip now ~50 GB as documented, no longer over-skips 80 GB GPUs) - docs/attention.md: add text lang to the diagram fence (MD040); clarify that dispatch uses forward() input-dtype path, forward_fp32() is the explicit fp32 path
|
|
||
| # scale defaults to 1/sqrt(head_dim); `is not None` so an explicit 0.0 is kept. | ||
| scale = scale if scale is not None else (1.0 / math.sqrt(D)) | ||
| scores = torch.matmul(qf, kf.transpose(-1, -2)) * scale # [B, Hq, Sq, Skv] |
There was a problem hiding this comment.
padding changes the QK/softmax shape here, so a sequence padded to Skv=10 is not bitwise-equal to the same sequence at Skv=6. The new padding test fails locally with ~1e-6 drift. We need a fixed reduction path that actually skips masked keys, or we need to relax the Axis-A bitwise claim/tests.
|
this PR adds tests/test_attention.py, but CI never runs it. That means the new attention contract can be broken and the PR still goes green. Please add the new CPU-safe attention tests to this job. |
|
@inaniloquentee thanks! Both addressed in [c194db1]:
|
Thanks for adding the CI step, that part looks good. python -m pytest tests/test_attention.py -v -k "not large and not gpu" test_key_padding_mask_excludes_padded_keys FAILED So the direction of relaxing padding from bitwise to near-equality makes sense, but atol=1e-6 looks too tight / platform-sensitive. Could you bump it with a little headroom, e.g. 2e-6 or a small peak-relative tolerance? Also minor doc nit: the detailed Axis-A section now correctly excludes key_padding_mask from bitwise, but the test coverage line still says “slice + chunked + padding”, which can read like padding is still part of the bitwise Axis-A claim. |
key_padding_mask drift over differing reduction widths (Skv=10 vs 6) is ~1.3e-6 and platform-sensitive; atol=1e-6 failed locally for the reviewer. Bump the threshold to 2e-6 for headroom, and update the test-coverage doc line so padding reads as near-equality, not part of the bitwise Axis-A claim.
@inaniloquentee thanks for the careful follow-up! Both points addressed in 1e3a990:
|
KJLdefeated
left a comment
There was a problem hiding this comment.
Overall good and clean to me. Happy to approve once the requests are fullfilled.
| def test_gradient_flows(): | ||
| """fp32 autograd (the backward golden source) yields finite grads for q, k, v.""" | ||
| op = NativeAttentionOp() | ||
| q, k, v = _qkv(2, 8, 8, seed=8) | ||
| q, k, v = q.requires_grad_(True), k.requires_grad_(True), v.requires_grad_(True) | ||
| op.forward_fp32(q, k, v, causal=True).sum().backward() | ||
| for t in (q, k, v): | ||
| assert t.grad is not None and t.grad.shape == t.shape | ||
| assert torch.isfinite(t.grad).all() |
There was a problem hiding this comment.
isfinite can't tell a correct gradient from a wrong-but-finite one, and attention's backward (softmax Jacobian + dQ/dK/dV contractions) is the most error-prone in the stack.
Fix: check grads against autograd through the independent double-precision reference, with a random cotangent (not .sum(), which collapses the contraction):
def test_gradient_matches_reference():
q, k, v = _qkv(2, 8, 8, seed=8)
q, k, v = q.requires_grad_(True), k.requires_grad_(True), v.requires_grad_(True)
dy = torch.randn_like(NativeAttentionOp().forward_fp32(q, k, v, causal=True))
NativeAttentionOp().forward_fp32(q, k, v, causal=True).backward(dy)
qd, kd, vd = (t.detach().double().requires_grad_(True) for t in (q, k, v))
_ref_softmax_attn(qd, kd, vd, causal=True).backward(dy.double())
for t, td in ((q, qd), (k, kd), (v, vd)):
torch.testing.assert_close(t.grad, td.grad.float(), rtol=1e-4, atol=1e-4)
Summary
This pull request introduces a pure-PyTorch ground-truth reference implementation of standard softmax attention for the WS1 batch-invariant forward chain. It adds the
NativeAttentionOpoperator with numerical contract tests validating both batch invariance and per-dtype accuracy paths.Terminology
The PR employs WS1 alignment vocabulary from issue #108:
torch.equalto ensure train/sample-time numerics stay synchronizedMotivation
Issue #108 establishes the numerical contract for WS1's batch-invariant forward chain. The Qwen3-8B attention block uses grouped-query attention (
Hq=32,Hkv=8, group 4,head_dim=128,scale=1/√128). This PR covers the softmax attention stage:out = softmax(Q Kᵀ · scale + masks) @ V, computed with a hand-written naive softmax (subtract-max) over a fixed reduction order, deliberately avoidingF.scaled_dot_product_attention/ flash whose reduction order is unspecified and would break the batch-invariance contract.Changes
Operator Implementation (
rl_engine/kernels/ops/pytorch/attention/standard_attn.py):NativeAttentionOp: implementssoftmax(Q Kᵀ · scale + masks) @ Vwithforward()andforward_fp32()dual pathsrepeat_interleave(query headh→ KV headh // g); causal offsetSkv - Sq + 1valid for prefill (Sq==Skv) and decode (Sq<Skv);key_padding_mask [B, Skv]withTrue=valid;scaledefault1/√128forward_fp32()disables TF32 + autocast for a strict fp32 reference; pure function, fp32 accumulation, no in-place mutationRegistry Integration (
rl_engine/kernels/registry.py):PYTORCH_NATIVE_ATTENTIONinOpBackendattentiondispatch across cuda, rocm, and cpu priority maps, distinct from the productionattn/PYTORCH_ATTNSDPA fallbackTest Suite (
tests/test_attention.py):Documentation (
docs/operators/attention.md):docs/.nav.yml/docs/operators/README.mdwiring)Contract Fulfillment
forward_fp32()with fixed-seedtorch.Generator, fixed reduction order, no SDPA/flashtorch.equal); Axis-B rel-peak tolerance (bf16 ~0.56% / 3%, fp16 ~0.07% / 0.5%) vs independent fp32 referenceHq=32,Hkv=8,head_dim=128,scale=1/√128)Test Environment
Checklist
F.scaled_dot_product_attention/ flashattn)Notes
The naive path materializes the full
[B, Hq, Sq, Skv]scores tensor (no query-chunking), so the LARGE Qwen3-8B load point (B=8,Skv=4096) is memory-heavy and GPU-only; the corresponding smoke test skips at runtime when free GPU memory is insufficient.Summary by CodeRabbit
attentiondispatch on CPU/CUDA/ROCm, including GQA head expansion and configurable scaling.