Skip to content

:feat(ws1): add NativeAttentionOp pure-PyTorch standard-softmax reference#188

Open
maxiaosong1124 wants to merge 7 commits into
RL-Align:mainfrom
maxiaosong1124:feat/ws1-attention-pytorch-op
Open

:feat(ws1): add NativeAttentionOp pure-PyTorch standard-softmax reference#188
maxiaosong1124 wants to merge 7 commits into
RL-Align:mainfrom
maxiaosong1124:feat/ws1-attention-pytorch-op

Conversation

@maxiaosong1124

@maxiaosong1124 maxiaosong1124 commented Jun 24, 2026

Copy link
Copy Markdown
Collaborator

Summary

This pull request introduces a pure-PyTorch ground-truth reference implementation of standard softmax attention for the WS1 batch-invariant forward chain. It adds the NativeAttentionOp operator with numerical contract tests validating both batch invariance and per-dtype accuracy paths.

Terminology

The PR employs WS1 alignment vocabulary from issue #108:

  • Axis-A (batch invariance): Row outputs must remain independent of batch composition, validated bitwise using torch.equal to ensure train/sample-time numerics stay synchronized
  • Axis-B (accuracy): Low-precision (bf16/fp16) forward paths computing in the input dtype, asserted within a per-dtype tolerance against an independent fp32 reference formula

Motivation

Issue #108 establishes the numerical contract for WS1's batch-invariant forward chain. The Qwen3-8B attention block uses grouped-query attention (Hq=32, Hkv=8, group 4, head_dim=128, scale=1/√128). This PR covers the softmax attention stage: out = softmax(Q Kᵀ · scale + masks) @ V, computed with a hand-written naive softmax (subtract-max) over a fixed reduction order, deliberately avoiding F.scaled_dot_product_attention / flash whose reduction order is unspecified and would break the batch-invariance contract.

Changes

Operator Implementation (rl_engine/kernels/ops/pytorch/attention/standard_attn.py):

  • NativeAttentionOp: implements softmax(Q Kᵀ · scale + masks) @ V with forward() and forward_fp32() dual paths
  • GQA via repeat_interleave (query head h → KV head h // g); causal offset Skv - Sq + 1 valid for prefill (Sq==Skv) and decode (Sq<Skv); key_padding_mask [B, Skv] with True=valid; scale default 1/√128
  • forward_fp32() disables TF32 + autocast for a strict fp32 reference; pure function, fp32 accumulation, no in-place mutation

Registry Integration (rl_engine/kernels/registry.py):

  • Registers PYTORCH_NATIVE_ATTENTION in OpBackend
  • Routes attention dispatch across cuda, rocm, and cpu priority maps, distinct from the production attn / PYTORCH_ATTN SDPA fallback

Test Suite (tests/test_attention.py):

  • 22 test cases covering fp32 correctness vs an independent reference, closed-form causal/decode checks, GQA replication and divisibility guard, scale defaults, key-padding masking, dtype-path accuracy, batch/chunked/padding invariance, input purity, gradient flow, registry dispatch, and a GPU-only LARGE Qwen3-8B real-shape smoke test

Documentation (docs/operators/attention.md):

  • Mathematical formula, tensor contract, dual-path semantics, dispatch behavior, accuracy table, and test coverage details (plus docs/.nav.yml / docs/operators/README.md wiring)

Contract Fulfillment

#108 Requirement Implementation
Deterministic reference forward_fp32() with fixed-seed torch.Generator, fixed reduction order, no SDPA/flash
Per-dtype policy Axis-A bitwise (torch.equal); Axis-B rel-peak tolerance (bf16 ~0.56% / 3%, fp16 ~0.07% / 0.5%) vs independent fp32 reference
Batch-config sweep Full-batch computation with bitwise-identical sliced / chunked / padded row assertions
Realistic shapes Tests run at Qwen3-8B attention dims (Hq=32, Hkv=8, head_dim=128, scale=1/√128)

Test Environment

  • OS: Ubuntu 22.04 (kernel 5.15.0-124-generic)
  • Python: 3.12.3
  • PyTorch: 2.8.0+cu128
  • CUDA/cuDNN: 12.8 / 9.10.02

Checklist

  • ✓ Pure-PyTorch reference, no custom extension required
  • ✓ Hand-written naive softmax (fixed reduction order); no F.scaled_dot_product_attention / flash
  • ✓ Covered at Qwen3-8B GQA dims (32/8/128, scale 1/√128); causal prefill + decode
  • ✓ Axis-A bitwise batch invariance enforced (slice + chunked + padding)
  • ✓ Axis-B per-dtype tolerance calibrated against measured drift
  • ✓ Registered in OpBackend + cuda/rocm/cpu priority maps (distinct from production attn)
  • ✓ 21 tests pass locally; 1 GPU-only LARGE smoke skips without sufficient GPU memory

Notes

The naive path materializes the full [B, Hq, Sq, Skv] scores tensor (no query-chunking), so the LARGE Qwen3-8B load point (B=8, Skv=4096) is memory-heavy and GPU-only; the corresponding smoke test skips at runtime when free GPU memory is insufficient.

Summary by CodeRabbit

  • New Features
    • Added a PyTorch-native “standard” softmax attention operator and enabled attention dispatch on CPU/CUDA/ROCm, including GQA head expansion and configurable scaling.
  • Bug Fixes
    • Improved fully-masked query-row handling to avoid NaNs and introduced a strict fp32 reference path for consistent results under autocast/TF32.
  • Documentation
    • Added “Standard Attention” documentation and linked it in the Operators navigation.
  • Tests
    • Added an end-to-end attention correctness/semantics test suite and updated CI to run CPU-safe attention tests by default (GPU/large cases where applicable).

WS1 ground-truth attention op for issue RL-Align#108 (Qwen3-8B GQA attention):
- NativeAttentionOp: out = softmax(Q Kᵀ * scale + masks) @ V, a hand-written
  naive softmax (NOT F.scaled_dot_product_attention / flash) so the reduction
  order is fixed for the batch-invariance contract. GQA 32/8 via
  repeat_interleave, causal offset Skv-Sq+1 (prefill + decode), key_padding_mask
  (True=valid), scale default 1/sqrt(128).
Exposes the forward / forward_fp32 dual-path contract (fp32 ground truth +
dtype-behavior path); forward_fp32 disables TF32/autocast for a strict fp32
reference. Pure function, fp32 accumulation.

- register PYTORCH_NATIVE_ATTENTION in OpBackend and the cuda/rocm/cpu priority
  maps under op_type "attention" (distinct from the production "attn" /
  PYTORCH_ATTN SDPA fallback)
- tests/test_attention.py: forward_fp32 vs independent fp32 reference, closed-form
  causal/decode, GQA replication + divisibility guard, scale, key-padding,
  dtype-path accuracy (Axis-B), Axis-A batch invariance (slice + chunked +
  padding), purity, gradient flow, registry dispatch, GPU-only LARGE Qwen3-8B smoke
- docs/operators/attention.md + nav/index wiring
@coderabbitai

coderabbitai Bot commented Jun 24, 2026

Copy link
Copy Markdown

Review Change Stack

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 3c6fdd52-01a8-42fb-8fe5-b87b1b743d3f

📥 Commits

Reviewing files that changed from the base of the PR and between c194db1 and 1e3a990.

📒 Files selected for processing (2)
  • docs/operators/attention.md
  • tests/test_attention.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • tests/test_attention.py
  • docs/operators/attention.md

📝 Walkthrough

Walkthrough

Adds a PyTorch-native reference softmax attention operator, registers it under "attention", and documents and tests its fp32, masking, GQA, scaling, and accuracy behavior.

Changes

NativeAttentionOp — implementation, registry, tests, and docs

Layer / File(s) Summary
NativeAttentionOp implementation
rl_engine/kernels/ops/pytorch/attention/standard_attn.py
Defines NativeAttentionOp, adds forward and forward_fp32, implements GQA expansion, causal and key-padding masking, fully masked row handling, scale defaulting, and strict fp32 math handling.
Registry backend and dispatch
rl_engine/kernels/registry.py
Adds PYTORCH_NATIVE_ATTENTION and routes the "attention" operator key to it on CUDA, ROCm, and CPU alongside the existing "attn" mapping.
Attention test suite
tests/test_attention.py
Adds an independent fp32 reference plus tests for fp32 correctness, strict fp32 behavior, causal and decode semantics, GQA rules, scaling and masking, low-precision accuracy, output shape, batch invariance, purity, gradients, registry dispatch, and a GPU-only smoke test.
Attention docs and navigation
docs/.nav.yml, docs/operators/README.md, docs/operators/attention.md
Adds the new attention operator documentation and links it into the docs nav and operators index, covering the contract, dispatch, accuracy expectations, tests, references, and limitations.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Suggested labels

component: kernels

Suggested reviewers

  • Flink-ddd
  • KJLdefeated
  • inaniloquentee

Poem

🐇 Hop, hop—attention takes the stage,
Softmax sings through every page,
Q, K, V now dance in line,
fp32 keeps the truth just fine,
The registry found its new trail bright,
This bunny says: "Looks softmax-right!"

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly summarizes the main change: adding a pure-PyTorch NativeAttentionOp reference implementation.
Docstring Coverage ✅ Passed Docstring coverage is 93.33% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@docs/operators/attention.md`:
- Around line 16-20: Add a language tag to the fenced block in the attention
markdown doc so it no longer triggers MD040; update the existing fence around
the Q/K/V diagram to use a text-style code block while keeping the diagram
content unchanged.
- Around line 77-80: The attention dispatch description conflates backend
resolution with dtype-specific execution; update the
`kernel_registry.get_op("attention")` explanation to say it selects
`NativeAttentionOp` via the `OpBackend` priority map, while the public
`__call__` path dispatches through `forward(...)` based on input dtype rather
than always using `forward_fp32(...)`. Clarify that backend selection and API
path selection are separate, and keep the wording aligned with the
`PYTORCH_NATIVE_ATTENTION`/`NativeAttentionOp` behavior described in this
section.

In `@rl_engine/kernels/ops/pytorch/attention/standard_attn.py`:
- Around line 164-169: The attention path in standard_attn.py can produce NaNs
when key_padding_mask masks every key in a row, because scores becomes all -inf
before torch.softmax. Update the logic around the key_padding_mask handling in
the attention computation to detect fully masked rows before calling softmax,
and force those rows to produce zero attention output instead of passing them
into softmax; keep the fix localized to the existing attention flow around
scores, probs, and out.

In `@tests/test_attention.py`:
- Around line 360-367: The LARGE-GPU smoke test skip logic is overestimating
required memory because `_enough_gpu_memory()` already multiplies by 1.5, and
the caller currently passes `scores_bytes * 3`, which doubles the intended
threshold. Update the caller around the LARGE-GPU check in `test_attention.py`
to pass a value that matches the documented ~50GB peak, using the existing
`_enough_gpu_memory()` helper and the `scores_bytes` calculation so the combined
threshold aligns with the comment and does not skip on valid 80GB GPUs.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 2085bbcc-4114-4a65-b260-305416a3e967

📥 Commits

Reviewing files that changed from the base of the PR and between be5ec9b and 39df0d2.

📒 Files selected for processing (6)
  • docs/.nav.yml
  • docs/operators/README.md
  • docs/operators/attention.md
  • rl_engine/kernels/ops/pytorch/attention/standard_attn.py
  • rl_engine/kernels/registry.py
  • tests/test_attention.py

Comment thread docs/operators/attention.md Outdated
Comment thread docs/operators/attention.md Outdated
Comment thread rl_engine/kernels/ops/pytorch/attention/standard_attn.py
Comment thread tests/test_attention.py Outdated
- standard_attn: define fully key-padding-masked query rows as 0 (was NaN);
  guarded to the padding branch so the no-pad path is unchanged, row-independent
  so Axis-A holds; add test_fully_masked_query_returns_zero_not_nan
- test: drop the double 1.5x margin in _enough_gpu_memory (LARGE skip now ~50 GB
  as documented, no longer over-skips 80 GB GPUs)
- docs/attention.md: add text lang to the diagram fence (MD040); clarify that
  dispatch uses forward() input-dtype path, forward_fp32() is the explicit fp32 path

# scale defaults to 1/sqrt(head_dim); `is not None` so an explicit 0.0 is kept.
scale = scale if scale is not None else (1.0 / math.sqrt(D))
scores = torch.matmul(qf, kf.transpose(-1, -2)) * scale # [B, Hq, Sq, Skv]

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

padding changes the QK/softmax shape here, so a sequence padded to Skv=10 is not bitwise-equal to the same sequence at Skv=6. The new padding test fails locally with ~1e-6 drift. We need a fixed reduction path that actually skips masked keys, or we need to relax the Axis-A bitwise claim/tests.

@inaniloquentee

Copy link
Copy Markdown
Collaborator

this PR adds tests/test_attention.py, but CI never runs it. That means the new attention contract can be broken and the PR still goes green. Please add the new CPU-safe attention tests to this job.

@maxiaosong1124 maxiaosong1124 requested a review from bitborne as a code owner June 25, 2026 07:13
@maxiaosong1124

Copy link
Copy Markdown
Collaborator Author

@inaniloquentee thanks! Both addressed in [c194db1]:

  1. CI: added a CPU-safe step running pytest tests/test_attention.py -k "not large and not gpu" (excludes the two GPU-only tests). Locally 21 passed.

  2. Padding bitwise: relaxed the claim/test. Padding changes the softmax reduction width (Skv=10 vs 6), so bitwise equality isn't guaranteed in IEEE 754. Test now uses allclose(atol=1e-6), and the Axis-A bullet in the docs excludes key_padding_mask from the atol=0 claim (slicing + chunked stay bitwise).

@Flink-ddd Flink-ddd requested a review from EthanZero2Hero June 25, 2026 15:53
@inaniloquentee

Copy link
Copy Markdown
Collaborator

@inaniloquentee thanks! Both addressed in [c194db1]:

  1. CI: added a CPU-safe step running pytest tests/test_attention.py -k "not large and not gpu" (excludes the two GPU-only tests). Locally 21 passed.
  2. Padding bitwise: relaxed the claim/test. Padding changes the softmax reduction width (Skv=10 vs 6), so bitwise equality isn't guaranteed in IEEE 754. Test now uses allclose(atol=1e-6), and the Axis-A bullet in the docs excludes key_padding_mask from the atol=0 claim (slicing + chunked stay bitwise).

Thanks for adding the CI step, that part looks good.
One issue is still not fully addressed on my side though: the new CPU-safe command still fails locally with the current padding tolerance.
I ran:

python -m pytest tests/test_attention.py -v -k "not large and not gpu"
and got:

test_key_padding_mask_excludes_padded_keys FAILED
Padding-masked result diverges from valid-only by 1.31e-06 > 1e-06

So the direction of relaxing padding from bitwise to near-equality makes sense, but atol=1e-6 looks too tight / platform-sensitive. Could you bump it with a little headroom, e.g. 2e-6 or a small peak-relative tolerance?

Also minor doc nit: the detailed Axis-A section now correctly excludes key_padding_mask from bitwise, but the test coverage line still says “slice + chunked + padding”, which can read like padding is still part of the bitwise Axis-A claim.

key_padding_mask drift over differing reduction widths (Skv=10 vs 6) is
~1.3e-6 and platform-sensitive; atol=1e-6 failed locally for the reviewer.
Bump the threshold to 2e-6 for headroom, and update the test-coverage doc
line so padding reads as near-equality, not part of the bitwise Axis-A claim.
@maxiaosong1124

Copy link
Copy Markdown
Collaborator Author

@inaniloquentee thanks! Both addressed in [c194db1]:

  1. CI: added a CPU-safe step running pytest tests/test_attention.py -k "not large and not gpu" (excludes the two GPU-only tests). Locally 21 passed.
  2. Padding bitwise: relaxed the claim/test. Padding changes the softmax reduction width (Skv=10 vs 6), so bitwise equality isn't guaranteed in IEEE 754. Test now uses allclose(atol=1e-6), and the Axis-A bullet in the docs excludes key_padding_mask from the atol=0 claim (slicing + chunked stay bitwise).

Thanks for adding the CI step, that part looks good. One issue is still not fully addressed on my side though: the new CPU-safe command still fails locally with the current padding tolerance. I ran:

python -m pytest tests/test_attention.py -v -k "not large and not gpu" and got:

test_key_padding_mask_excludes_padded_keys FAILED Padding-masked result diverges from valid-only by 1.31e-06 > 1e-06

So the direction of relaxing padding from bitwise to near-equality makes sense, but atol=1e-6 looks too tight / platform-sensitive. Could you bump it with a little headroom, e.g. 2e-6 or a small peak-relative tolerance?

Also minor doc nit: the detailed Axis-A section now correctly excludes key_padding_mask from bitwise, but the test coverage line still says “slice + chunked + padding”, which can read like padding is still part of the bitwise Axis-A claim.

@inaniloquentee thanks for the careful follow-up! Both points addressed in 1e3a990:

  1. Padding tolerance. You're right that atol=1e-6 was too tight / platform-sensitive — the drift comes from the differing softmax reduction width (Skv=10 vs 6), not from the masking semantics. Bumped _PADDING_ATOL to 2e-6, which clears the observed ~1.3e-6 with headroom. pytest tests/test_attention.py -k "not large and not gpu" → 21 passed locally, including test_key_padding_mask_excludes_padded_keys. Kept rtol=0.0 so it stays an absolute-error check.
  2. Doc nit. Good catch — the coverage line did still read as if padding were part of the bitwise Axis-A claim. Updated it to slice + chunked, bitwise; padding is near-equality only, see below, so it now matches the detailed Axis-A section that excludes key_padding_mask from the atol=0 claim.

@KJLdefeated KJLdefeated left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall good and clean to me. Happy to approve once the requests are fullfilled.

Comment thread tests/test_attention.py
Comment on lines +380 to +388
def test_gradient_flows():
"""fp32 autograd (the backward golden source) yields finite grads for q, k, v."""
op = NativeAttentionOp()
q, k, v = _qkv(2, 8, 8, seed=8)
q, k, v = q.requires_grad_(True), k.requires_grad_(True), v.requires_grad_(True)
op.forward_fp32(q, k, v, causal=True).sum().backward()
for t in (q, k, v):
assert t.grad is not None and t.grad.shape == t.shape
assert torch.isfinite(t.grad).all()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isfinite can't tell a correct gradient from a wrong-but-finite one, and attention's backward (softmax Jacobian + dQ/dK/dV contractions) is the most error-prone in the stack.
Fix: check grads against autograd through the independent double-precision reference, with a random cotangent (not .sum(), which collapses the contraction):

def test_gradient_matches_reference():
    q, k, v = _qkv(2, 8, 8, seed=8)
    q, k, v = q.requires_grad_(True), k.requires_grad_(True), v.requires_grad_(True)
    dy = torch.randn_like(NativeAttentionOp().forward_fp32(q, k, v, causal=True))
    NativeAttentionOp().forward_fp32(q, k, v, causal=True).backward(dy)

    qd, kd, vd = (t.detach().double().requires_grad_(True) for t in (q, k, v))
    _ref_softmax_attn(qd, kd, vd, causal=True).backward(dy.double())
    for t, td in ((q, qd), (k, kd), (v, vd)):
        torch.testing.assert_close(t.grad, td.grad.float(), rtol=1e-4, atol=1e-4)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants