Skip to content

feat(ws1): Add PyTorch matmul reference operator#168

Open
a-kaa wants to merge 3 commits into
RL-Align:mainfrom
a-kaa:issue-108-matmul
Open

feat(ws1): Add PyTorch matmul reference operator#168
a-kaa wants to merge 3 commits into
RL-Align:mainfrom
a-kaa:issue-108-matmul

Conversation

@a-kaa

@a-kaa a-kaa commented Jun 21, 2026

Copy link
Copy Markdown
Collaborator

Summary

Adds the PyTorch reference GEMM/Matmul operator for Issue #108.

This implements NativeMatmulOp as the fp32 ground-truth baseline for dense projection
matmuls. The operator follows the frozen #108 interface contract:

  • forward_fp32(a, b) casts inputs to fp32 and calls torch.matmul once
  • op_class = "reduction"
  • registry dispatch via kernel_registry.get_op("matmul")

Also adds operator documentation under docs/operators/matmul.md.

Implementation

  • Added rl_engine/kernels/ops/pytorch/linear/matmul.py
  • Added rl_engine/kernels/ops/pytorch/linear/__init__.py
  • Registered PYTORCH_NATIVE_MATMUL in rl_engine/kernels/registry.py
  • Added tests/test_matmul.py
  • Added Matmul operator docs and linked them from docs/operators/README.md
image

Summary by CodeRabbit

  • New Features

    • Added a new matmul operator with a native PyTorch reference implementation.
    • Enabled automatic matmul dispatch on CPU, CUDA, and ROCm.
  • Documentation

    • Added a dedicated matmul operator documentation page.
    • Updated the operators index to include matmul.
  • Tests

    • Added comprehensive matmul tests covering correctness, dtype behavior (including FP32 casting), supported shapes, batch invariance, and registry integration.

@coderabbitai

coderabbitai Bot commented Jun 21, 2026

Copy link
Copy Markdown

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 15198113-4727-4cb8-8cd7-0127d7642df6

📥 Commits

Reviewing files that changed from the base of the PR and between 18a6640 and c545946.

📒 Files selected for processing (1)
  • tests/test_matmul.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/test_matmul.py

📝 Walkthrough

Walkthrough

Adds NativeMatmulOp, a PyTorch fp32-reference matrix multiplication operator, to rl_engine/kernels/ops/pytorch/linear/. The class is exported from the linear package init, registered under a new OpBackend.PYTORCH_NATIVE_MATMUL enum value, and wired into KernelRegistry._priority_map for CUDA, ROCm, and CPU. A new test module and operator documentation page are included.

Changes

NativeMatmulOp Feature

Layer / File(s) Summary
NativeMatmulOp implementation and registry wiring
rl_engine/kernels/ops/pytorch/linear/matmul.py, rl_engine/kernels/ops/pytorch/linear/__init__.py, rl_engine/kernels/registry.py
NativeMatmulOp defines forward (casts inputs to fp32, calls torch.matmul, returns result in original dtype), forward_fp32 (returns fp32 result directly), and __call__ (delegates to forward). __init__.py exports it via __all__. registry.py adds OpBackend.PYTORCH_NATIVE_MATMUL enum member and inserts "matmul" entries into _priority_map for "cuda", "rocm", and "cpu" platforms.
NativeMatmulOp test suite
tests/test_matmul.py
Five test classes validate output shape, forward/forward_fp32 dtype contract, __call__/forward equivalence, bitwise fp32 match to torch.matmul, non-mutation of inputs, op_class = "reduction" metadata, batch invariance (per-row vs. full batch, padded batch), dtype-parameterized accuracy for float32/bfloat16/float16, Qwen3 projection shapes, and registry dispatch returning a NativeMatmulOp instance.
Matmul operator documentation
docs/operators/matmul.md, docs/operators/README.md
New matmul.md documents entry points, backend dispatch, tensor contract, Qwen3 projection shape coverage, reference semantics, accuracy expectations, test command, and relevant files. README.md adds a bullet linking to matmul.md in the "Current Pages" list.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Poem

🐇 Hop hop, a matrix joins the fold,
Two tensors meet in fp32 gold,
The registry maps the path with care,
Batch-invariant, dtype-aware!
The docs are fresh, the tests are bright—
torch.matmul shines in reference light. ✨

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and specifically describes the main change: adding a PyTorch matmul reference operator, which aligns with the primary objective of the PR across all modified files.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@a-kaa a-kaa changed the title Add PyTorch matmul reference operator feat(ws1): Add PyTorch matmul reference operator Jun 22, 2026

@Flink-ddd Flink-ddd left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here are some review comments, everything else is fine.

Comment thread tests/test_matmul.py
assert torch.allclose(out_typed, out_fp32, atol=atol, rtol=rtol), (
f"dtype={dtype}, max_abs_error={diff:.3e} exceeds " f"atol={atol}, rtol={rtol}"
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are no backward pass tests (test_gradient_flows) in this PR. Matrix multiplication is the primary engine of the neural network and the most common source of gradient reduction drift.

You must verify that gradients flow correctly to both a (activations) and b (weights), and that they strictly adhere to the Axis-A Batch-Invariance contract:

For the activation gradient (a.grad), a slice of the full-batch gradient must be bitwise identical to the gradient computed from a sliced single-batch forward/backward pass.

For the weight gradient (b.grad), because it accumulates across the batch dimension, you should ideally verify that the full-batch weight gradient equals the sum of the single-batch weight gradients.

def __init__(self) -> None:
pass

def __call__(self, a: Tensor, b: Tensor) -> Tensor:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As with the previous WS1 reference operators, NativeMatmulOp must inherit from torch.nn.Module to ensure upstream compatibility with Dynamo tracing and PyTorch hooks.

Please inherit from nn.Module, initialize super().init(), and remove the manually defined call method.

Comment thread tests/test_matmul.py
Comment on lines +77 to +97
class TestNativeMatmulOpBatchInvariance:
def test_batch1_vs_batchN_bitwise(self):
op = NativeMatmulOp()
a, b = _make_inputs(4, 16, 64, 32, seed=321)
full_out = op.forward_fp32(a, b)
for row in range(a.shape[0]):
single_out = op.forward_fp32(a[row : row + 1], b)
assert torch.equal(
full_out[row], single_out[0]
), f"Batch invariance broken at row {row}"

def test_batch_invariance_with_padding(self):
op = NativeMatmulOp()
a_valid, b = _make_inputs(2, 16, 64, 32, seed=456)
gen = torch.Generator().manual_seed(789)
padding = torch.randn(3, 16, 64, generator=gen)
a_padded = torch.cat([a_valid, padding], dim=0)
out_valid = op.forward_fp32(a_valid, b)
out_padded = op.forward_fp32(a_padded, b)
assert torch.equal(out_valid[0], out_padded[0])
assert torch.equal(out_valid[1], out_padded[1])

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cublas is not guarantee batch invariant, with larger size of matrix (M, K, N - Bigger K), batch-invariance may fail. Can you test on larger matrix?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should alse add this to make sure batch invariant in CPU.

@contextlib.contextmanager
def _single_thread():
    """Pin CPU GEMM to one thread so the matmul reduction order is batch-independent."""
    prev = torch.get_num_threads()
    torch.set_num_threads(1)
    try:
        yield
    finally:
        torch.set_num_threads(prev)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants