feat(ws1): NativeRMSNormOp pure-PyTorch ground-truth reference + numerical contract tests#160
Conversation
- NativeRMSNormOp with forward / forward_fp32 (fp32 ground-truth path) - covers both normalized dims: hidden=4096 and head_dim=128 (Qwen3 QK-Norm) - register PYTORCH_NATIVE_RMS_NORM in OpBackend + cpu/cuda/rocm priority map - tests/test_rms_norm.py: axis-A bitwise batch invariance + dtype tolerance, shape guard, purity, gradient flow, registry dispatch (16 tests) Refs RL-Align#108
📝 WalkthroughWalkthroughAdds ChangesNativeRMSNormOp: Implementation, Registry Wiring, and Tests
Possibly Related Issues
Suggested Reviewers
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@rl_engine/kernels/ops/pytorch/norm/rms_norm.py`:
- Around line 5-6: The file rms_norm.py has linting and formatting violations
from black, isort, and trailing-whitespace checks that are blocking CI. Run the
project's pre-commit formatting hooks (typically via a command like 'pre-commit
run --all-files' or 'black' and 'isort' individually) to automatically reformat
the file and fix signature spacing issues around lines 25-31, expression
formatting issues around lines 38-44 and 57-67, and any trailing whitespace
violations. Commit the reformatted file after the hooks complete.
In `@tests/test_rms_norm.py`:
- Around line 10-11: The test_rms_norm.py file has formatting inconsistencies
detected by Black. Run the Black formatter on this file to automatically fix
alignment and spacing issues in inline comments (like those on the _HIDDEN and
_HEAD_DIM constant definitions) and long assertions (around the test assertion
blocks). Apply Black's output and commit the formatted result to resolve the CI
formatting check.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 34ccd0aa-23c0-4f09-8d2a-e7d333cea8a2
📒 Files selected for processing (3)
rl_engine/kernels/ops/pytorch/norm/rms_norm.pyrl_engine/kernels/registry.pytests/test_rms_norm.py
Resolve CodeRabbit formatting findings on RL-Align#160: black (line-length=100), isort (profile=black), trailing-whitespace and EOF fixes. No logic change; 16 tests still pass.
| w = _rand((_HIDDEN,), seed=14).requires_grad_(True) | ||
| op.forward_fp32(x, w).sum().backward() | ||
| assert torch.isfinite(x.grad).all() and torch.isfinite(w.grad).all() | ||
|
|
There was a problem hiding this comment.
You have thoroughly covered forward Axis-A (Batch Invariance), but WS1 requires gradients to be batch-invariant as well (vital for #153). Currently, test_gradient_flows only checks if gradients exist and are finite. You need to prove that slicing a batch yields identical gradients to computing the full batch.
Code Example:
# Add this new test case
def test_backward_batch_invariance_slice():
"""Axis A: Gradients must be bitwise identical regardless of batch size."""
op = NativeRMSNormOp()
# Full batch forward & backward
w_full = _rand((_HIDDEN,), seed=1).requires_grad_(True)
x_full = _rand((8, 32, _HIDDEN), seed=2).requires_grad_(True)
out_full = op.forward_fp32(x_full, w_full)
dy_full = _rand(out_full.shape, seed=3)
out_full.backward(dy_full)
grad_x_full_sliced = x_full.grad[:1].clone()
# Sliced batch forward & backward (batch size = 1)
w_slice = _rand((_HIDDEN,), seed=1).requires_grad_(True)
x_slice = _rand((8, 32, _HIDDEN), seed=2)[:1].detach().requires_grad_(True)
out_slice = op.forward_fp32(x_slice, w_slice)
out_slice.backward(dy_full[:1]) # Use matching slice of upstream grad
# Assert x.grad is bitwise identical
assert torch.equal(x_slice.grad, grad_x_full_sliced)
| f"got tuple(weight.shape)={tuple(weight.shape)}" | ||
| ) | ||
| x_f = x.float() | ||
| var = x_f.pow(2).mean(dim=-1, keepdim=True) |
There was a problem hiding this comment.
The line var = x_f.pow(2).mean(...) explicitly materializes a full-sized FP32 tensor in memory before reducing. For a reference/ground-truth operator, code readability is more important than VRAM efficiency, so this is perfectly fine here. However, keep this memory spike in mind when reviewing the downstream Triton/CUDA implementations for long-context workloads. No code change is strictly required for this PR.
KJLdefeated
left a comment
There was a problem hiding this comment.
Implementation looks correct — standard RMSNorm, fp32 accumulation, eps inside the sqrt, plain weight scaling, clean shape guard. Registry wiring matches the existing attn pattern. Happy to approve once request changes are addressed.
| def test_forward_fp32_matches_manual_reference(N): | ||
| op = NativeRMSNormOp() | ||
| x, w = _rand((2, 16, N), seed=0), _rand((N,), seed=1) | ||
| assert torch.equal(op.forward_fp32(x, w), _manual_rms_norm(x, w)) |
There was a problem hiding this comment.
_manual_rms_norm is the same expression as _rms_norm in the same float order, which is why torch.equal passes bitwise. The downside is that if there is a bug in the formula (eps placement, wrong reduction dim) would be present in both sides and still go green.
Note it has to be allclose, not torch.equal since F.rms_norm may reduce in a different order, and that non-identical reduction is exactly what makes it a real check:
import torch.nn.functional as F
@pytest.mark.parametrize("N", [_HIDDEN, _HEAD_DIM])
def test_forward_fp32_matches_torch_reference(N):
op = NativeRMSNormOp()
x, w = _rand((2, 16, N), seed=0), _rand((N,), seed=1)
ref = F.rms_norm(x.float(), (N,), weight=w.float(), eps=_EPS)
torch.testing.assert_close(op.forward_fp32(x, w), ref, rtol=1e-6, atol=1e-6)Fine to keep the hand-written reference as a secondary sanity check, but it shouldn't be the primary correctness assertion.
…eck + backward batch-invariance - KJLdefeated: add test_forward_fp32_matches_torch_reference comparing against PyTorch's F.rms_norm via assert_close (tolerance, not torch.equal, since its reduction order may differ) so a shared formula bug can't pass green. Keep the hand-written _manual_rms_norm test as a secondary bitwise sanity check. - Flink-ddd: add test_backward_batch_invariance_slice proving input gradients are bitwise identical regardless of batch size (Axis-A for gradients, needed for RL-Align#153).
|
@KJLdefeated You're right — _manual_rms_norm uses the same expression and float order as the op, so torch.equal only proves they're copies of each other, not that the formula is correct; a shared bug (eps placement, wrong reduction dim) would pass green on both sides. Addressed in 64f6f56:
Thanks for the careful catch! |
@Flink-ddd Good point — forward batch-invariance isn't enough; WS1 (and #153) needs the backward to be batch-invariant too, and test_gradient_flows only checked isfinite. Addressed in 64f6f56: Added test_backward_batch_invariance_slice following your example — it runs the full-batch forward+backward, then a batch-of-1 recompute fed the matching slice of the upstream gradient, and asserts torch.equal(x_slice.grad, grad_x_full_sliced) so the input gradient is bitwise identical regardless of batch size. On your other note (var = x_f.pow(2).mean(...) materializing a full fp32 tensor): left as-is for this PR, since you flagged it as no-change-required — readability over VRAM for the reference op. I've noted it as a watch-item for the downstream Triton/CUDA kernels on long-context workloads. All 19 tests in tests/test_rms_norm.py pass and black --check is clean. Thanks for the review! |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tests/test_rms_norm.py`:
- Around line 144-159: The backward invariance test only covers the `_HIDDEN`
normalization width, so it should be expanded to validate both Qwen3 RMSNorm
widths. Update `test_backward_batch_invariance_slice` to run the same
forward/backward slice comparison for both `_HIDDEN` and `_HEAD_DIM` (for
example via parametrization), keeping the existing `NativeRMSNormOp`, `_rand`,
and gradient slice assertions unchanged.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 5bc24ea0-96eb-4323-a0b7-59f424add2dc
📒 Files selected for processing (2)
rl_engine/kernels/registry.pytests/test_rms_norm.py
🚧 Files skipped from review as they are similar to previous changes (1)
- rl_engine/kernels/registry.py
| def test_backward_batch_invariance_slice(): | ||
| op = NativeRMSNormOp() | ||
|
|
||
| w_full = _rand((_HIDDEN,), seed=1).requires_grad_(True) | ||
| x_full = _rand((8, 32, _HIDDEN), seed=2).requires_grad_(True) | ||
| out_full = op.forward_fp32(x_full, w_full) | ||
| dy_full = _rand(out_full.shape, seed=3) | ||
| out_full.backward(dy_full) | ||
| grad_x_full_sliced = x_full.grad[:1].clone() | ||
|
|
||
| w_slice = _rand((_HIDDEN,), seed=1).requires_grad_(True) | ||
| x_slice = _rand((8, 32, _HIDDEN), seed=2)[:1].detach().requires_grad_(True) | ||
| out_slice = op.forward_fp32(x_slice, w_slice) | ||
| out_slice.backward(dy_full[:1]) # matching slice of the upstream gradient | ||
|
|
||
| assert torch.equal(x_slice.grad, grad_x_full_sliced) |
There was a problem hiding this comment.
🎯 Functional Correctness | 🟡 Minor | ⚡ Quick win
Parameterize the backward invariance check over both normalization widths.
This new Axis-A regression only exercises _HIDDEN, so a width-specific bug in the _HEAD_DIM path would still miss coverage even though this suite is meant to validate both Qwen3 RMSNorm widths.
Proposed fix
+@pytest.mark.parametrize("N", [_HIDDEN, _HEAD_DIM])
-def test_backward_batch_invariance_slice():
+def test_backward_batch_invariance_slice(N):
op = NativeRMSNormOp()
- w_full = _rand((_HIDDEN,), seed=1).requires_grad_(True)
- x_full = _rand((8, 32, _HIDDEN), seed=2).requires_grad_(True)
+ w_full = _rand((N,), seed=1).requires_grad_(True)
+ x_full = _rand((8, 32, N), seed=2).requires_grad_(True)
out_full = op.forward_fp32(x_full, w_full)
dy_full = _rand(out_full.shape, seed=3)
out_full.backward(dy_full)
grad_x_full_sliced = x_full.grad[:1].clone()
- w_slice = _rand((_HIDDEN,), seed=1).requires_grad_(True)
- x_slice = _rand((8, 32, _HIDDEN), seed=2)[:1].detach().requires_grad_(True)
+ w_slice = _rand((N,), seed=1).requires_grad_(True)
+ x_slice = _rand((8, 32, N), seed=2)[:1].detach().requires_grad_(True)
out_slice = op.forward_fp32(x_slice, w_slice)
out_slice.backward(dy_full[:1]) # matching slice of the upstream gradient
assert torch.equal(x_slice.grad, grad_x_full_sliced)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def test_backward_batch_invariance_slice(): | |
| op = NativeRMSNormOp() | |
| w_full = _rand((_HIDDEN,), seed=1).requires_grad_(True) | |
| x_full = _rand((8, 32, _HIDDEN), seed=2).requires_grad_(True) | |
| out_full = op.forward_fp32(x_full, w_full) | |
| dy_full = _rand(out_full.shape, seed=3) | |
| out_full.backward(dy_full) | |
| grad_x_full_sliced = x_full.grad[:1].clone() | |
| w_slice = _rand((_HIDDEN,), seed=1).requires_grad_(True) | |
| x_slice = _rand((8, 32, _HIDDEN), seed=2)[:1].detach().requires_grad_(True) | |
| out_slice = op.forward_fp32(x_slice, w_slice) | |
| out_slice.backward(dy_full[:1]) # matching slice of the upstream gradient | |
| assert torch.equal(x_slice.grad, grad_x_full_sliced) | |
| `@pytest.mark.parametrize`("N", [_HIDDEN, _HEAD_DIM]) | |
| def test_backward_batch_invariance_slice(N): | |
| op = NativeRMSNormOp() | |
| w_full = _rand((N,), seed=1).requires_grad_(True) | |
| x_full = _rand((8, 32, N), seed=2).requires_grad_(True) | |
| out_full = op.forward_fp32(x_full, w_full) | |
| dy_full = _rand(out_full.shape, seed=3) | |
| out_full.backward(dy_full) | |
| grad_x_full_sliced = x_full.grad[:1].clone() | |
| w_slice = _rand((N,), seed=1).requires_grad_(True) | |
| x_slice = _rand((8, 32, N), seed=2)[:1].detach().requires_grad_(True) | |
| out_slice = op.forward_fp32(x_slice, w_slice) | |
| out_slice.backward(dy_full[:1]) # matching slice of the upstream gradient | |
| assert torch.equal(x_slice.grad, grad_x_full_sliced) |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tests/test_rms_norm.py` around lines 144 - 159, The backward invariance test
only covers the `_HIDDEN` normalization width, so it should be expanded to
validate both Qwen3 RMSNorm widths. Update
`test_backward_batch_invariance_slice` to run the same forward/backward slice
comparison for both `_HIDDEN` and `_HEAD_DIM` (for example via parametrization),
keeping the existing `NativeRMSNormOp`, `_rand`, and gradient slice assertions
unchanged.
Summary
Adds the pure-PyTorch ground-truth reference op for RMSNorm (pre-norm / QK-Norm)
as the first WS1 batch-invariant operator built on top of the numerical contract
defined in #108. Ships the op, its registry wiring, and a 16-case test suite that
pins down both alignment axes (Axis-A bitwise batch invariance, Axis-B per-dtype
tolerance).
Refs #108
Terminology
This PR uses the WS1 alignment vocabulary from #108:
how many rows share the batch (batch size, slicing, padding). Asserted bitwise
(
torch.equal). This is what keeps train-time (large batch) and sample-time(small batch / dynamic padding) numerics identical so the policy ratio doesn't drift.
documented per-dtype tolerance of the fp32 ground-truth. Asserted with
torch.allcloseMotivation / Context
#108 establishes the ground-truth harness and numerical contract for the WS1
batch-invariant forward chain. RMSNorm is required on two normalized dims of the
target model (Qwen3-8B dense):
hidden = 4096— input / post-attention normhead_dim = 128— QK-Norm (per-head RMSNorm on Q and K)This PR provides the deterministic fp32 reference path those downstream kernels
(Triton / CUDA / ROCm RMSNorm) will be validated against.
Changes
rl_engine/kernels/ops/pytorch/norm/rms_norm.py—NativeRMSNormOpforward()— accumulate in fp32, cast result back tox.dtype(Axis-B candidate path)forward_fp32()— fp32 accumulation, forced fp32 output (ground-truth / backward golden source)out = x * rsqrt(mean(x^2, dim=-1) + eps) * weightepslives inside the sqrt; plain weight scaling (not the1 + weightvariant)weightmust be 1-D of sizex.shape[-1]rl_engine/kernels/registry.py— registerPYTORCH_NATIVE_RMS_NORMand add
rms_normdispatch to the cuda / rocm / cpu priority mapstests/test_rms_norm.py— 16 tests (details below)How this satisfies the #108 contract
forward_fp32()accumulates in fp32 alongdim=-1; tests use fixed-seedtorch.Generatorso outputs are reproducibletorch.equal); Axis-B asserted within documented per-dtype thresholds — bf16atol=2e-2, rtol=1.6e-2, fp16atol=1e-3, rtol=1e-3hidden=4096andhead_dim=128Test Environment
Testing
Run from the repo root with
python -m pytest(the-mform puts the repo on