feat(ws1): Add PyTorch matmul reference operator#168
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
📝 WalkthroughWalkthroughAdds ChangesNativeMatmulOp Feature
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Flink-ddd
left a comment
There was a problem hiding this comment.
Here are some review comments, everything else is fine.
| assert torch.allclose(out_typed, out_fp32, atol=atol, rtol=rtol), ( | ||
| f"dtype={dtype}, max_abs_error={diff:.3e} exceeds " f"atol={atol}, rtol={rtol}" | ||
| ) | ||
|
|
There was a problem hiding this comment.
There are no backward pass tests (test_gradient_flows) in this PR. Matrix multiplication is the primary engine of the neural network and the most common source of gradient reduction drift.
You must verify that gradients flow correctly to both a (activations) and b (weights), and that they strictly adhere to the Axis-A Batch-Invariance contract:
For the activation gradient (a.grad), a slice of the full-batch gradient must be bitwise identical to the gradient computed from a sliced single-batch forward/backward pass.
For the weight gradient (b.grad), because it accumulates across the batch dimension, you should ideally verify that the full-batch weight gradient equals the sum of the single-batch weight gradients.
| def __init__(self) -> None: | ||
| pass | ||
|
|
||
| def __call__(self, a: Tensor, b: Tensor) -> Tensor: |
There was a problem hiding this comment.
As with the previous WS1 reference operators, NativeMatmulOp must inherit from torch.nn.Module to ensure upstream compatibility with Dynamo tracing and PyTorch hooks.
Please inherit from nn.Module, initialize super().init(), and remove the manually defined call method.
| class TestNativeMatmulOpBatchInvariance: | ||
| def test_batch1_vs_batchN_bitwise(self): | ||
| op = NativeMatmulOp() | ||
| a, b = _make_inputs(4, 16, 64, 32, seed=321) | ||
| full_out = op.forward_fp32(a, b) | ||
| for row in range(a.shape[0]): | ||
| single_out = op.forward_fp32(a[row : row + 1], b) | ||
| assert torch.equal( | ||
| full_out[row], single_out[0] | ||
| ), f"Batch invariance broken at row {row}" | ||
|
|
||
| def test_batch_invariance_with_padding(self): | ||
| op = NativeMatmulOp() | ||
| a_valid, b = _make_inputs(2, 16, 64, 32, seed=456) | ||
| gen = torch.Generator().manual_seed(789) | ||
| padding = torch.randn(3, 16, 64, generator=gen) | ||
| a_padded = torch.cat([a_valid, padding], dim=0) | ||
| out_valid = op.forward_fp32(a_valid, b) | ||
| out_padded = op.forward_fp32(a_padded, b) | ||
| assert torch.equal(out_valid[0], out_padded[0]) | ||
| assert torch.equal(out_valid[1], out_padded[1]) |
There was a problem hiding this comment.
cublas is not guarantee batch invariant, with larger size of matrix (M, K, N - Bigger K), batch-invariance may fail. Can you test on larger matrix?
There was a problem hiding this comment.
You should alse add this to make sure batch invariant in CPU.
@contextlib.contextmanager
def _single_thread():
"""Pin CPU GEMM to one thread so the matmul reduction order is batch-independent."""
prev = torch.get_num_threads()
torch.set_num_threads(1)
try:
yield
finally:
torch.set_num_threads(prev)
Summary
Adds the PyTorch reference GEMM/Matmul operator for Issue #108.
This implements
NativeMatmulOpas the fp32 ground-truth baseline for dense projectionmatmuls. The operator follows the frozen #108 interface contract:
forward_fp32(a, b)casts inputs to fp32 and callstorch.matmulonceop_class = "reduction"kernel_registry.get_op("matmul")Also adds operator documentation under
docs/operators/matmul.md.Implementation
rl_engine/kernels/ops/pytorch/linear/matmul.pyrl_engine/kernels/ops/pytorch/linear/__init__.pyPYTORCH_NATIVE_MATMULinrl_engine/kernels/registry.pytests/test_matmul.pydocs/operators/README.mdSummary by CodeRabbit
New Features
matmuloperator with a native PyTorch reference implementation.matmuldispatch on CPU, CUDA, and ROCm.Documentation
matmuloperator documentation page.matmul.Tests
matmultests covering correctness, dtype behavior (including FP32 casting), supported shapes, batch invariance, and registry integration.