Skip to content

[WS1][kernels] Batch-invariant deterministic GEMM (fwd + bwd)#180

Open
Flink-ddd wants to merge 4 commits into
mainfrom
feat/add-ws1-gemm
Open

[WS1][kernels] Batch-invariant deterministic GEMM (fwd + bwd)#180
Flink-ddd wants to merge 4 commits into
mainfrom
feat/add-ws1-gemm

Conversation

@Flink-ddd

@Flink-ddd Flink-ddd commented Jun 22, 2026

Copy link
Copy Markdown
Collaborator

Fixes #146

Purpose

In RL train–inference alignment, a token's logprob must not drift between rollout (vLLM-side, varying batch shapes) and training (Megatron-side). A major source of this drift is the GEMM: cuBLAS selects kernels by problem shape and may use split-K, so the K-reduction order — and therefore the low bits of the output changes with batch size, chunked-prefill splitting, and padding layout. The same row produces different results depending on the batch it rides in.

This PR implements a batch-invariant deterministic GEMM (forward + backward), one op in the WS1 forward chain (#146). A row's output is bitwise invariant to batch size, chunked-prefill splitting, and padding layout. Invariance is achieved by pinning the tile shape, fixing the K-accumulation order, and forbidding split-K / shape-based kernel selection. BF16 inputs, FP32 accumulation, no TF32.

This is PR2 of the planned series (design note → kernel → tests → LM-head wiring → benchmark). Scope here: the kernel(s), op wiring, and invariance tests.

Invariance contract. The same (M, N, K) row produces the same output regardless of the surrounding batch, chunked-prefill split, or padding. Two design points protect this:

  • M is padded to a multiple of the tile (BM) so every M — including M=1 and non-tile-aligned M — takes the same kernel. Selecting a kernel by M would itself break batch-invariance, since M is the batch dimension.
  • Single CTA per output tile, fixed ascending K loop. No split-K, no cross-CTA K reduction, no atomics.

Backends. CUDA (SM90 TMA + mma.sync tensor cores; naive FP32 scalar fallback below SM90), Triton (portable / ROCm fallback + cross-backend invariance eference), and a PyTorch reference (torch.matmul, intentionally non-deterministic — reference/benchmark only, deliberately excluded from dispatch, since a non-deterministic fallback would silently defeat the op's purpose).

Out of scope (per #146): tensor-parallel GEMM (WS2), FP8, and occupancy/throughput tuning of the tensor-core path — correctness and invariance first; a slower-than-cuBLAS deterministic baseline is the accepted first milestone. Throughput tuning will be a separate perf PR (see Follow-ups).

Test Plan

  1. Build: compile the _C extension with the SM90 path enabled (KERNEL_ALIGN_DET_GEMM_SM90=1) on H100 (SM90), CUDA 12.4.
  2. Single-tile correctness: SM90 TMA + mma.sync kernel vs an FP32 reference on a 64×64 tile, then on multi-tile shapes (M-tiles, long K), to validate the ldmatrix addressing, B-operand layout, and epilogue.
  3. Invariance (hard gate, bitwise via torch.equal):
    • forward batch-invariance (same row, different batch sizes),
    • chunked-prefill split == full GEMM,
    • padding rows do not affect valid rows,
    • backward dA batch-invariance,
      on both the CUDA tensor-core kernel and the Triton path.
  4. Correctness vs FP32: forward and backward, placeholder tolerances pending the [WS1] Ground-truth harness + numerical contract for batch-invariant ops #108 numerical contract.
  5. Target shapes: invariance across the 5 real projection shapes (QKV / o_proj / MLP up / MLP down / LM-head).
  6. Benchmark: overhead vs cuBLAS (TF32 disabled), as the fair baseline.

Test Result

Build: compiled and linked successfully with the SM90 TMA + mma.sync path enabled (CUDA 12.4, H100). det_gemm uses its own csrc/cuda/gemm/det_gemm_tma.cuh (shared::cluster); the shared csrc/utils/tma_utils.cuh emits shared::cta, which CUDA 12.4 ptxas rejects for cp.async.bulk.tensor — scoped the fix to this PR rather than touching the shared logp helpers (tracked separately in #). The KERNEL_ALIGN_DET_GEMM_SM90 build flag is independent of the fused_logp SM90 sources, which are left untouched.

Tests: pytest tests/test_det_gemm.py → 22/22 passed on H100 SXM (CUDA 12.4) — forward/backward batch-invariance, chunked-prefill,
padding-invariance (all bitwise), correctness vs FP32, and all 5 target
projection shapes, on both the CUDA tensor-core kernel and Triton.

-ltorch_cuda -o /tmp/tmptxrii14y.build-lib/rl_engine/_C.cpython-311-x86_64-linux-gnu.so -lcuda
===================================================================================== test session starts ======================================================================================
platform linux -- Python 3.11.10, pytest-9.1.1, pluggy-1.6.0 -- /usr/bin/python
cachedir: .pytest_cache
rootdir: /root/RL-Kernel
configfile: pyproject.toml
plugins: anyio-4.6.0
collected 22 items                                                                                                                                                                             

tests/test_det_gemm.py::test_forward_batch_invariance[cuda-deterministic_gemm] PASSED                                                                                                    [  4%]
tests/test_det_gemm.py::test_forward_batch_invariance[triton-deterministic_gemm_triton] PASSED                                                                                           [  9%]
tests/test_det_gemm.py::test_forward_chunked_prefill[cuda-deterministic_gemm] PASSED                                                                                                     [ 13%]
tests/test_det_gemm.py::test_forward_chunked_prefill[triton-deterministic_gemm_triton] PASSED                                                                                            [ 18%]
tests/test_det_gemm.py::test_forward_padding_invariance[cuda-deterministic_gemm] PASSED                                                                                                  [ 22%]
tests/test_det_gemm.py::test_forward_padding_invariance[triton-deterministic_gemm_triton] PASSED                                                                                         [ 27%]
tests/test_det_gemm.py::test_forward_correctness[cuda-deterministic_gemm] PASSED                                                                                                         [ 31%]
tests/test_det_gemm.py::test_forward_correctness[triton-deterministic_gemm_triton] PASSED                                                                                                [ 36%]
tests/test_det_gemm.py::test_backward_batch_invariance[cuda-deterministic_gemm] PASSED                                                                                                   [ 40%]
tests/test_det_gemm.py::test_backward_batch_invariance[triton-deterministic_gemm_triton] PASSED                                                                                          [ 45%]
tests/test_det_gemm.py::test_backward_correctness[cuda-deterministic_gemm] PASSED                                                                                                        [ 50%]
tests/test_det_gemm.py::test_backward_correctness[triton-deterministic_gemm_triton] PASSED                                                                                               [ 54%]
tests/test_det_gemm.py::test_target_shapes_invariance[shape0-cuda-deterministic_gemm] PASSED                                                                                             [ 59%]
tests/test_det_gemm.py::test_target_shapes_invariance[shape0-triton-deterministic_gemm_triton] PASSED                                                                                    [ 63%]
tests/test_det_gemm.py::test_target_shapes_invariance[shape1-cuda-deterministic_gemm] PASSED                                                                                             [ 68%]
tests/test_det_gemm.py::test_target_shapes_invariance[shape1-triton-deterministic_gemm_triton] PASSED                                                                                    [ 72%]
tests/test_det_gemm.py::test_target_shapes_invariance[shape2-cuda-deterministic_gemm] PASSED                                                                                             [ 77%]
tests/test_det_gemm.py::test_target_shapes_invariance[shape2-triton-deterministic_gemm_triton] PASSED                                                                                    [ 81%]
tests/test_det_gemm.py::test_target_shapes_invariance[shape3-cuda-deterministic_gemm] PASSED                                                                                             [ 86%]
tests/test_det_gemm.py::test_target_shapes_invariance[shape3-triton-deterministic_gemm_triton] PASSED                                                                                    [ 90%]
tests/test_det_gemm.py::test_target_shapes_invariance[shape4-cuda-deterministic_gemm] PASSED                                                                                             [ 95%]
tests/test_det_gemm.py::test_target_shapes_invariance[shape4-triton-deterministic_gemm_triton] PASSED                                                                                    [100%]

====================================================================================== 22 passed in 2.85s ======================================================================================

Benchmark (NVIDIA H100 80GB HBM3, SM90): overhead = det CUDA vs cuBLAS (TF32 disabled). Both deterministic paths trade speed for bitwise invariance by design (no split-K, fixed accumulation, FP32, no TF32).

shape M K N cuBLAS tf32 cuBLAS fp32 det CUDA det Triton overhead
qkv 4096 4096 12288 0.538 0.538 3.280 1.421 6.1x
o_proj 4096 4096 4096 0.190 0.190 1.164 0.478 6.1x
mlp_up 4096 4096 14336 0.656 0.704 3.800 1.688 5.4x
mlp_dn 4096 14336 4096 0.629 0.685 3.779 1.787 5.5x
lm_head 4096 4096 32000 1.513 1.528 8.269 3.897 5.4x

(The det CUDA path uses SM90 TMA + mma.sync tensor cores, 128×128 tile, single-CTA-per-tile, no split-K. Occupancy/throughput tuning is deferred per #146 and will be an ncu-driven perf PR; this milestone is correctness- and invariance-first.)

Files

  • csrc/cuda/gemm/det_gemm_kernel.cu — naive + SM90 TMA + mma.sync kernels, dispatch by compute capability (M padded to tile).
  • csrc/cuda/gemm/det_gemm_tma.cuh — det_gemm-local TMA / mbarrier primitives (shared::cluster).
  • csrc/ops.cpp — pybind registration (det_gemm_fwd / det_gemm_da / det_gemm_db).
  • rl_engine/kernels/ops/{cuda,triton,pytorch}/matmul/det_gemm.py — autograd wrappers + ops.
  • rl_engine/kernels/registry.py — det_gemm dispatch (CUDA primary, Triton fallback; PyTorch reference excluded).
  • tests/test_det_gemm.py — invariance + correctness (CUDA + Triton).
  • benchmarks/benchmark_det_gemm.py — overhead vs cuBLAS.
  • setup.py — det_gemm SM90 build flag; CUTLASS include removed; gencode fix.

Follow-ups

  • Perf PR (separate, ncu-driven): occupancy/throughput tuning of the SM90 tensor-core path to close the gap vs cuBLAS — to be tracked in a new perf issue.
  • PR3: replace placeholder test tolerances with the [WS1] Ground-truth harness + numerical contract for batch-invariant ops #108 threshold table once it lands.
  • PR4: wire one real projection (LM head) through the deterministic path.
  • PR5: benchmark doc — overhead + supported shapes.
  • Fix csrc/utils/tma_utils.cuh (shared::cta → shared::cluster) so fused_logp SM90 compiles on CUDA 12.4 — separate issue #.

Summary by CodeRabbit

  • New Features

    • Added a new deterministic GEMM option for CUDA workloads, with support for a native PyTorch baseline and an optional Triton-backed implementation.
    • Exposed the new GEMM capability through the operator interface and documentation, including usage guidance and backend selection.
    • Added a benchmark report and script to compare performance across common GEMM shapes.
  • Bug Fixes

    • Improved batch-size and padding invariance for GEMM results.
    • Added validation and test coverage for forward and backward correctness on supported CUDA devices.

@coderabbitai

coderabbitai Bot commented Jun 22, 2026

Copy link
Copy Markdown

Review Change Stack

📝 Walkthrough

Walkthrough

This PR adds a deterministic GEMM CUDA path with SM90 TMA support, Python/C++ bindings, Triton and native reference backends, registry wiring, build flags, docs, benchmark artifacts, and CUDA tests covering batch invariance and backward correctness.

Changes

Deterministic GEMM rollout

Layer / File(s) Summary
Kernel primitives and SM90 path
csrc/cuda/gemm/det_gemm_tma.cuh, csrc/cuda/gemm/det_gemm_kernel.cu
BF16 TMA helpers and deterministic GEMM kernels add the naive fixed-order path and the SM90 tiled path.
CUDA dispatch and C++ exports
csrc/cuda/gemm/det_gemm_kernel.cu, csrc/ops.cpp
The dispatch path pads M for batch invariance, selects SM90 or naive execution, and exposes forward/backward CUDA entrypoints through PyBind11.
CUDA Python op surface
rl_engine/kernels/ops/cuda/__init__.py, rl_engine/kernels/ops/cuda/matmul/__init__.py, rl_engine/kernels/ops/cuda/matmul/det_gemm.py, csrc/ops.cpp
The CUDA matmul package exports the deterministic op, and the wrapper calls the new CUDA entrypoints while enforcing BF16 CUDA inputs.
Native and Triton backends
rl_engine/kernels/ops/pytorch/matmul/__init__.py, rl_engine/kernels/ops/pytorch/matmul/det_gemm.py, rl_engine/kernels/ops/triton/matmul/__init__.py, rl_engine/kernels/ops/triton/matmul/det_gemm.py, rl_engine/kernels/registry.py
The reference matmul, Triton deterministic matmul, and registry backend enums and priority maps add alternate implementations and dispatch targets for det_gemm.
Build flags and extension sources
setup.py
The CUDA extension build includes the new det_gemm source and adds SM90-specific gencode, linking, and compiler definitions controlled by the new environment flags.
Docs, benchmark, and tests
docs/.nav.yml, docs/operators/det-gemm.md, benchmarks/benchmark_det_gemm.py, benchmarks/results/det_gemm_h100_tma.md, tests/test_det_gemm.py
The docs sidebar and det_gemm page describe the operator, the benchmark script and H100 result file report timings and overhead, and the CUDA tests cover batch invariance and gradient correctness across supported shapes.

Sequence Diagram(s)

sequenceDiagram
  participant DetGemmOp
  participant _DetGemmFn
  participant det_gemm_fwd
  participant gemm_dispatch
  participant det_gemm_da
  participant det_gemm_db
  DetGemmOp->>_DetGemmFn: apply(a, b)
  _DetGemmFn->>det_gemm_fwd: forward(a, b)
  det_gemm_fwd->>gemm_dispatch: choose SM90 or naive
  gemm_dispatch-->>det_gemm_fwd: C
  _DetGemmFn-->>DetGemmOp: C
  _DetGemmFn->>det_gemm_da: grad_out, b
  _DetGemmFn->>det_gemm_db: a, grad_out
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related issues

  • Issue 151: The new deterministic CUDA, Triton, and registry wiring matches the batch-invariant GEMM implementation scope.
  • Issue 153: The new backward entrypoints and gradient-invariance tests match the backward-consistency objective for deterministic GEMM.

Possibly related PRs

  • RL-Align/RL-Kernel#91: Its SM90 build-flag gating and extension-flag changes overlap with the new det_gemm SM90 build path.

Suggested labels

component: kernels

Suggested reviewers

  • bitborne
  • EthanZero2Hero
  • inaniloquentee

Poem

A rabbit hopped through kernels bright,
With fixed-order steps and BF16 light.
SM90 hummed, Triton chimed,
The batch stayed steady, row by row aligned.
🐇 Crunch! Determinism tastes just right.

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 11.11% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and concisely summarizes the main change: adding a batch-invariant deterministic GEMM for forward and backward paths.
Linked Issues check ✅ Passed The PR matches #146 by adding a fixed-order BF16/FP32 deterministic GEMM, validating batch and padding invariance, and covering forward/backward paths.
Out of Scope Changes check ✅ Passed The changes stay focused on the deterministic GEMM, its exports, tests, docs, benchmark, and required build wiring, with no obvious unrelated additions.
✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feat/add-ws1-gemm

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@Flink-ddd Flink-ddd changed the title WS1][kernels] Batch-invariant deterministic GEMM (fwd + bwd) [WS1][kernels] Batch-invariant deterministic GEMM (fwd + bwd) Jun 22, 2026
Signed-off-by: vensen <vensenmu@gmail.com>
@Flink-ddd Flink-ddd force-pushed the feat/add-ws1-gemm branch from e797bc9 to 90d120c Compare June 24, 2026 07:21
@Flink-ddd Flink-ddd marked this pull request as ready for review June 26, 2026 05:21

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 9

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
rl_engine/kernels/registry.py (1)

99-112: 🎯 Functional Correctness | 🟠 Major | ⚡ Quick win

det_gemm falls back to the logp backend on CPU.

CUDA and ROCm get explicit det_gemm entries, but CPU does not. Because get_op() defaults missing op types to OpBackend.PYTORCH_NATIVE, get_op("det_gemm") on CPU will try to instantiate NativeLogpOp instead of failing fast.

Suggested fix
             "cpu": {
                 "logp": [OpBackend.PYTORCH_NATIVE],
                 "attn": [OpBackend.PYTORCH_ATTN],
                 "grpo_loss": [OpBackend.PYTORCH_GRPO_LOSS],
                 "linear_logp": [OpBackend.PYTORCH_LINEAR_LOGP],
                 "ratio_kl": [OpBackend.PYTORCH_RATIO_KL],
+                "det_gemm": [],
             },
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/kernels/registry.py` around lines 99 - 112, The
`KernelRegistry.get_op()` fallback is incorrectly treating missing `det_gemm` as
`OpBackend.PYTORCH_NATIVE`, which routes CPU requests to `NativeLogpOp` instead
of failing or using a valid `det_gemm` backend. Update the registry/lookup logic
so `det_gemm` is explicitly handled for CPU in `KernelRegistry` (or excluded
from the generic native fallback), and ensure `get_op("det_gemm")` selects only
a real `det_gemm` implementation or raises a clear unsupported-op error.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@benchmarks/benchmark_det_gemm.py`:
- Around line 52-63: The benchmark currently starts timing in run() without
checking that the CUDA backend binding is actually available, so a None handle
can still fail later inside deterministic_gemm as NoneType.det_gemm_fwd. Add an
explicit preflight check before the SHAPES loop to verify the CUDA
wrapper/binding used by native_gemm and deterministic_gemm is initialized, and
if not, exit or skip with a clear message instead of letting _time() trigger the
opaque failure. If Triton is part of the path, gate deterministic_gemm_triton
the same way so the benchmark only runs when the required backend symbols are
present.

In `@csrc/cuda/gemm/det_gemm_kernel.cu`:
- Around line 230-238: The gemm dispatch path currently assumes the active CUDA
device matches the input tensors and does not reject mixed-GPU inputs. Update
check_in() and gemm_dispatch() to verify both tensors are on CUDA, are bf16, and
reside on the same device, then add an explicit device guard in gemm_dispatch()
so the allocation of c and the kernel launch via
at::cuda::getCurrentCUDAStream() are pinned to the inputs’ device. Use the
existing gemm_dispatch, check_in, and stream setup to place the guard and
same-device validation close to the launch site.

In `@docs/operators/det-gemm.md`:
- Around line 28-29: The CUDA backend description in the DetGemm documentation
is outdated and still describes a naive FP32 milestone with deferred tensor-core
work, while the merged implementation already uses the SM90 TMA plus mma.sync
path. Update the wording in the DetGemm operator docs table and any related CUDA
backend references so they accurately reflect the implemented DetGemmOp behavior
and current performance profile, keeping the TritonDetGemmOp description
consistent with the same section.

In `@rl_engine/kernels/ops/cuda/matmul/det_gemm.py`:
- Around line 16-20: The functional deterministic GEMM path is bypassing the
compiled-extension availability guard and can hit _C.det_gemm_fwd when _C is
None. Update _DetGemmFn.forward and deterministic_gemm to perform the same
availability check used by DetGemmOp.__call__ before touching _C, and raise the
intended RuntimeError instead of allowing an AttributeError to surface.

In `@rl_engine/kernels/ops/triton/matmul/det_gemm.py`:
- Around line 101-105: The backward path in det_gemm.backward still computes dB
from Aᵀ @ grad_out, so the reduction depends on the live batch dimension and can
vary with chunking/padding. Update the db calculation to use the WS1 invariant
reduction contract instead of directly reducing over the current batch layout,
and ensure the Triton backward path produces the same dW regardless of batch
size or layout changes. Keep the fix localized around backward and the
_triton_gemm calls in det_gemm.py.
- Around line 68-90: The _triton_gemm helper currently derives K from a.shape
and launches _det_gemm_kernel without verifying that b.shape[0] matches, so add
an explicit inner-dimension check before the kernel call and raise the standard
matmul shape error on mismatch. Keep the validation in _triton_gemm near the
existing M, K, N shape extraction so invalid inputs are rejected before any
Triton launch.

In `@setup.py`:
- Around line 64-71: The SM90 build path still emits a conflicting plain SM90
gencode when KERNEL_ALIGN_DET_GEMM_SM90 is enabled, so update the nvcc flag
assembly in setup.py to avoid adding the compute_90/sm_90 pair on SM90 machines.
Make the logic in the startup flag block and the later SM90 append path
consistent so that enabling KERNEL_ALIGN_DET_GEMM_SM90 alone does not require
also setting KERNEL_ALIGN_FORCE_SM90. Use the existing nvcc_flags construction
and the SM90-specific branch to ensure only the intended 90a gencode is present.

In `@tests/test_det_gemm.py`:
- Around line 25-33: The CUDA backend is being added to _BACKENDS even when
deterministic_gemm is not actually available, which causes crashes before the
assertions run. Update the backend selection near pytestmark/_BACKENDS to check
the wrapper’s real availability signal from deterministic_gemm before appending
the CUDA case, or convert it into a clear module-level skip when the extension
binding is None. Keep the Triton entry gated separately with _HAS_TRITON.
- Around line 78-85: The correctness assertions in test_forward_correctness and
the related tests still use temporary loose thresholds instead of the shared
`#108` contract. Update the checks in the gemm test cases to use the `#108`
correctness harness/thresholds rather than hardcoded placeholder max_abs limits,
and make sure the affected test functions (including test_forward_correctness
and the nearby cases in the same range) validate against the contract
consistently.

---

Outside diff comments:
In `@rl_engine/kernels/registry.py`:
- Around line 99-112: The `KernelRegistry.get_op()` fallback is incorrectly
treating missing `det_gemm` as `OpBackend.PYTORCH_NATIVE`, which routes CPU
requests to `NativeLogpOp` instead of failing or using a valid `det_gemm`
backend. Update the registry/lookup logic so `det_gemm` is explicitly handled
for CPU in `KernelRegistry` (or excluded from the generic native fallback), and
ensure `get_op("det_gemm")` selects only a real `det_gemm` implementation or
raises a clear unsupported-op error.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 0de95808-0f0e-49ba-ba17-f216e838e757

📥 Commits

Reviewing files that changed from the base of the PR and between be5ec9b and 70a60a5.

📒 Files selected for processing (17)
  • benchmarks/benchmark_det_gemm.py
  • benchmarks/results/det_gemm_h100_tma.md
  • csrc/cuda/gemm/det_gemm_kernel.cu
  • csrc/cuda/gemm/det_gemm_tma.cuh
  • csrc/ops.cpp
  • docs/.nav.yml
  • docs/operators/det-gemm.md
  • rl_engine/kernels/ops/cuda/__init__.py
  • rl_engine/kernels/ops/cuda/matmul/__init__.py
  • rl_engine/kernels/ops/cuda/matmul/det_gemm.py
  • rl_engine/kernels/ops/pytorch/matmul/__init__.py
  • rl_engine/kernels/ops/pytorch/matmul/det_gemm.py
  • rl_engine/kernels/ops/triton/matmul/__init__.py
  • rl_engine/kernels/ops/triton/matmul/det_gemm.py
  • rl_engine/kernels/registry.py
  • setup.py
  • tests/test_det_gemm.py

Comment on lines +52 to +63
def run():
rows = []
for name, M, K, N in SHAPES:
a = torch.randn(M, K, device=DEV, dtype=torch.bfloat16)
b = torch.randn(K, N, device=DEV, dtype=torch.bfloat16)
torch.backends.cuda.matmul.allow_tf32 = True
t_tf32 = _time(lambda x, y: torch.matmul(x, y), a, b)
torch.backends.cuda.matmul.allow_tf32 = False
t_fp32 = _time(native_gemm, a, b)
t_cuda = _time(deterministic_gemm, a, b)
t_tri = _time(deterministic_gemm_triton, a, b) if _HAS_TRITON else float("nan")
rows.append((name, M, K, N, t_tf32, t_fp32, t_cuda, t_tri, t_cuda / t_fp32))

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🟡 Minor | ⚡ Quick win

Preflight backend availability before timing.

GPU CI already shows the CUDA wrapper can import while its binding handle is still None, which then blows up as NoneType.det_gemm_fwd. In that state this benchmark will fail mid-run with the same opaque error. Add an explicit availability check up front and exit with a clear skip/error message instead of discovering it inside _time().

Also applies to: 88-98

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@benchmarks/benchmark_det_gemm.py` around lines 52 - 63, The benchmark
currently starts timing in run() without checking that the CUDA backend binding
is actually available, so a None handle can still fail later inside
deterministic_gemm as NoneType.det_gemm_fwd. Add an explicit preflight check
before the SHAPES loop to verify the CUDA wrapper/binding used by native_gemm
and deterministic_gemm is initialized, and if not, exit or skip with a clear
message instead of letting _time() trigger the opaque failure. If Triton is part
of the path, gate deterministic_gemm_triton the same way so the benchmark only
runs when the required backend symbols are present.

Source: Pipeline failures

Comment on lines +230 to +238
void check_in(const torch::Tensor& t, const char* n) {
TORCH_CHECK(t.is_cuda(), n, " must be CUDA");
TORCH_CHECK(t.scalar_type() == torch::kBFloat16, n, " must be bf16");
}

torch::Tensor gemm_dispatch(const torch::Tensor& a, const torch::Tensor& b) {
const int M = a.size(0), K = a.size(1), N = b.size(1);
auto c = torch::empty({M, N}, a.options());
auto stream = at::cuda::getCurrentCUDAStream();

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cat -n csrc/cuda/gemm/det_gemm_kernel.cu | head -50

Repository: RL-Align/RL-Kernel

Length of output: 2419


🏁 Script executed:

cat -n csrc/cuda/gemm/det_gemm_kernel.cu | sed -n '230,270p'

Repository: RL-Align/RL-Kernel

Length of output: 2085


Pin launches to the input tensors’ CUDA device and enforce same-device inputs.

gemm_dispatch() relies on at::cuda::getCurrentCUDAStream() without checking if the current device matches the tensors' device, and check_in() fails to reject inputs residing on different GPUs. In multi-GPU contexts, this can cause kernel launches on the wrong device or with mismatched pointers.

Add an explicit device guard and cross-check input devices:

Suggested fix
+#include <c10/cuda/CUDAGuard.h>
+
 void check_in(const torch::Tensor& t, const char* n) {
   TORCH_CHECK(t.is_cuda(), n, " must be CUDA");
   TORCH_CHECK(t.scalar_type() == torch::kBFloat16, n, " must be bf16");
 }
 
 torch::Tensor gemm_dispatch(const torch::Tensor& a, const torch::Tensor& b) {
+  TORCH_CHECK(a.get_device() == b.get_device(), "det_gemm: inputs must be on the same CUDA device");
+  const c10::cuda::CUDAGuard device_guard(a.device());
   const int M = a.size(0), K = a.size(1), N = b.size(1);
   auto c = torch::empty({M, N}, a.options());
   auto stream = at::cuda::getCurrentCUDAStream();
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
void check_in(const torch::Tensor& t, const char* n) {
TORCH_CHECK(t.is_cuda(), n, " must be CUDA");
TORCH_CHECK(t.scalar_type() == torch::kBFloat16, n, " must be bf16");
}
torch::Tensor gemm_dispatch(const torch::Tensor& a, const torch::Tensor& b) {
const int M = a.size(0), K = a.size(1), N = b.size(1);
auto c = torch::empty({M, N}, a.options());
auto stream = at::cuda::getCurrentCUDAStream();
`#include` <c10/cuda/CUDAGuard.h>
void check_in(const torch::Tensor& t, const char* n) {
TORCH_CHECK(t.is_cuda(), n, " must be CUDA");
TORCH_CHECK(t.scalar_type() == torch::kBFloat16, n, " must be bf16");
}
torch::Tensor gemm_dispatch(const torch::Tensor& a, const torch::Tensor& b) {
TORCH_CHECK(a.get_device() == b.get_device(), "det_gemm: inputs must be on the same CUDA device");
const c10::cuda::CUDAGuard device_guard(a.device());
const int M = a.size(0), K = a.size(1), N = b.size(1);
auto c = torch::empty({M, N}, a.options());
auto stream = at::cuda::getCurrentCUDAStream();
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@csrc/cuda/gemm/det_gemm_kernel.cu` around lines 230 - 238, The gemm dispatch
path currently assumes the active CUDA device matches the input tensors and does
not reject mixed-GPU inputs. Update check_in() and gemm_dispatch() to verify
both tensors are on CUDA, are bf16, and reside on the same device, then add an
explicit device guard in gemm_dispatch() so the allocation of c and the kernel
launch via at::cuda::getCurrentCUDAStream() are pinned to the inputs’ device.
Use the existing gemm_dispatch, check_in, and stream setup to place the guard
and same-device validation close to the launch site.

Comment on lines +28 to +29
| CUDA (`DetGemmOp`) | yes | Hand-written kernel. First milestone is a naive FP32 implementation (correctness first); a tensor-core (`mma.sync`) pass matching `prefix_shared_attention.cu` follows. NVIDIA SM80+. |
| Triton (`TritonDetGemmOp`) | yes | Autotune disabled, BLOCK pinned, no split-K. Portable / ROCm fallback and cross-backend reference. |

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win

Update the CUDA backend description to match the implementation.

This page still says the CUDA path is a naive FP32 first milestone with tensor-core work deferred, but the benchmark code/results in this PR describe an SM90 TMA + mma.sync path that already exists. That mismatch will mislead readers about what was actually merged and the expected performance profile.

Also applies to: 50-53

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@docs/operators/det-gemm.md` around lines 28 - 29, The CUDA backend
description in the DetGemm documentation is outdated and still describes a naive
FP32 milestone with deferred tensor-core work, while the merged implementation
already uses the SM90 TMA plus mma.sync path. Update the wording in the DetGemm
operator docs table and any related CUDA backend references so they accurately
reflect the implemented DetGemmOp behavior and current performance profile,
keeping the TritonDetGemmOp description consistent with the same section.

Comment on lines +16 to +20
class _DetGemmFn(torch.autograd.Function):
@staticmethod
def forward(ctx, a, b):
ctx.save_for_backward(a, b)
return _C.det_gemm_fwd(a, b)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🟠 Major | ⚡ Quick win

Don’t bypass the compiled-extension availability check.

DetGemmOp.__call__ raises a clear RuntimeError when _C is unavailable, but both _DetGemmFn.forward() and deterministic_gemm() bypass that guard. That matches the reported CI failure: the functional path reaches _C.det_gemm_fwd(...) with _C is None and crashes with AttributeError instead of the intended explicit error.

Suggested fix
 class _DetGemmFn(torch.autograd.Function):
     `@staticmethod`
     def forward(ctx, a, b):
+        if not (_EXT_AVAILABLE and _C is not None and hasattr(_C, "det_gemm_fwd")):
+            raise RuntimeError(
+                "DetGemmOp: compiled _C.det_gemm kernel unavailable; no "
+                "batch-invariant fallback exists. Build the extension first."
+            )
         ctx.save_for_backward(a, b)
         return _C.det_gemm_fwd(a, b)

Also applies to: 57-59

🧰 Tools
🪛 GitHub Actions: GPU CI / 0_gpu-tests.txt

[error] 20-20: AttributeError in deterministic GEMM CUDA path: '_C' is None, so '_C.det_gemm_fwd(a, b)' fails with "AttributeError: 'NoneType' object has no attribute 'det_gemm_fwd'".

🪛 GitHub Actions: GPU CI / gpu-tests

[error] 20-20: AttributeError in CUDA deterministic GEMM forward: '_C' is None and has no attribute 'det_gemm_fwd' (raised at _C.det_gemm_fwd(a, b)).

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/kernels/ops/cuda/matmul/det_gemm.py` around lines 16 - 20, The
functional deterministic GEMM path is bypassing the compiled-extension
availability guard and can hit _C.det_gemm_fwd when _C is None. Update
_DetGemmFn.forward and deterministic_gemm to perform the same availability check
used by DetGemmOp.__call__ before touching _C, and raise the intended
RuntimeError instead of allowing an AttributeError to surface.

Source: Pipeline failures

Comment on lines +68 to +90
def _triton_gemm(a, b):
a, b = a.contiguous(), b.contiguous()
M, K = a.shape
_, N = b.shape
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
grid = (triton.cdiv(M, _BLOCK_M), triton.cdiv(N, _BLOCK_N))
_det_gemm_kernel[grid](
a,
b,
c,
M,
N,
K,
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
BLOCK_M=_BLOCK_M,
BLOCK_N=_BLOCK_N,
BLOCK_K=_BLOCK_K,
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🔴 Critical | ⚡ Quick win

Validate inner dimensions before launching the Triton kernel.

Line 70 takes K from a.shape and never checks b.shape[0] == K. On a mismatch, the kernel masks loads with a’s K, so it can walk past b instead of raising the usual matmul shape error.

Suggested fix
+def _validate_det_gemm_inputs(a, b):
+    if a.ndim != 2 or b.ndim != 2:
+        raise ValueError("det_gemm expects 2D tensors")
+    if a.shape[1] != b.shape[0]:
+        raise ValueError(f"incompatible shapes: {tuple(a.shape)} and {tuple(b.shape)}")
+
 def _triton_gemm(a, b):
+    _validate_det_gemm_inputs(a, b)
     a, b = a.contiguous(), b.contiguous()
     M, K = a.shape
     _, N = b.shape
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _triton_gemm(a, b):
a, b = a.contiguous(), b.contiguous()
M, K = a.shape
_, N = b.shape
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
grid = (triton.cdiv(M, _BLOCK_M), triton.cdiv(N, _BLOCK_N))
_det_gemm_kernel[grid](
a,
b,
c,
M,
N,
K,
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
BLOCK_M=_BLOCK_M,
BLOCK_N=_BLOCK_N,
BLOCK_K=_BLOCK_K,
)
def _validate_det_gemm_inputs(a, b):
if a.ndim != 2 or b.ndim != 2:
raise ValueError("det_gemm expects 2D tensors")
if a.shape[1] != b.shape[0]:
raise ValueError(f"incompatible shapes: {tuple(a.shape)} and {tuple(b.shape)}")
def _triton_gemm(a, b):
_validate_det_gemm_inputs(a, b)
a, b = a.contiguous(), b.contiguous()
M, K = a.shape
_, N = b.shape
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
grid = (triton.cdiv(M, _BLOCK_M), triton.cdiv(N, _BLOCK_N))
_det_gemm_kernel[grid](
a,
b,
c,
M,
N,
K,
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
BLOCK_M=_BLOCK_M,
BLOCK_N=_BLOCK_N,
BLOCK_K=_BLOCK_K,
)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/kernels/ops/triton/matmul/det_gemm.py` around lines 68 - 90, The
_triton_gemm helper currently derives K from a.shape and launches
_det_gemm_kernel without verifying that b.shape[0] matches, so add an explicit
inner-dimension check before the kernel call and raise the standard matmul shape
error on mismatch. Keep the validation in _triton_gemm near the existing M, K, N
shape extraction so invalid inputs are rejected before any Triton launch.

Comment on lines +101 to +105
def backward(ctx, grad_out):
a, b = ctx.saved_tensors
grad_out = grad_out.contiguous()
da = _triton_gemm(grad_out, b.t().contiguous()) if ctx.needs_input_grad[0] else None
db = _triton_gemm(a.t().contiguous(), grad_out) if ctx.needs_input_grad[1] else None

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟠 Major | 🏗️ Heavy lift

dB is still reduced over the live batch dimension.

Line 105 computes db as Aᵀ @ grad_out, which makes the reduction axis the original M. That means the Triton backward path can still change with batch/chunk/padding layout, so it does not meet the WS1 dW invariance contract yet. As per the PR objectives, backward dW must remain invariant across batch-size and layout changes.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/kernels/ops/triton/matmul/det_gemm.py` around lines 101 - 105, The
backward path in det_gemm.backward still computes dB from Aᵀ @ grad_out, so the
reduction depends on the live batch dimension and can vary with
chunking/padding. Update the db calculation to use the WS1 invariant reduction
contract instead of directly reducing over the current batch layout, and ensure
the Triton backward path produces the same dW regardless of batch size or layout
changes. Keep the fix localized around backward and the _triton_gemm calls in
det_gemm.py.

Comment thread setup.py
Comment on lines 64 to +71
cc_major, cc_minor = torch.cuda.get_device_capability()
enable_sm90 = os.environ.get("KERNEL_ALIGN_FORCE_SM90") == "1"
nvcc_flags = ["-O3", "--use_fast_math", "-Xfatbin", "-compress-all"]
if not enable_sm90:
# SM90 build emits 90a below; mixing plain compute_90 breaks TMA ptxas.
nvcc_flags.append(
f"-gencode=arch=compute_{cc_major}{cc_minor},code=sm_{cc_major}{cc_minor}"
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🟠 Major | ⚡ Quick win

KERNEL_ALIGN_DET_GEMM_SM90 still leaves the conflicting plain SM90 gencode enabled.

When KERNEL_ALIGN_DET_GEMM_SM90=1 is set on an SM90 machine, Lines 67-71 still add -gencode=arch=compute_90,code=sm_90, and Lines 138-145 then append compute_90a/sm_90a. The comment on Line 68 says that mix breaks TMA ptxas, so this new path can still fail unless callers also remember to set the separate KERNEL_ALIGN_FORCE_SM90 flag.

Suggested fix
         cc_major, cc_minor = torch.cuda.get_device_capability()
         enable_sm90 = os.environ.get("KERNEL_ALIGN_FORCE_SM90") == "1"
+        enable_det_gemm_sm90 = os.environ.get("KERNEL_ALIGN_DET_GEMM_SM90") == "1"
         nvcc_flags = ["-O3", "--use_fast_math", "-Xfatbin", "-compress-all"]
-        if not enable_sm90:
+        if not (enable_sm90 or enable_det_gemm_sm90):
             # SM90 build emits 90a below; mixing plain compute_90 breaks TMA ptxas.
             nvcc_flags.append(
                 f"-gencode=arch=compute_{cc_major}{cc_minor},code=sm_{cc_major}{cc_minor}"
             )
@@
-        enable_det_gemm_sm90 = os.environ.get("KERNEL_ALIGN_DET_GEMM_SM90") == "1"
         if enable_det_gemm_sm90:

Also applies to: 137-145

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@setup.py` around lines 64 - 71, The SM90 build path still emits a conflicting
plain SM90 gencode when KERNEL_ALIGN_DET_GEMM_SM90 is enabled, so update the
nvcc flag assembly in setup.py to avoid adding the compute_90/sm_90 pair on SM90
machines. Make the logic in the startup flag block and the later SM90 append
path consistent so that enabling KERNEL_ALIGN_DET_GEMM_SM90 alone does not
require also setting KERNEL_ALIGN_FORCE_SM90. Use the existing nvcc_flags
construction and the SM90-specific branch to ensure only the intended 90a
gencode is present.

Comment thread tests/test_det_gemm.py
Comment on lines +25 to +33
pytestmark = pytest.mark.skipif(
not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 8,
reason="det_gemm requires CUDA SM80+",
)

# Each deterministic backend is validated independently.
_BACKENDS = [("cuda", deterministic_gemm)]
if _HAS_TRITON:
_BACKENDS.append(("triton", deterministic_gemm_triton))

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🟠 Major | ⚡ Quick win

Don't parametrize the CUDA backend unless the extension actually loaded.

The current skip only checks for CUDA + SM80, but GPU CI shows deterministic_gemm is still callable in a state where its backing binding is None, so every CUDA case crashes before the invariance assertions run. Gate _BACKENDS on the wrapper's real availability signal (or fail fast with a clear module-level skip) instead of assuming import success means the op is usable.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/test_det_gemm.py` around lines 25 - 33, The CUDA backend is being added
to _BACKENDS even when deterministic_gemm is not actually available, which
causes crashes before the assertions run. Update the backend selection near
pytestmark/_BACKENDS to check the wrapper’s real availability signal from
deterministic_gemm before appending the CUDA case, or convert it into a clear
module-level skip when the extension binding is None. Keep the Triton entry
gated separately with _HAS_TRITON.

Source: Pipeline failures

Comment thread tests/test_det_gemm.py
Comment on lines +78 to +85
def test_forward_correctness(name, gemm):
# vs FP32 reference. Placeholder tolerance; PR3 swaps for #108 contract.
torch.manual_seed(3)
M, K, N = 128, 2048, 2048
a, b = _rand(M, K), _rand(K, N)
out = gemm(a, b).float()
ref = a.float() @ b.float()
assert (out - ref).abs().max().item() < 1.0 # TODO(#108): contract threshold

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟠 Major | 🏗️ Heavy lift

Replace the placeholder tolerances with the #108 correctness contract before merge.

These max_abs < 1.0 / < 2.0 checks are explicitly temporary, and they are loose enough to hide material numeric regressions while still passing CI. The linked objective for #146 calls out validation through the shared #108 harness, so leaving placeholder thresholds here means the correctness part of the contract is still unverified.

Also applies to: 105-117

🧰 Tools
🪛 GitHub Actions: GPU CI / 0_gpu-tests.txt

[error] 83-83: test_forward_correctness[cuda-deterministic_gemm] failed because gemm(a, b) raised the AttributeError: 'NoneType' object has no attribute 'det_gemm_fwd'.

🪛 GitHub Actions: GPU CI / gpu-tests

[error] 83-83: test_forward_correctness[cuda-deterministic_gemm] failed when calling gemm(): AttributeError: 'NoneType' object has no attribute 'det_gemm_fwd'.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/test_det_gemm.py` around lines 78 - 85, The correctness assertions in
test_forward_correctness and the related tests still use temporary loose
thresholds instead of the shared `#108` contract. Update the checks in the gemm
test cases to use the `#108` correctness harness/thresholds rather than hardcoded
placeholder max_abs limits, and make sure the affected test functions (including
test_forward_correctness and the nearby cases in the same range) validate
against the contract consistently.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[WS1][kernels] Batch-invariant matmul / GEMM

1 participant