feat(ws1): NativeSiLUOp + NativeSwiGLUOp pure-PyTorch ground-truth references + numerical contract tests#166
Conversation
WS1 ground-truth activation ops for issue RL-Align#108 (Qwen3-8B gated MLP): - NativeSiLUOp: silu(x) = x * sigmoid(x) - NativeSwiGLUOp: silu(gate) * up (gate/up at intermediate dim) Both expose the forward / forward_fp32 dual-path contract (fp32 ground truth + dtype-behavior path), pure functions, fp32 accumulation. - register PYTORCH_NATIVE_SILU / PYTORCH_NATIVE_SWIGLU in OpBackend and the cuda/rocm/cpu priority maps - tests/test_swiglu.py: correctness vs fp32 formula, dtype paths, Axis-A batch invariance (slice + padding), purity, gradient flow, shape guard, registry dispatch - docs/operators/activation.md + nav/index wiring
|
Warning Review limit reached
More reviews will be available in 22 minutes and 52 seconds. Learn how PR review limits work. Your organization has run out of usage credits. Purchase more credits in the billing tab to continue. ⌛ How to resolve this issue?After more reviews become available, a review can be triggered using the To avoid repeated limits, reduce automatic review volume by pausing incremental auto-reviews earlier, using label-based review opt-in, excluding WIP or generated PR titles, or requesting reviews manually when the PR is ready. If your team needs uninterrupted high-volume reviews, an organization admin can enable usage-based credits. 🚦 How do rate limits work?CodeRabbit enforces per-developer PR review limits for each organization. Most developers receive the normal plan review availability. For paid Pro and Pro+ PR reviews, CodeRabbit uses adaptive limits for sustained high-volume activity. When a developer's recent PR review activity reaches the 95th percentile or higher among CodeRabbit users, additional reviews become available more gradually as earlier reviews age out of the rolling window. Please see our Fair Usage Limits Policy for further information. ℹ️ Review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (6)
📝 WalkthroughWalkthroughAdds ChangesSiLU / SwiGLU activation operators
Estimated code review effort🎯 2 (Simple) | ⏱️ ~12 minutes Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@docs/operators/activation.md`:
- Around line 12-16: The fenced code block containing the ASCII diagram (showing
hidden, gate_proj, gate, swiglu, down_proj, and up_proj) is missing a language
identifier, which violates the MD040 markdownlint rule. Add "text" as the
language identifier to the opening fence by changing the opening triple
backticks to include the language specifier, making it ```text instead of just
```.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 52c72f53-6442-4854-927a-13addf467820
📒 Files selected for processing (7)
docs/.nav.ymldocs/operators/README.mddocs/operators/activation.mdrl_engine/kernels/ops/pytorch/activation/__init__.pyrl_engine/kernels/ops/pytorch/activation/swiglu.pyrl_engine/kernels/registry.pytests/test_swiglu.py
Flink-ddd
left a comment
There was a problem hiding this comment.
This PR maintains the same high engineering standards as #160. Writing out the explicit mathematical formulas (like x * sigmoid(x)) is excellent documentation for whoever builds the downstream CUDA/Triton fused kernels. Since these two PRs were likely developed in parallel, this one shares the exact same two minor omissions.
| up = _rand((2, _INTERMEDIATE), seed=13).requires_grad_(True) | ||
| op.forward_fp32(gate, up).sum().backward() | ||
| assert torch.isfinite(gate.grad).all() and torch.isfinite(up.grad).all() | ||
|
|
There was a problem hiding this comment.
Currently, test_silu_gradient_flows and test_swiglu_gradient_flows only check if the gradients are numerically valid (isfinite). To strictly satisfy the WS1 contract, we must use a slice test to prove that the backward pass (gradient computation) for both SiLU and SwiGLU is completely independent of the batch size.
Code Example (for SwiGLU):
# Add this to tests/test_swiglu.py
def test_swiglu_backward_batch_invariance_slice():
"""Axis A: Gradients must be bitwise identical regardless of batch size."""
op = NativeSwiGLUOp()
# 1. Full batch forward & backward
gate_full = _rand((8, 32, _INTERMEDIATE), seed=1).requires_grad_(True)
up_full = _rand((8, 32, _INTERMEDIATE), seed=2).requires_grad_(True)
out_full = op.forward_fp32(gate_full, up_full)
dy_full = _rand(out_full.shape, seed=3)
out_full.backward(dy_full)
grad_gate_full_sliced = gate_full.grad[:1].clone()
grad_up_full_sliced = up_full.grad[:1].clone()
# 2. Sliced batch (batch size = 1)
gate_slice = _rand((8, 32, _INTERMEDIATE), seed=1)[:1].detach().requires_grad_(True)
up_slice = _rand((8, 32, _INTERMEDIATE), seed=2)[:1].detach().requires_grad_(True)
out_slice = op.forward_fp32(gate_slice, up_slice)
# Use matching slice of upstream grad
out_slice.backward(dy_full[:1])
# 3. Assert gradients are bitwise identical
assert torch.equal(gate_slice.grad, grad_gate_full_sliced)
assert torch.equal(up_slice.grad, grad_up_full_sliced)
# Note: Please add a similar test for SiLU (test_silu_backward_batch_invariance_slice)
| Qwen3-8B intermediate dim (12288) is just one valid last-dim size. | ||
| """ | ||
|
|
||
| def __init__(self) -> None: |
There was a problem hiding this comment.
Just like the RMSNorm op, these pure PyTorch reference operators should inherit from nn.Module to ensure seamless integration with the broader PyTorch ecosystem (Hooks, state_dict, Dynamo tracing) later down the line.
Bonus: Because nn.Module automatically handles routing call to forward, you can safely delete the manually defined call methods in both classes, making the code even cleaner.
…kward batch invariance tests, MD040 fix - Refactor NativeSiLUOp and NativeSwiGLUOp to inherit from nn.Module, remove manual __call__ (nn.Module routes __call__ to forward). - Add test_silu_backward_batch_invariance_slice and test_swiglu_backward_batch_invariance_slice proving backward pass is completely independent of batch size (Axis-A contract). - Add `text` language identifier to fenced code block in activation.md (MD040 markdownlint compliance).
|
@Flink-ddd @KJLdefeated Thanks for the careful review! All requested changes have been pushed in 2b3db46. Summary of what was addressed:
|
Flink-ddd
left a comment
There was a problem hiding this comment.
LGTM now, Thank you for update.
|
Let's merge this PR first |
Summary
Adds the pure-PyTorch ground-truth reference ops for the gated MLP activation
of the WS1 batch-invariant forward chain: SiLU (Swish) and SwiGLU, built on
top of the numerical contract defined in #108. Ships the two ops, their registry
wiring, docs, and a 16-case test suite that pins down both alignment axes
(Axis-A bitwise batch invariance, Axis-B per-dtype path).
Refs #108
Terminology
This PR uses the WS1 alignment vocabulary from #108:
how many rows share the batch (batch size, slicing, padding). Asserted bitwise
(
torch.equal). This is what keeps train-time (large batch) and sample-time(small batch / dynamic padding) numerics identical so the policy ratio doesn't drift.
compute in fp32 and cast the result back to the input dtype, so the dtype output is
bitwise equal to the fp32 formula cast down — asserted with
torch.equalagainstan independent fp32 reference (no tolerance window needed for an element-wise op).
Motivation / Context
#108 establishes the ground-truth harness and numerical contract for the WS1
batch-invariant forward chain. The Qwen3-8B dense MLP is a gated (SwiGLU) MLP:
down_proj( silu(gate_proj(x)) * up_proj(x) )
This PR covers the activation stage in the middle:
silu— element-wisex * sigmoid(x)(hidden_act="silu"), shape-agnostic.swiglu—silu(gate) * up, wheregate/upare the gate_proj / up_projoutputs at the intermediate dim (Qwen3-8B: 12288). The trailing
down_projis aplain matmul and lives in a separate op.
This PR provides the deterministic fp32 reference path those downstream kernels
(Triton / CUDA / ROCm fused activation) will be validated against.
Changes
rl_engine/kernels/ops/pytorch/activation/swiglu.py—NativeSiLUOp,NativeSwiGLUOpforward()— accumulate in fp32, cast result back to input dtype (Axis-B path)forward_fp32()— fp32 accumulation, forced fp32 output (ground-truth / backward golden source)silu(x) = x * sigmoid(x);swiglu(gate, up) = gate * sigmoid(gate) * upgateandupmust share shaperl_engine/kernels/registry.py— registerPYTORCH_NATIVE_SILU/PYTORCH_NATIVE_SWIGLUinOpBackendand addsilu/swigludispatch to thecuda / rocm / cpu priority maps
tests/test_swiglu.py— 16 tests (details below)docs/operators/activation.md+ nav / index wiringHow this satisfies the #108 contract
forward_fp32()computes element-wise in fp32; tests use fixed-seedtorch.Generatorso outputs are reproducibletorch.equal); Axis-B dtype output is fp32-compute-then-cast, asserted bitwise against the independent fp32 formula cast to dtype (element-wise op needs no tolerance window)12288)Test Environment
──────────────────────────
padding variants, asserted bitwise
Checklist
Summary by CodeRabbit
New Features
Documentation