Skip to content

feat(ws1): NativeRMSNormOp pure-PyTorch ground-truth reference + numerical contract tests#160

Open
maxiaosong1124 wants to merge 7 commits into
RL-Align:mainfrom
maxiaosong1124:feat/ws1-rms_norm-pytorch-op
Open

feat(ws1): NativeRMSNormOp pure-PyTorch ground-truth reference + numerical contract tests#160
maxiaosong1124 wants to merge 7 commits into
RL-Align:mainfrom
maxiaosong1124:feat/ws1-rms_norm-pytorch-op

Conversation

@maxiaosong1124

@maxiaosong1124 maxiaosong1124 commented Jun 20, 2026

Copy link
Copy Markdown
Collaborator

Summary

Adds the pure-PyTorch ground-truth reference op for RMSNorm (pre-norm / QK-Norm)
as the first WS1 batch-invariant operator built on top of the numerical contract
defined in #108. Ships the op, its registry wiring, and a 16-case test suite that
pins down both alignment axes (Axis-A bitwise batch invariance, Axis-B per-dtype
tolerance).

Refs #108

Terminology

This PR uses the WS1 alignment vocabulary from #108:

  • Axis-A — batch invariance (reproducibility). A row's output must not depend on
    how many rows share the batch (batch size, slicing, padding). Asserted bitwise
    (torch.equal). This is what keeps train-time (large batch) and sample-time
    (small batch / dynamic padding) numerics identical so the policy ratio doesn't drift.
  • Axis-B — accuracy. The low-precision (bf16 / fp16) forward must stay within a
    documented per-dtype tolerance of the fp32 ground-truth. Asserted with torch.allclose
    • per-dtype thresholds.

Motivation / Context

#108 establishes the ground-truth harness and numerical contract for the WS1
batch-invariant forward chain. RMSNorm is required on two normalized dims of the
target model (Qwen3-8B dense):

  • hidden = 4096 — input / post-attention norm
  • head_dim = 128 — QK-Norm (per-head RMSNorm on Q and K)

This PR provides the deterministic fp32 reference path those downstream kernels
(Triton / CUDA / ROCm RMSNorm) will be validated against.

Changes

  • rl_engine/kernels/ops/pytorch/norm/rms_norm.pyNativeRMSNormOp
    • forward() — accumulate in fp32, cast result back to x.dtype (Axis-B candidate path)
    • forward_fp32() — fp32 accumulation, forced fp32 output (ground-truth / backward golden source)
    • Formula: out = x * rsqrt(mean(x^2, dim=-1) + eps) * weight
    • eps lives inside the sqrt; plain weight scaling (not the 1 + weight variant)
    • Shape guard: weight must be 1-D of size x.shape[-1]
  • rl_engine/kernels/registry.py — register PYTORCH_NATIVE_RMS_NORM
    and add rms_norm dispatch to the cuda / rocm / cpu priority maps
  • tests/test_rms_norm.py — 16 tests (details below)

How this satisfies the #108 contract

#108 requirement How it's met here
Deterministic reference path, fixed reduction order forward_fp32() accumulates in fp32 along dim=-1; tests use fixed-seed torch.Generator so outputs are reproducible
Per-dtype tolerance policy (bitwise vs tight-tolerance) Axis-A asserted bitwise (torch.equal); Axis-B asserted within documented per-dtype thresholds — bf16 atol=2e-2, rtol=1.6e-2, fp16 atol=1e-3, rtol=1e-3
Batch-config sweep / validation helper Batch-invariance checks compute on the full batch, then assert sliced/padded rows are bitwise identical to their full-batch counterparts
Both normalized dims covered Every correctness/invariance test is parametrized over hidden=4096 and head_dim=128

Test Environment

OS Ubuntu 22.04.5 LTS (kernel 5.15.0-122-generic)
Python 3.12.3
PyTorch 2.8.0+cu128
CUDA / cuDNN 12.8 / 9.10.02 (driver 580.65.06)
pytest 9.0.3
GPU NVIDIA H20

Testing

Run from the repo root with python -m pytest (the -m form puts the repo on

python -m pytest tests/test_rms_norm.py

→ 16 passed, covering:

- Correctness vs an independent hand-written fp32 formula (bitwise, both dims)
- Axis-A batch invariance: row output is independent of batch size — slice and
padding variants, asserted bitwise
- dtype paths: forward follows input dtype; forward_fp32 forces fp32
- Axis-B low-precision (bf16 / fp16) within tolerance of the fp32 reference
- eps inside sqrt (zero input → finite zero output)
- plain weight scaling (rules out the 1 + weight variant)
- shape guard fires on wrong-size / non-1-D weight
- purity (inputs not mutated in place)
- gradient flow (fp32 autograd = backward golden source)
- registry dispatch resolves rms_norm → NativeRMSNormOp

Rebased onto latest upstream/main; registry dispatch for the neighboring
ratio_kl / grpo_loss ops verified intact after conflict resolution.

Checklist

- [x] Pure-PyTorch reference, no custom extension required
- [x] Both Qwen3-8B normalized dims (4096, 128) covered
- [x] Axis-A bitwise batch invariance enforced
- [x] Axis-B per-dtype tolerance documented and tested
- [x] Registered in OpBackend + cuda/rocm/cpu priority maps
- [x] All 16 tests pass locally

---

<!-- This is an auto-generated comment: release notes by coderabbit.ai -->
## Summary by CodeRabbit

* **New Features**
  * Added RMSNorm backend support across CUDA, ROCm, and CPU, including a native pure-PyTorch reference implementation.
  * Supports fp16/bf16 execution while preserving input dtype, plus an option to force fp32 outputs.
  * Enforces weight shape requirements and correct `eps` handling, with proper output dtype casting.

* **Tests**
  * Added comprehensive pytest coverage for correctness vs fp32 references, dtype behavior, shape/guard errors, input immutability, gradient finiteness, batch-slice invariance, and backend dispatch for `"rms_norm"`.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

- NativeRMSNormOp with forward / forward_fp32 (fp32 ground-truth path)
- covers both normalized dims: hidden=4096 and head_dim=128 (Qwen3 QK-Norm)
- register PYTORCH_NATIVE_RMS_NORM in OpBackend + cpu/cuda/rocm priority map
- tests/test_rms_norm.py: axis-A bitwise batch invariance + dtype tolerance,
  shape guard, purity, gradient flow, registry dispatch (16 tests)

Refs RL-Align#108
@coderabbitai

coderabbitai Bot commented Jun 20, 2026

Copy link
Copy Markdown

Review Change Stack

📝 Walkthrough

Walkthrough

Adds NativeRMSNormOp, a pure-PyTorch RMSNorm reference implementation, to rl_engine/kernels/ops/pytorch/norm/rms_norm.py. Registers it as OpBackend.PYTORCH_NATIVE_RMS_NORM in KernelRegistry with dispatch entries for cuda, rocm, and cpu. A new pytest module validates correctness, dtype routing, batch invariance, shape guards, purity, gradients, and registry dispatch.

Changes

NativeRMSNormOp: Implementation, Registry Wiring, and Tests

Layer / File(s) Summary
NativeRMSNormOp class and core _rms_norm math
rl_engine/kernels/ops/pytorch/norm/rms_norm.py
Defines NativeRMSNormOp with __call__/forward (fp32 accumulation, casts output to x.dtype), forward_fp32 (forces float32 output), and static _rms_norm that validates weight shape, computes rsqrt(mean(x²) + eps) * weight, and casts to output_dtype.
OpBackend enum and KernelRegistry dispatch wiring
rl_engine/kernels/registry.py
Adds PYTORCH_NATIVE_RMS_NORM to OpBackend with the NativeRMSNormOp import path, and extends KernelRegistry._priority_map with rms_norm entries for cuda, rocm, and cpu.
Test suite: correctness, dtype, guards, purity, gradients, registry
tests/test_rms_norm.py
Validates NativeRMSNormOp against a manual fp32 reference for two normalized dimensions, batch/padding invariance (bitwise equality), dtype routing, bf16/fp16 tolerances, eps/zero-input finiteness, linear weight scaling, ValueError on bad weight shapes, input non-mutation, gradient finiteness, and kernel_registry dispatch.

Possibly Related Issues

Suggested Reviewers

  • KJLdefeated
  • inaniloquentee
  • Flink-ddd

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🐇 A little kernel woke up one day,
And said, “Let’s norm these weights the PyTorch way!”
With rsqrt and mean(x²) in sight,
fp32 accumulation shining bright,
On CUDA, ROCm, CPU it’ll run — hooray! 🌟

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 23.81% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly summarizes the main change: a native pure-PyTorch RMSNorm reference implementation with contract tests.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@rl_engine/kernels/ops/pytorch/norm/rms_norm.py`:
- Around line 5-6: The file rms_norm.py has linting and formatting violations
from black, isort, and trailing-whitespace checks that are blocking CI. Run the
project's pre-commit formatting hooks (typically via a command like 'pre-commit
run --all-files' or 'black' and 'isort' individually) to automatically reformat
the file and fix signature spacing issues around lines 25-31, expression
formatting issues around lines 38-44 and 57-67, and any trailing whitespace
violations. Commit the reformatted file after the hooks complete.

In `@tests/test_rms_norm.py`:
- Around line 10-11: The test_rms_norm.py file has formatting inconsistencies
detected by Black. Run the Black formatter on this file to automatically fix
alignment and spacing issues in inline comments (like those on the _HIDDEN and
_HEAD_DIM constant definitions) and long assertions (around the test assertion
blocks). Apply Black's output and commit the formatted result to resolve the CI
formatting check.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 34ccd0aa-23c0-4f09-8d2a-e7d333cea8a2

📥 Commits

Reviewing files that changed from the base of the PR and between 9dfcdbc and 5396d27.

📒 Files selected for processing (3)
  • rl_engine/kernels/ops/pytorch/norm/rms_norm.py
  • rl_engine/kernels/registry.py
  • tests/test_rms_norm.py

Comment thread rl_engine/kernels/ops/pytorch/norm/rms_norm.py
Comment thread tests/test_rms_norm.py Outdated
Resolve CodeRabbit formatting findings on RL-Align#160: black (line-length=100),
isort (profile=black), trailing-whitespace and EOF fixes. No logic change;
16 tests still pass.
@Flink-ddd Flink-ddd added platform: cuda Specific optimizations or bugs in NVIDIA graphics cards (such as FlashInfer, TMA optimizations) priority: high Severe congestion issues require the highest priority for resolution. sprint-0615 labels Jun 21, 2026
@Flink-ddd Flink-ddd requested a review from EthanZero2Hero June 21, 2026 13:10

@Flink-ddd Flink-ddd left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a solid, production-ready PR that strictly adheres to the #108 numerical contract. Here are a few professional suggestions for refinement before merging:

Comment thread tests/test_rms_norm.py
w = _rand((_HIDDEN,), seed=14).requires_grad_(True)
op.forward_fp32(x, w).sum().backward()
assert torch.isfinite(x.grad).all() and torch.isfinite(w.grad).all()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have thoroughly covered forward Axis-A (Batch Invariance), but WS1 requires gradients to be batch-invariant as well (vital for #153). Currently, test_gradient_flows only checks if gradients exist and are finite. You need to prove that slicing a batch yields identical gradients to computing the full batch.

Code Example:

# Add this new test case
def test_backward_batch_invariance_slice():
    """Axis A: Gradients must be bitwise identical regardless of batch size."""
    op = NativeRMSNormOp()
    
    # Full batch forward & backward
    w_full = _rand((_HIDDEN,), seed=1).requires_grad_(True)
    x_full = _rand((8, 32, _HIDDEN), seed=2).requires_grad_(True)
    out_full = op.forward_fp32(x_full, w_full)
    
    dy_full = _rand(out_full.shape, seed=3)
    out_full.backward(dy_full)
    grad_x_full_sliced = x_full.grad[:1].clone()
    
    # Sliced batch forward & backward (batch size = 1)
    w_slice = _rand((_HIDDEN,), seed=1).requires_grad_(True)
    x_slice = _rand((8, 32, _HIDDEN), seed=2)[:1].detach().requires_grad_(True)
    out_slice = op.forward_fp32(x_slice, w_slice)
    
    out_slice.backward(dy_full[:1]) # Use matching slice of upstream grad
    
    # Assert x.grad is bitwise identical
    assert torch.equal(x_slice.grad, grad_x_full_sliced)

f"got tuple(weight.shape)={tuple(weight.shape)}"
)
x_f = x.float()
var = x_f.pow(2).mean(dim=-1, keepdim=True)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The line var = x_f.pow(2).mean(...) explicitly materializes a full-sized FP32 tensor in memory before reducing. For a reference/ground-truth operator, code readability is more important than VRAM efficiency, so this is perfectly fine here. However, keep this memory spike in mind when reviewing the downstream Triton/CUDA implementations for long-context workloads. No code change is strictly required for this PR.

@KJLdefeated KJLdefeated left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implementation looks correct — standard RMSNorm, fp32 accumulation, eps inside the sqrt, plain weight scaling, clean shape guard. Registry wiring matches the existing attn pattern. Happy to approve once request changes are addressed.

Comment thread tests/test_rms_norm.py
Comment on lines +30 to +33
def test_forward_fp32_matches_manual_reference(N):
op = NativeRMSNormOp()
x, w = _rand((2, 16, N), seed=0), _rand((N,), seed=1)
assert torch.equal(op.forward_fp32(x, w), _manual_rms_norm(x, w))

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_manual_rms_norm is the same expression as _rms_norm in the same float order, which is why torch.equal passes bitwise. The downside is that if there is a bug in the formula (eps placement, wrong reduction dim) would be present in both sides and still go green.
Note it has to be allclose, not torch.equal since F.rms_norm may reduce in a different order, and that non-identical reduction is exactly what makes it a real check:

import torch.nn.functional as F

@pytest.mark.parametrize("N", [_HIDDEN, _HEAD_DIM])
def test_forward_fp32_matches_torch_reference(N):
    op = NativeRMSNormOp()
    x, w = _rand((2, 16, N), seed=0), _rand((N,), seed=1)
    ref = F.rms_norm(x.float(), (N,), weight=w.float(), eps=_EPS)
    torch.testing.assert_close(op.forward_fp32(x, w), ref, rtol=1e-6, atol=1e-6)

Fine to keep the hand-written reference as a secondary sanity check, but it shouldn't be the primary correctness assertion.

…eck + backward batch-invariance

- KJLdefeated: add test_forward_fp32_matches_torch_reference comparing against
  PyTorch's F.rms_norm via assert_close (tolerance, not torch.equal, since its
  reduction order may differ) so a shared formula bug can't pass green. Keep the
  hand-written _manual_rms_norm test as a secondary bitwise sanity check.
- Flink-ddd: add test_backward_batch_invariance_slice proving input gradients are
  bitwise identical regardless of batch size (Axis-A for gradients, needed for RL-Align#153).
@maxiaosong1124

Copy link
Copy Markdown
Collaborator Author

@KJLdefeated You're right — _manual_rms_norm uses the same expression and float order as the op, so torch.equal only proves they're copies of each other, not that the formula is correct; a shared bug (eps placement, wrong reduction dim) would pass green on both sides. Addressed in 64f6f56:

  • Added test_forward_fp32_matches_torch_reference as the primary correctness check, comparing forward_fp32 against PyTorch's own F.rms_norm. Used torch.testing.assert_close(rtol=1e-6, atol=1e-6) rather than torch.equal exactly because F.rms_norm is free to reduce in a different float order — that non-identical reduction is what makes it a real independent check.
  • Kept the hand-written _manual_rms_norm test as a secondary bitwise sanity check that pins the exact reference semantics.

Thanks for the careful catch!

@maxiaosong1124

Copy link
Copy Markdown
Collaborator Author

This is a solid, production-ready PR that strictly adheres to the #108 numerical contract. Here are a few professional suggestions for refinement before merging:

@Flink-ddd Good point — forward batch-invariance isn't enough; WS1 (and #153) needs the backward to be batch-invariant too, and test_gradient_flows only checked isfinite. Addressed in 64f6f56:

Added test_backward_batch_invariance_slice following your example — it runs the full-batch forward+backward, then a batch-of-1 recompute fed the matching slice of the upstream gradient, and asserts torch.equal(x_slice.grad, grad_x_full_sliced) so the input gradient is bitwise identical regardless of batch size.

On your other note (var = x_f.pow(2).mean(...) materializing a full fp32 tensor): left as-is for this PR, since you flagged it as no-change-required — readability over VRAM for the reference op. I've noted it as a watch-item for the downstream Triton/CUDA kernels on long-context workloads.

All 19 tests in tests/test_rms_norm.py pass and black --check is clean. Thanks for the review!

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tests/test_rms_norm.py`:
- Around line 144-159: The backward invariance test only covers the `_HIDDEN`
normalization width, so it should be expanded to validate both Qwen3 RMSNorm
widths. Update `test_backward_batch_invariance_slice` to run the same
forward/backward slice comparison for both `_HIDDEN` and `_HEAD_DIM` (for
example via parametrization), keeping the existing `NativeRMSNormOp`, `_rand`,
and gradient slice assertions unchanged.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 5bc24ea0-96eb-4323-a0b7-59f424add2dc

📥 Commits

Reviewing files that changed from the base of the PR and between 6c50a87 and fb6cfc5.

📒 Files selected for processing (2)
  • rl_engine/kernels/registry.py
  • tests/test_rms_norm.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • rl_engine/kernels/registry.py

Comment thread tests/test_rms_norm.py
Comment on lines +144 to +159
def test_backward_batch_invariance_slice():
op = NativeRMSNormOp()

w_full = _rand((_HIDDEN,), seed=1).requires_grad_(True)
x_full = _rand((8, 32, _HIDDEN), seed=2).requires_grad_(True)
out_full = op.forward_fp32(x_full, w_full)
dy_full = _rand(out_full.shape, seed=3)
out_full.backward(dy_full)
grad_x_full_sliced = x_full.grad[:1].clone()

w_slice = _rand((_HIDDEN,), seed=1).requires_grad_(True)
x_slice = _rand((8, 32, _HIDDEN), seed=2)[:1].detach().requires_grad_(True)
out_slice = op.forward_fp32(x_slice, w_slice)
out_slice.backward(dy_full[:1]) # matching slice of the upstream gradient

assert torch.equal(x_slice.grad, grad_x_full_sliced)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟡 Minor | ⚡ Quick win

Parameterize the backward invariance check over both normalization widths.

This new Axis-A regression only exercises _HIDDEN, so a width-specific bug in the _HEAD_DIM path would still miss coverage even though this suite is meant to validate both Qwen3 RMSNorm widths.

Proposed fix
+@pytest.mark.parametrize("N", [_HIDDEN, _HEAD_DIM])
-def test_backward_batch_invariance_slice():
+def test_backward_batch_invariance_slice(N):
     op = NativeRMSNormOp()
 
-    w_full = _rand((_HIDDEN,), seed=1).requires_grad_(True)
-    x_full = _rand((8, 32, _HIDDEN), seed=2).requires_grad_(True)
+    w_full = _rand((N,), seed=1).requires_grad_(True)
+    x_full = _rand((8, 32, N), seed=2).requires_grad_(True)
     out_full = op.forward_fp32(x_full, w_full)
     dy_full = _rand(out_full.shape, seed=3)
     out_full.backward(dy_full)
     grad_x_full_sliced = x_full.grad[:1].clone()
 
-    w_slice = _rand((_HIDDEN,), seed=1).requires_grad_(True)
-    x_slice = _rand((8, 32, _HIDDEN), seed=2)[:1].detach().requires_grad_(True)
+    w_slice = _rand((N,), seed=1).requires_grad_(True)
+    x_slice = _rand((8, 32, N), seed=2)[:1].detach().requires_grad_(True)
     out_slice = op.forward_fp32(x_slice, w_slice)
     out_slice.backward(dy_full[:1])  # matching slice of the upstream gradient
 
     assert torch.equal(x_slice.grad, grad_x_full_sliced)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def test_backward_batch_invariance_slice():
op = NativeRMSNormOp()
w_full = _rand((_HIDDEN,), seed=1).requires_grad_(True)
x_full = _rand((8, 32, _HIDDEN), seed=2).requires_grad_(True)
out_full = op.forward_fp32(x_full, w_full)
dy_full = _rand(out_full.shape, seed=3)
out_full.backward(dy_full)
grad_x_full_sliced = x_full.grad[:1].clone()
w_slice = _rand((_HIDDEN,), seed=1).requires_grad_(True)
x_slice = _rand((8, 32, _HIDDEN), seed=2)[:1].detach().requires_grad_(True)
out_slice = op.forward_fp32(x_slice, w_slice)
out_slice.backward(dy_full[:1]) # matching slice of the upstream gradient
assert torch.equal(x_slice.grad, grad_x_full_sliced)
`@pytest.mark.parametrize`("N", [_HIDDEN, _HEAD_DIM])
def test_backward_batch_invariance_slice(N):
op = NativeRMSNormOp()
w_full = _rand((N,), seed=1).requires_grad_(True)
x_full = _rand((8, 32, N), seed=2).requires_grad_(True)
out_full = op.forward_fp32(x_full, w_full)
dy_full = _rand(out_full.shape, seed=3)
out_full.backward(dy_full)
grad_x_full_sliced = x_full.grad[:1].clone()
w_slice = _rand((N,), seed=1).requires_grad_(True)
x_slice = _rand((8, 32, N), seed=2)[:1].detach().requires_grad_(True)
out_slice = op.forward_fp32(x_slice, w_slice)
out_slice.backward(dy_full[:1]) # matching slice of the upstream gradient
assert torch.equal(x_slice.grad, grad_x_full_sliced)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/test_rms_norm.py` around lines 144 - 159, The backward invariance test
only covers the `_HIDDEN` normalization width, so it should be expanded to
validate both Qwen3 RMSNorm widths. Update
`test_backward_batch_invariance_slice` to run the same forward/backward slice
comparison for both `_HIDDEN` and `_HEAD_DIM` (for example via parametrization),
keeping the existing `NativeRMSNormOp`, `_rand`, and gradient slice assertions
unchanged.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

needs-gpu-ci platform: cuda Specific optimizations or bugs in NVIDIA graphics cards (such as FlashInfer, TMA optimizations) priority: high Severe congestion issues require the highest priority for resolution. sprint-0615

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants