Skip to content

feat(ws1): NativeSiLUOp + NativeSwiGLUOp pure-PyTorch ground-truth references + numerical contract tests#166

Merged
Flink-ddd merged 4 commits into
RL-Align:mainfrom
maxiaosong1124:feat/ws1-silu-swiglu-pytorch-op
Jun 28, 2026
Merged

feat(ws1): NativeSiLUOp + NativeSwiGLUOp pure-PyTorch ground-truth references + numerical contract tests#166
Flink-ddd merged 4 commits into
RL-Align:mainfrom
maxiaosong1124:feat/ws1-silu-swiglu-pytorch-op

Conversation

@maxiaosong1124

@maxiaosong1124 maxiaosong1124 commented Jun 21, 2026

Copy link
Copy Markdown
Collaborator

Summary

Adds the pure-PyTorch ground-truth reference ops for the gated MLP activation
of the WS1 batch-invariant forward chain: SiLU (Swish) and SwiGLU, built on
top of the numerical contract defined in #108. Ships the two ops, their registry
wiring, docs, and a 16-case test suite that pins down both alignment axes
(Axis-A bitwise batch invariance, Axis-B per-dtype path).

Refs #108

Terminology

This PR uses the WS1 alignment vocabulary from #108:

  • Axis-A — batch invariance (reproducibility). A row's output must not depend on
    how many rows share the batch (batch size, slicing, padding). Asserted bitwise
    (torch.equal). This is what keeps train-time (large batch) and sample-time
    (small batch / dynamic padding) numerics identical so the policy ratio doesn't drift.
  • Axis-B — accuracy. The low-precision (bf16 / fp16) forward path. These activations
    compute in fp32 and cast the result back to the input dtype, so the dtype output is
    bitwise equal to the fp32 formula cast down — asserted with torch.equal against
    an independent fp32 reference (no tolerance window needed for an element-wise op).

Motivation / Context

#108 establishes the ground-truth harness and numerical contract for the WS1
batch-invariant forward chain. The Qwen3-8B dense MLP is a gated (SwiGLU) MLP:

down_proj( silu(gate_proj(x)) * up_proj(x) )

This PR covers the activation stage in the middle:

  • silu — element-wise x * sigmoid(x) (hidden_act="silu"), shape-agnostic.
  • swiglusilu(gate) * up, where gate / up are the gate_proj / up_proj
    outputs at the intermediate dim (Qwen3-8B: 12288). The trailing down_proj is a
    plain matmul and lives in a separate op.

This PR provides the deterministic fp32 reference path those downstream kernels
(Triton / CUDA / ROCm fused activation) will be validated against.

Changes

  • rl_engine/kernels/ops/pytorch/activation/swiglu.pyNativeSiLUOp, NativeSwiGLUOp
    • forward() — accumulate in fp32, cast result back to input dtype (Axis-B path)
    • forward_fp32() — fp32 accumulation, forced fp32 output (ground-truth / backward golden source)
    • Formulas: silu(x) = x * sigmoid(x); swiglu(gate, up) = gate * sigmoid(gate) * up
    • Pure functions — inputs never mutated in place
    • SwiGLU shape guard: gate and up must share shape
  • rl_engine/kernels/registry.py — register PYTORCH_NATIVE_SILU /
    PYTORCH_NATIVE_SWIGLU in OpBackend and add silu / swiglu dispatch to the
    cuda / rocm / cpu priority maps
  • tests/test_swiglu.py — 16 tests (details below)
  • docs/operators/activation.md + nav / index wiring

How this satisfies the #108 contract

#108 requirement How it's met here
Deterministic reference path forward_fp32() computes element-wise in fp32; tests use fixed-seed torch.Generator so outputs are reproducible
Per-dtype tolerance policy (bitwise vs tight-tolerance) Axis-A asserted bitwise (torch.equal); Axis-B dtype output is fp32-compute-then-cast, asserted bitwise against the independent fp32 formula cast to dtype (element-wise op needs no tolerance window)
Batch-config sweep / validation helper Batch-invariance checks compute on the full batch, then assert sliced/padded rows are bitwise identical to their full-batch counterparts
Realistic shapes covered Batch-invariance tests run at the Qwen3-8B intermediate dim (12288)

Test Environment

OS Ubuntu 22.04.5 LTS (kernel 5.15.0-122-generic)
Python 3.12.3
PyTorch 2.8.0+cu128
CUDA / cuDNN 12.8 / 9.10.02 (driver 580.65.06)

──────────────────────────
padding variants, asserted bitwise

  • purity (inputs not mutated in place)
  • gradient flow (fp32 autograd = backward golden source)
  • SwiGLU shape guard fires on mismatched gate / up shapes
  • registry dispatch resolves silu → NativeSiLUOp, swiglu → NativeSwiGLUOp

Checklist

  • Pure-PyTorch reference, no custom extension required
  • SwiGLU covered at the Qwen3-8B intermediate dim (12288)
  • Axis-A bitwise batch invariance enforced
  • Axis-B fp32-compute-then-cast dtype path tested
  • Registered in OpBackend + cuda/rocm/cpu priority maps
  • All 16 tests pass locally

Summary by CodeRabbit

  • New Features

    • Added SiLU and SwiGLU activation operators with PyTorch implementations and registry support.
  • Documentation

    • Added comprehensive documentation for SiLU/SwiGLU operators, including mathematical definitions, tensor shape requirements, and backend dispatch behavior.
    • Added validation tests covering correctness, input validation, batch invariance, and gradient propagation.

WS1 ground-truth activation ops for issue RL-Align#108 (Qwen3-8B gated MLP):
- NativeSiLUOp: silu(x) = x * sigmoid(x)
- NativeSwiGLUOp: silu(gate) * up (gate/up at intermediate dim)
Both expose the forward / forward_fp32 dual-path contract (fp32
ground truth + dtype-behavior path), pure functions, fp32 accumulation.

- register PYTORCH_NATIVE_SILU / PYTORCH_NATIVE_SWIGLU in OpBackend and
  the cuda/rocm/cpu priority maps
- tests/test_swiglu.py: correctness vs fp32 formula, dtype paths,
  Axis-A batch invariance (slice + padding), purity, gradient flow,
  shape guard, registry dispatch
- docs/operators/activation.md + nav/index wiring
@coderabbitai

coderabbitai Bot commented Jun 21, 2026

Copy link
Copy Markdown

Review Change Stack

Warning

Review limit reached

@maxiaosong1124, we couldn't start this review because you've reached your PR review rate limit.

More reviews will be available in 22 minutes and 52 seconds. Learn how PR review limits work.

Your organization has run out of usage credits. Purchase more credits in the billing tab to continue.

⌛ How to resolve this issue?

After more reviews become available, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

To avoid repeated limits, reduce automatic review volume by pausing incremental auto-reviews earlier, using label-based review opt-in, excluding WIP or generated PR titles, or requesting reviews manually when the PR is ready. If your team needs uninterrupted high-volume reviews, an organization admin can enable usage-based credits.

🚦 How do rate limits work?

CodeRabbit enforces per-developer PR review limits for each organization. Most developers receive the normal plan review availability.

For paid Pro and Pro+ PR reviews, CodeRabbit uses adaptive limits for sustained high-volume activity. When a developer's recent PR review activity reaches the 95th percentile or higher among CodeRabbit users, additional reviews become available more gradually as earlier reviews age out of the rolling window.

Please see our Fair Usage Limits Policy for further information.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: d4c9bdf8-5220-4f6b-87ef-273bdbed620e

📥 Commits

Reviewing files that changed from the base of the PR and between 9bcd65b and 2b3db46.

📒 Files selected for processing (6)
  • docs/.nav.yml
  • docs/operators/README.md
  • docs/operators/activation.md
  • rl_engine/kernels/ops/pytorch/activation/swiglu.py
  • rl_engine/kernels/registry.py
  • tests/test_swiglu.py
📝 Walkthrough

Walkthrough

Adds NativeSiLUOp and NativeSwiGLUOp PyTorch reference implementations with fp32-accumulation semantics and dual forward/forward_fp32 paths. Two new OpBackend enum members are registered in KernelRegistry for cuda/rocm/cpu dispatch. A 127-line test module validates correctness, invariance, purity, gradients, and registry integration. Documentation for the operator contract is added under docs/operators/activation.md.

Changes

SiLU / SwiGLU activation operators

Layer / File(s) Summary
Op implementation and registry wiring
rl_engine/kernels/ops/pytorch/activation/__init__.py, rl_engine/kernels/ops/pytorch/activation/swiglu.py, rl_engine/kernels/registry.py
NativeSiLUOp and NativeSwiGLUOp are implemented with forward (fp32 compute, cast to input dtype) and forward_fp32 (fp32 output) paths; NativeSwiGLUOp._swiglu raises ValueError on shape mismatch. OpBackend gains two new enum members, and KernelRegistry._priority_map routes "silu"/"swiglu" on all platforms to those backends.
Test suite
tests/test_swiglu.py
Covers dtype-preserving correctness against fp32 reference, shape-mismatch guard, batch/padding invariance, input purity, finite-gradient backprop, and kernel_registry.get_op dispatch for both ops.
Operator documentation and navigation
docs/operators/activation.md, docs/operators/README.md, docs/.nav.yml
Documents math formulas, tensor contract, dual-path semantics, dispatch behavior, accuracy axes, test coverage, and current limitations (no fused CUDA/Triton backend). Navigation and README index updated.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~12 minutes

Poem

🐇 Hop! The sigmoid blooms at last,
gate meets up with fp32 cast,
SwiGLU purrs through every dtype lane,
no mutation, no broadcast pain.
The registry points, the tests all pass—
this little rabbit ships some class! ✨

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 20.83% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main change: introducing two pure-PyTorch reference implementations (NativeSiLUOp and NativeSwiGLUOp) with ground-truth references and numerical contract tests.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@docs/operators/activation.md`:
- Around line 12-16: The fenced code block containing the ASCII diagram (showing
hidden, gate_proj, gate, swiglu, down_proj, and up_proj) is missing a language
identifier, which violates the MD040 markdownlint rule. Add "text" as the
language identifier to the opening fence by changing the opening triple
backticks to include the language specifier, making it ```text instead of just
```.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 52c72f53-6442-4854-927a-13addf467820

📥 Commits

Reviewing files that changed from the base of the PR and between d6db6bf and 9bcd65b.

📒 Files selected for processing (7)
  • docs/.nav.yml
  • docs/operators/README.md
  • docs/operators/activation.md
  • rl_engine/kernels/ops/pytorch/activation/__init__.py
  • rl_engine/kernels/ops/pytorch/activation/swiglu.py
  • rl_engine/kernels/registry.py
  • tests/test_swiglu.py

Comment thread docs/operators/activation.md Outdated

@Flink-ddd Flink-ddd left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR maintains the same high engineering standards as #160. Writing out the explicit mathematical formulas (like x * sigmoid(x)) is excellent documentation for whoever builds the downstream CUDA/Triton fused kernels. Since these two PRs were likely developed in parallel, this one shares the exact same two minor omissions.

Comment thread tests/test_swiglu.py
up = _rand((2, _INTERMEDIATE), seed=13).requires_grad_(True)
op.forward_fp32(gate, up).sum().backward()
assert torch.isfinite(gate.grad).all() and torch.isfinite(up.grad).all()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, test_silu_gradient_flows and test_swiglu_gradient_flows only check if the gradients are numerically valid (isfinite). To strictly satisfy the WS1 contract, we must use a slice test to prove that the backward pass (gradient computation) for both SiLU and SwiGLU is completely independent of the batch size.

Code Example (for SwiGLU):

# Add this to tests/test_swiglu.py
def test_swiglu_backward_batch_invariance_slice():
    """Axis A: Gradients must be bitwise identical regardless of batch size."""
    op = NativeSwiGLUOp()
    
    # 1. Full batch forward & backward
    gate_full = _rand((8, 32, _INTERMEDIATE), seed=1).requires_grad_(True)
    up_full = _rand((8, 32, _INTERMEDIATE), seed=2).requires_grad_(True)
    out_full = op.forward_fp32(gate_full, up_full)
    
    dy_full = _rand(out_full.shape, seed=3)
    out_full.backward(dy_full)
    
    grad_gate_full_sliced = gate_full.grad[:1].clone()
    grad_up_full_sliced = up_full.grad[:1].clone()
    
    # 2. Sliced batch (batch size = 1)
    gate_slice = _rand((8, 32, _INTERMEDIATE), seed=1)[:1].detach().requires_grad_(True)
    up_slice = _rand((8, 32, _INTERMEDIATE), seed=2)[:1].detach().requires_grad_(True)
    out_slice = op.forward_fp32(gate_slice, up_slice)
    
    # Use matching slice of upstream grad
    out_slice.backward(dy_full[:1])
    
    # 3. Assert gradients are bitwise identical
    assert torch.equal(gate_slice.grad, grad_gate_full_sliced)
    assert torch.equal(up_slice.grad, grad_up_full_sliced)

# Note: Please add a similar test for SiLU (test_silu_backward_batch_invariance_slice)

Qwen3-8B intermediate dim (12288) is just one valid last-dim size.
"""

def __init__(self) -> None:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just like the RMSNorm op, these pure PyTorch reference operators should inherit from nn.Module to ensure seamless integration with the broader PyTorch ecosystem (Hooks, state_dict, Dynamo tracing) later down the line.
Bonus: Because nn.Module automatically handles routing call to forward, you can safely delete the manually defined call methods in both classes, making the code even cleaner.

@KJLdefeated KJLdefeated left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Math and structure are correct, dual-path contract is consistent with the rest of WS1, and Vensen's two points cover most of what I'd raise. Pls resolve Vensen's requests. LGTM.

…kward batch invariance tests, MD040 fix

- Refactor NativeSiLUOp and NativeSwiGLUOp to inherit from nn.Module,
  remove manual __call__ (nn.Module routes __call__ to forward).
- Add test_silu_backward_batch_invariance_slice and
  test_swiglu_backward_batch_invariance_slice proving backward pass
  is completely independent of batch size (Axis-A contract).
- Add `text` language identifier to fenced code block in activation.md
  (MD040 markdownlint compliance).
@maxiaosong1124

Copy link
Copy Markdown
Collaborator Author

@Flink-ddd @KJLdefeated Thanks for the careful review! All requested changes have been pushed in 2b3db46. Summary of what was addressed:
both points resolved:

  1. nn.Module inheritance — NativeSiLUOp and NativeSwiGLUOp now inherit from nn.Module (with super().init()), matching the RMSNorm op. As you noted, the manual call methods are now removed since nn.Module routes call → forward automatically.
  2. Backward batch-invariance slice tests — Added test_silu_backward_batch_invariance_slice and test_swiglu_backward_batch_invariance_slice. Instead of only checking isfinite, they now assert (via torch.equal) that the gradients of the batch-size-1 slice are bitwise identical to the corresponding slice of the full-batch gradients — i.e. the Axis-A backward contract, following the example you provided.

@Flink-ddd Flink-ddd left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM now, Thank you for update.

@Flink-ddd Flink-ddd merged commit 9480500 into RL-Align:main Jun 28, 2026
5 checks passed
@Flink-ddd

Copy link
Copy Markdown
Collaborator

Let's merge this PR first

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants