[WS1][kernels] Batch-invariant deterministic GEMM (fwd + bwd)#180
[WS1][kernels] Batch-invariant deterministic GEMM (fwd + bwd)#180Flink-ddd wants to merge 4 commits into
Conversation
📝 WalkthroughWalkthroughThis PR adds a deterministic GEMM CUDA path with SM90 TMA support, Python/C++ bindings, Triton and native reference backends, registry wiring, build flags, docs, benchmark artifacts, and CUDA tests covering batch invariance and backward correctness. ChangesDeterministic GEMM rollout
Sequence Diagram(s)sequenceDiagram
participant DetGemmOp
participant _DetGemmFn
participant det_gemm_fwd
participant gemm_dispatch
participant det_gemm_da
participant det_gemm_db
DetGemmOp->>_DetGemmFn: apply(a, b)
_DetGemmFn->>det_gemm_fwd: forward(a, b)
det_gemm_fwd->>gemm_dispatch: choose SM90 or naive
gemm_dispatch-->>det_gemm_fwd: C
_DetGemmFn-->>DetGemmOp: C
_DetGemmFn->>det_gemm_da: grad_out, b
_DetGemmFn->>det_gemm_db: a, grad_out
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related issues
Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Signed-off-by: vensen <vensenmu@gmail.com>
e797bc9 to
90d120c
Compare
There was a problem hiding this comment.
Actionable comments posted: 9
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
rl_engine/kernels/registry.py (1)
99-112: 🎯 Functional Correctness | 🟠 Major | ⚡ Quick win
det_gemmfalls back to the logp backend on CPU.CUDA and ROCm get explicit
det_gemmentries, but CPU does not. Becauseget_op()defaults missing op types toOpBackend.PYTORCH_NATIVE,get_op("det_gemm")on CPU will try to instantiateNativeLogpOpinstead of failing fast.Suggested fix
"cpu": { "logp": [OpBackend.PYTORCH_NATIVE], "attn": [OpBackend.PYTORCH_ATTN], "grpo_loss": [OpBackend.PYTORCH_GRPO_LOSS], "linear_logp": [OpBackend.PYTORCH_LINEAR_LOGP], "ratio_kl": [OpBackend.PYTORCH_RATIO_KL], + "det_gemm": [], },🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@rl_engine/kernels/registry.py` around lines 99 - 112, The `KernelRegistry.get_op()` fallback is incorrectly treating missing `det_gemm` as `OpBackend.PYTORCH_NATIVE`, which routes CPU requests to `NativeLogpOp` instead of failing or using a valid `det_gemm` backend. Update the registry/lookup logic so `det_gemm` is explicitly handled for CPU in `KernelRegistry` (or excluded from the generic native fallback), and ensure `get_op("det_gemm")` selects only a real `det_gemm` implementation or raises a clear unsupported-op error.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@benchmarks/benchmark_det_gemm.py`:
- Around line 52-63: The benchmark currently starts timing in run() without
checking that the CUDA backend binding is actually available, so a None handle
can still fail later inside deterministic_gemm as NoneType.det_gemm_fwd. Add an
explicit preflight check before the SHAPES loop to verify the CUDA
wrapper/binding used by native_gemm and deterministic_gemm is initialized, and
if not, exit or skip with a clear message instead of letting _time() trigger the
opaque failure. If Triton is part of the path, gate deterministic_gemm_triton
the same way so the benchmark only runs when the required backend symbols are
present.
In `@csrc/cuda/gemm/det_gemm_kernel.cu`:
- Around line 230-238: The gemm dispatch path currently assumes the active CUDA
device matches the input tensors and does not reject mixed-GPU inputs. Update
check_in() and gemm_dispatch() to verify both tensors are on CUDA, are bf16, and
reside on the same device, then add an explicit device guard in gemm_dispatch()
so the allocation of c and the kernel launch via
at::cuda::getCurrentCUDAStream() are pinned to the inputs’ device. Use the
existing gemm_dispatch, check_in, and stream setup to place the guard and
same-device validation close to the launch site.
In `@docs/operators/det-gemm.md`:
- Around line 28-29: The CUDA backend description in the DetGemm documentation
is outdated and still describes a naive FP32 milestone with deferred tensor-core
work, while the merged implementation already uses the SM90 TMA plus mma.sync
path. Update the wording in the DetGemm operator docs table and any related CUDA
backend references so they accurately reflect the implemented DetGemmOp behavior
and current performance profile, keeping the TritonDetGemmOp description
consistent with the same section.
In `@rl_engine/kernels/ops/cuda/matmul/det_gemm.py`:
- Around line 16-20: The functional deterministic GEMM path is bypassing the
compiled-extension availability guard and can hit _C.det_gemm_fwd when _C is
None. Update _DetGemmFn.forward and deterministic_gemm to perform the same
availability check used by DetGemmOp.__call__ before touching _C, and raise the
intended RuntimeError instead of allowing an AttributeError to surface.
In `@rl_engine/kernels/ops/triton/matmul/det_gemm.py`:
- Around line 101-105: The backward path in det_gemm.backward still computes dB
from Aᵀ @ grad_out, so the reduction depends on the live batch dimension and can
vary with chunking/padding. Update the db calculation to use the WS1 invariant
reduction contract instead of directly reducing over the current batch layout,
and ensure the Triton backward path produces the same dW regardless of batch
size or layout changes. Keep the fix localized around backward and the
_triton_gemm calls in det_gemm.py.
- Around line 68-90: The _triton_gemm helper currently derives K from a.shape
and launches _det_gemm_kernel without verifying that b.shape[0] matches, so add
an explicit inner-dimension check before the kernel call and raise the standard
matmul shape error on mismatch. Keep the validation in _triton_gemm near the
existing M, K, N shape extraction so invalid inputs are rejected before any
Triton launch.
In `@setup.py`:
- Around line 64-71: The SM90 build path still emits a conflicting plain SM90
gencode when KERNEL_ALIGN_DET_GEMM_SM90 is enabled, so update the nvcc flag
assembly in setup.py to avoid adding the compute_90/sm_90 pair on SM90 machines.
Make the logic in the startup flag block and the later SM90 append path
consistent so that enabling KERNEL_ALIGN_DET_GEMM_SM90 alone does not require
also setting KERNEL_ALIGN_FORCE_SM90. Use the existing nvcc_flags construction
and the SM90-specific branch to ensure only the intended 90a gencode is present.
In `@tests/test_det_gemm.py`:
- Around line 25-33: The CUDA backend is being added to _BACKENDS even when
deterministic_gemm is not actually available, which causes crashes before the
assertions run. Update the backend selection near pytestmark/_BACKENDS to check
the wrapper’s real availability signal from deterministic_gemm before appending
the CUDA case, or convert it into a clear module-level skip when the extension
binding is None. Keep the Triton entry gated separately with _HAS_TRITON.
- Around line 78-85: The correctness assertions in test_forward_correctness and
the related tests still use temporary loose thresholds instead of the shared
`#108` contract. Update the checks in the gemm test cases to use the `#108`
correctness harness/thresholds rather than hardcoded placeholder max_abs limits,
and make sure the affected test functions (including test_forward_correctness
and the nearby cases in the same range) validate against the contract
consistently.
---
Outside diff comments:
In `@rl_engine/kernels/registry.py`:
- Around line 99-112: The `KernelRegistry.get_op()` fallback is incorrectly
treating missing `det_gemm` as `OpBackend.PYTORCH_NATIVE`, which routes CPU
requests to `NativeLogpOp` instead of failing or using a valid `det_gemm`
backend. Update the registry/lookup logic so `det_gemm` is explicitly handled
for CPU in `KernelRegistry` (or excluded from the generic native fallback), and
ensure `get_op("det_gemm")` selects only a real `det_gemm` implementation or
raises a clear unsupported-op error.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 0de95808-0f0e-49ba-ba17-f216e838e757
📒 Files selected for processing (17)
benchmarks/benchmark_det_gemm.pybenchmarks/results/det_gemm_h100_tma.mdcsrc/cuda/gemm/det_gemm_kernel.cucsrc/cuda/gemm/det_gemm_tma.cuhcsrc/ops.cppdocs/.nav.ymldocs/operators/det-gemm.mdrl_engine/kernels/ops/cuda/__init__.pyrl_engine/kernels/ops/cuda/matmul/__init__.pyrl_engine/kernels/ops/cuda/matmul/det_gemm.pyrl_engine/kernels/ops/pytorch/matmul/__init__.pyrl_engine/kernels/ops/pytorch/matmul/det_gemm.pyrl_engine/kernels/ops/triton/matmul/__init__.pyrl_engine/kernels/ops/triton/matmul/det_gemm.pyrl_engine/kernels/registry.pysetup.pytests/test_det_gemm.py
| def run(): | ||
| rows = [] | ||
| for name, M, K, N in SHAPES: | ||
| a = torch.randn(M, K, device=DEV, dtype=torch.bfloat16) | ||
| b = torch.randn(K, N, device=DEV, dtype=torch.bfloat16) | ||
| torch.backends.cuda.matmul.allow_tf32 = True | ||
| t_tf32 = _time(lambda x, y: torch.matmul(x, y), a, b) | ||
| torch.backends.cuda.matmul.allow_tf32 = False | ||
| t_fp32 = _time(native_gemm, a, b) | ||
| t_cuda = _time(deterministic_gemm, a, b) | ||
| t_tri = _time(deterministic_gemm_triton, a, b) if _HAS_TRITON else float("nan") | ||
| rows.append((name, M, K, N, t_tf32, t_fp32, t_cuda, t_tri, t_cuda / t_fp32)) |
There was a problem hiding this comment.
🩺 Stability & Availability | 🟡 Minor | ⚡ Quick win
Preflight backend availability before timing.
GPU CI already shows the CUDA wrapper can import while its binding handle is still None, which then blows up as NoneType.det_gemm_fwd. In that state this benchmark will fail mid-run with the same opaque error. Add an explicit availability check up front and exit with a clear skip/error message instead of discovering it inside _time().
Also applies to: 88-98
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@benchmarks/benchmark_det_gemm.py` around lines 52 - 63, The benchmark
currently starts timing in run() without checking that the CUDA backend binding
is actually available, so a None handle can still fail later inside
deterministic_gemm as NoneType.det_gemm_fwd. Add an explicit preflight check
before the SHAPES loop to verify the CUDA wrapper/binding used by native_gemm
and deterministic_gemm is initialized, and if not, exit or skip with a clear
message instead of letting _time() trigger the opaque failure. If Triton is part
of the path, gate deterministic_gemm_triton the same way so the benchmark only
runs when the required backend symbols are present.
Source: Pipeline failures
| void check_in(const torch::Tensor& t, const char* n) { | ||
| TORCH_CHECK(t.is_cuda(), n, " must be CUDA"); | ||
| TORCH_CHECK(t.scalar_type() == torch::kBFloat16, n, " must be bf16"); | ||
| } | ||
|
|
||
| torch::Tensor gemm_dispatch(const torch::Tensor& a, const torch::Tensor& b) { | ||
| const int M = a.size(0), K = a.size(1), N = b.size(1); | ||
| auto c = torch::empty({M, N}, a.options()); | ||
| auto stream = at::cuda::getCurrentCUDAStream(); |
There was a problem hiding this comment.
🩺 Stability & Availability | 🟠 Major
🧩 Analysis chain
🏁 Script executed:
cat -n csrc/cuda/gemm/det_gemm_kernel.cu | head -50Repository: RL-Align/RL-Kernel
Length of output: 2419
🏁 Script executed:
cat -n csrc/cuda/gemm/det_gemm_kernel.cu | sed -n '230,270p'Repository: RL-Align/RL-Kernel
Length of output: 2085
Pin launches to the input tensors’ CUDA device and enforce same-device inputs.
gemm_dispatch() relies on at::cuda::getCurrentCUDAStream() without checking if the current device matches the tensors' device, and check_in() fails to reject inputs residing on different GPUs. In multi-GPU contexts, this can cause kernel launches on the wrong device or with mismatched pointers.
Add an explicit device guard and cross-check input devices:
Suggested fix
+#include <c10/cuda/CUDAGuard.h>
+
void check_in(const torch::Tensor& t, const char* n) {
TORCH_CHECK(t.is_cuda(), n, " must be CUDA");
TORCH_CHECK(t.scalar_type() == torch::kBFloat16, n, " must be bf16");
}
torch::Tensor gemm_dispatch(const torch::Tensor& a, const torch::Tensor& b) {
+ TORCH_CHECK(a.get_device() == b.get_device(), "det_gemm: inputs must be on the same CUDA device");
+ const c10::cuda::CUDAGuard device_guard(a.device());
const int M = a.size(0), K = a.size(1), N = b.size(1);
auto c = torch::empty({M, N}, a.options());
auto stream = at::cuda::getCurrentCUDAStream();📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| void check_in(const torch::Tensor& t, const char* n) { | |
| TORCH_CHECK(t.is_cuda(), n, " must be CUDA"); | |
| TORCH_CHECK(t.scalar_type() == torch::kBFloat16, n, " must be bf16"); | |
| } | |
| torch::Tensor gemm_dispatch(const torch::Tensor& a, const torch::Tensor& b) { | |
| const int M = a.size(0), K = a.size(1), N = b.size(1); | |
| auto c = torch::empty({M, N}, a.options()); | |
| auto stream = at::cuda::getCurrentCUDAStream(); | |
| `#include` <c10/cuda/CUDAGuard.h> | |
| void check_in(const torch::Tensor& t, const char* n) { | |
| TORCH_CHECK(t.is_cuda(), n, " must be CUDA"); | |
| TORCH_CHECK(t.scalar_type() == torch::kBFloat16, n, " must be bf16"); | |
| } | |
| torch::Tensor gemm_dispatch(const torch::Tensor& a, const torch::Tensor& b) { | |
| TORCH_CHECK(a.get_device() == b.get_device(), "det_gemm: inputs must be on the same CUDA device"); | |
| const c10::cuda::CUDAGuard device_guard(a.device()); | |
| const int M = a.size(0), K = a.size(1), N = b.size(1); | |
| auto c = torch::empty({M, N}, a.options()); | |
| auto stream = at::cuda::getCurrentCUDAStream(); |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@csrc/cuda/gemm/det_gemm_kernel.cu` around lines 230 - 238, The gemm dispatch
path currently assumes the active CUDA device matches the input tensors and does
not reject mixed-GPU inputs. Update check_in() and gemm_dispatch() to verify
both tensors are on CUDA, are bf16, and reside on the same device, then add an
explicit device guard in gemm_dispatch() so the allocation of c and the kernel
launch via at::cuda::getCurrentCUDAStream() are pinned to the inputs’ device.
Use the existing gemm_dispatch, check_in, and stream setup to place the guard
and same-device validation close to the launch site.
| | CUDA (`DetGemmOp`) | yes | Hand-written kernel. First milestone is a naive FP32 implementation (correctness first); a tensor-core (`mma.sync`) pass matching `prefix_shared_attention.cu` follows. NVIDIA SM80+. | | ||
| | Triton (`TritonDetGemmOp`) | yes | Autotune disabled, BLOCK pinned, no split-K. Portable / ROCm fallback and cross-backend reference. | |
There was a problem hiding this comment.
📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win
Update the CUDA backend description to match the implementation.
This page still says the CUDA path is a naive FP32 first milestone with tensor-core work deferred, but the benchmark code/results in this PR describe an SM90 TMA + mma.sync path that already exists. That mismatch will mislead readers about what was actually merged and the expected performance profile.
Also applies to: 50-53
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@docs/operators/det-gemm.md` around lines 28 - 29, The CUDA backend
description in the DetGemm documentation is outdated and still describes a naive
FP32 milestone with deferred tensor-core work, while the merged implementation
already uses the SM90 TMA plus mma.sync path. Update the wording in the DetGemm
operator docs table and any related CUDA backend references so they accurately
reflect the implemented DetGemmOp behavior and current performance profile,
keeping the TritonDetGemmOp description consistent with the same section.
| class _DetGemmFn(torch.autograd.Function): | ||
| @staticmethod | ||
| def forward(ctx, a, b): | ||
| ctx.save_for_backward(a, b) | ||
| return _C.det_gemm_fwd(a, b) |
There was a problem hiding this comment.
🩺 Stability & Availability | 🟠 Major | ⚡ Quick win
Don’t bypass the compiled-extension availability check.
DetGemmOp.__call__ raises a clear RuntimeError when _C is unavailable, but both _DetGemmFn.forward() and deterministic_gemm() bypass that guard. That matches the reported CI failure: the functional path reaches _C.det_gemm_fwd(...) with _C is None and crashes with AttributeError instead of the intended explicit error.
Suggested fix
class _DetGemmFn(torch.autograd.Function):
`@staticmethod`
def forward(ctx, a, b):
+ if not (_EXT_AVAILABLE and _C is not None and hasattr(_C, "det_gemm_fwd")):
+ raise RuntimeError(
+ "DetGemmOp: compiled _C.det_gemm kernel unavailable; no "
+ "batch-invariant fallback exists. Build the extension first."
+ )
ctx.save_for_backward(a, b)
return _C.det_gemm_fwd(a, b)Also applies to: 57-59
🧰 Tools
🪛 GitHub Actions: GPU CI / 0_gpu-tests.txt
[error] 20-20: AttributeError in deterministic GEMM CUDA path: '_C' is None, so '_C.det_gemm_fwd(a, b)' fails with "AttributeError: 'NoneType' object has no attribute 'det_gemm_fwd'".
🪛 GitHub Actions: GPU CI / gpu-tests
[error] 20-20: AttributeError in CUDA deterministic GEMM forward: '_C' is None and has no attribute 'det_gemm_fwd' (raised at _C.det_gemm_fwd(a, b)).
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@rl_engine/kernels/ops/cuda/matmul/det_gemm.py` around lines 16 - 20, The
functional deterministic GEMM path is bypassing the compiled-extension
availability guard and can hit _C.det_gemm_fwd when _C is None. Update
_DetGemmFn.forward and deterministic_gemm to perform the same availability check
used by DetGemmOp.__call__ before touching _C, and raise the intended
RuntimeError instead of allowing an AttributeError to surface.
Source: Pipeline failures
| def _triton_gemm(a, b): | ||
| a, b = a.contiguous(), b.contiguous() | ||
| M, K = a.shape | ||
| _, N = b.shape | ||
| c = torch.empty((M, N), device=a.device, dtype=a.dtype) | ||
| grid = (triton.cdiv(M, _BLOCK_M), triton.cdiv(N, _BLOCK_N)) | ||
| _det_gemm_kernel[grid]( | ||
| a, | ||
| b, | ||
| c, | ||
| M, | ||
| N, | ||
| K, | ||
| a.stride(0), | ||
| a.stride(1), | ||
| b.stride(0), | ||
| b.stride(1), | ||
| c.stride(0), | ||
| c.stride(1), | ||
| BLOCK_M=_BLOCK_M, | ||
| BLOCK_N=_BLOCK_N, | ||
| BLOCK_K=_BLOCK_K, | ||
| ) |
There was a problem hiding this comment.
🩺 Stability & Availability | 🔴 Critical | ⚡ Quick win
Validate inner dimensions before launching the Triton kernel.
Line 70 takes K from a.shape and never checks b.shape[0] == K. On a mismatch, the kernel masks loads with a’s K, so it can walk past b instead of raising the usual matmul shape error.
Suggested fix
+def _validate_det_gemm_inputs(a, b):
+ if a.ndim != 2 or b.ndim != 2:
+ raise ValueError("det_gemm expects 2D tensors")
+ if a.shape[1] != b.shape[0]:
+ raise ValueError(f"incompatible shapes: {tuple(a.shape)} and {tuple(b.shape)}")
+
def _triton_gemm(a, b):
+ _validate_det_gemm_inputs(a, b)
a, b = a.contiguous(), b.contiguous()
M, K = a.shape
_, N = b.shape📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def _triton_gemm(a, b): | |
| a, b = a.contiguous(), b.contiguous() | |
| M, K = a.shape | |
| _, N = b.shape | |
| c = torch.empty((M, N), device=a.device, dtype=a.dtype) | |
| grid = (triton.cdiv(M, _BLOCK_M), triton.cdiv(N, _BLOCK_N)) | |
| _det_gemm_kernel[grid]( | |
| a, | |
| b, | |
| c, | |
| M, | |
| N, | |
| K, | |
| a.stride(0), | |
| a.stride(1), | |
| b.stride(0), | |
| b.stride(1), | |
| c.stride(0), | |
| c.stride(1), | |
| BLOCK_M=_BLOCK_M, | |
| BLOCK_N=_BLOCK_N, | |
| BLOCK_K=_BLOCK_K, | |
| ) | |
| def _validate_det_gemm_inputs(a, b): | |
| if a.ndim != 2 or b.ndim != 2: | |
| raise ValueError("det_gemm expects 2D tensors") | |
| if a.shape[1] != b.shape[0]: | |
| raise ValueError(f"incompatible shapes: {tuple(a.shape)} and {tuple(b.shape)}") | |
| def _triton_gemm(a, b): | |
| _validate_det_gemm_inputs(a, b) | |
| a, b = a.contiguous(), b.contiguous() | |
| M, K = a.shape | |
| _, N = b.shape | |
| c = torch.empty((M, N), device=a.device, dtype=a.dtype) | |
| grid = (triton.cdiv(M, _BLOCK_M), triton.cdiv(N, _BLOCK_N)) | |
| _det_gemm_kernel[grid]( | |
| a, | |
| b, | |
| c, | |
| M, | |
| N, | |
| K, | |
| a.stride(0), | |
| a.stride(1), | |
| b.stride(0), | |
| b.stride(1), | |
| c.stride(0), | |
| c.stride(1), | |
| BLOCK_M=_BLOCK_M, | |
| BLOCK_N=_BLOCK_N, | |
| BLOCK_K=_BLOCK_K, | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@rl_engine/kernels/ops/triton/matmul/det_gemm.py` around lines 68 - 90, The
_triton_gemm helper currently derives K from a.shape and launches
_det_gemm_kernel without verifying that b.shape[0] matches, so add an explicit
inner-dimension check before the kernel call and raise the standard matmul shape
error on mismatch. Keep the validation in _triton_gemm near the existing M, K, N
shape extraction so invalid inputs are rejected before any Triton launch.
| def backward(ctx, grad_out): | ||
| a, b = ctx.saved_tensors | ||
| grad_out = grad_out.contiguous() | ||
| da = _triton_gemm(grad_out, b.t().contiguous()) if ctx.needs_input_grad[0] else None | ||
| db = _triton_gemm(a.t().contiguous(), grad_out) if ctx.needs_input_grad[1] else None |
There was a problem hiding this comment.
🎯 Functional Correctness | 🟠 Major | 🏗️ Heavy lift
dB is still reduced over the live batch dimension.
Line 105 computes db as Aᵀ @ grad_out, which makes the reduction axis the original M. That means the Triton backward path can still change with batch/chunk/padding layout, so it does not meet the WS1 dW invariance contract yet. As per the PR objectives, backward dW must remain invariant across batch-size and layout changes.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@rl_engine/kernels/ops/triton/matmul/det_gemm.py` around lines 101 - 105, The
backward path in det_gemm.backward still computes dB from Aᵀ @ grad_out, so the
reduction depends on the live batch dimension and can vary with
chunking/padding. Update the db calculation to use the WS1 invariant reduction
contract instead of directly reducing over the current batch layout, and ensure
the Triton backward path produces the same dW regardless of batch size or layout
changes. Keep the fix localized around backward and the _triton_gemm calls in
det_gemm.py.
| cc_major, cc_minor = torch.cuda.get_device_capability() | ||
| enable_sm90 = os.environ.get("KERNEL_ALIGN_FORCE_SM90") == "1" | ||
| nvcc_flags = ["-O3", "--use_fast_math", "-Xfatbin", "-compress-all"] | ||
| if not enable_sm90: | ||
| # SM90 build emits 90a below; mixing plain compute_90 breaks TMA ptxas. | ||
| nvcc_flags.append( | ||
| f"-gencode=arch=compute_{cc_major}{cc_minor},code=sm_{cc_major}{cc_minor}" | ||
| ) |
There was a problem hiding this comment.
🩺 Stability & Availability | 🟠 Major | ⚡ Quick win
KERNEL_ALIGN_DET_GEMM_SM90 still leaves the conflicting plain SM90 gencode enabled.
When KERNEL_ALIGN_DET_GEMM_SM90=1 is set on an SM90 machine, Lines 67-71 still add -gencode=arch=compute_90,code=sm_90, and Lines 138-145 then append compute_90a/sm_90a. The comment on Line 68 says that mix breaks TMA ptxas, so this new path can still fail unless callers also remember to set the separate KERNEL_ALIGN_FORCE_SM90 flag.
Suggested fix
cc_major, cc_minor = torch.cuda.get_device_capability()
enable_sm90 = os.environ.get("KERNEL_ALIGN_FORCE_SM90") == "1"
+ enable_det_gemm_sm90 = os.environ.get("KERNEL_ALIGN_DET_GEMM_SM90") == "1"
nvcc_flags = ["-O3", "--use_fast_math", "-Xfatbin", "-compress-all"]
- if not enable_sm90:
+ if not (enable_sm90 or enable_det_gemm_sm90):
# SM90 build emits 90a below; mixing plain compute_90 breaks TMA ptxas.
nvcc_flags.append(
f"-gencode=arch=compute_{cc_major}{cc_minor},code=sm_{cc_major}{cc_minor}"
)
@@
- enable_det_gemm_sm90 = os.environ.get("KERNEL_ALIGN_DET_GEMM_SM90") == "1"
if enable_det_gemm_sm90:Also applies to: 137-145
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@setup.py` around lines 64 - 71, The SM90 build path still emits a conflicting
plain SM90 gencode when KERNEL_ALIGN_DET_GEMM_SM90 is enabled, so update the
nvcc flag assembly in setup.py to avoid adding the compute_90/sm_90 pair on SM90
machines. Make the logic in the startup flag block and the later SM90 append
path consistent so that enabling KERNEL_ALIGN_DET_GEMM_SM90 alone does not
require also setting KERNEL_ALIGN_FORCE_SM90. Use the existing nvcc_flags
construction and the SM90-specific branch to ensure only the intended 90a
gencode is present.
| pytestmark = pytest.mark.skipif( | ||
| not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 8, | ||
| reason="det_gemm requires CUDA SM80+", | ||
| ) | ||
|
|
||
| # Each deterministic backend is validated independently. | ||
| _BACKENDS = [("cuda", deterministic_gemm)] | ||
| if _HAS_TRITON: | ||
| _BACKENDS.append(("triton", deterministic_gemm_triton)) |
There was a problem hiding this comment.
🩺 Stability & Availability | 🟠 Major | ⚡ Quick win
Don't parametrize the CUDA backend unless the extension actually loaded.
The current skip only checks for CUDA + SM80, but GPU CI shows deterministic_gemm is still callable in a state where its backing binding is None, so every CUDA case crashes before the invariance assertions run. Gate _BACKENDS on the wrapper's real availability signal (or fail fast with a clear module-level skip) instead of assuming import success means the op is usable.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tests/test_det_gemm.py` around lines 25 - 33, The CUDA backend is being added
to _BACKENDS even when deterministic_gemm is not actually available, which
causes crashes before the assertions run. Update the backend selection near
pytestmark/_BACKENDS to check the wrapper’s real availability signal from
deterministic_gemm before appending the CUDA case, or convert it into a clear
module-level skip when the extension binding is None. Keep the Triton entry
gated separately with _HAS_TRITON.
Source: Pipeline failures
| def test_forward_correctness(name, gemm): | ||
| # vs FP32 reference. Placeholder tolerance; PR3 swaps for #108 contract. | ||
| torch.manual_seed(3) | ||
| M, K, N = 128, 2048, 2048 | ||
| a, b = _rand(M, K), _rand(K, N) | ||
| out = gemm(a, b).float() | ||
| ref = a.float() @ b.float() | ||
| assert (out - ref).abs().max().item() < 1.0 # TODO(#108): contract threshold |
There was a problem hiding this comment.
🎯 Functional Correctness | 🟠 Major | 🏗️ Heavy lift
Replace the placeholder tolerances with the #108 correctness contract before merge.
These max_abs < 1.0 / < 2.0 checks are explicitly temporary, and they are loose enough to hide material numeric regressions while still passing CI. The linked objective for #146 calls out validation through the shared #108 harness, so leaving placeholder thresholds here means the correctness part of the contract is still unverified.
Also applies to: 105-117
🧰 Tools
🪛 GitHub Actions: GPU CI / 0_gpu-tests.txt
[error] 83-83: test_forward_correctness[cuda-deterministic_gemm] failed because gemm(a, b) raised the AttributeError: 'NoneType' object has no attribute 'det_gemm_fwd'.
🪛 GitHub Actions: GPU CI / gpu-tests
[error] 83-83: test_forward_correctness[cuda-deterministic_gemm] failed when calling gemm(): AttributeError: 'NoneType' object has no attribute 'det_gemm_fwd'.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tests/test_det_gemm.py` around lines 78 - 85, The correctness assertions in
test_forward_correctness and the related tests still use temporary loose
thresholds instead of the shared `#108` contract. Update the checks in the gemm
test cases to use the `#108` correctness harness/thresholds rather than hardcoded
placeholder max_abs limits, and make sure the affected test functions (including
test_forward_correctness and the nearby cases in the same range) validate
against the contract consistently.
Fixes #146
Purpose
In RL train–inference alignment, a token's logprob must not drift between rollout (vLLM-side, varying batch shapes) and training (Megatron-side). A major source of this drift is the GEMM: cuBLAS selects kernels by problem shape and may use split-K, so the K-reduction order — and therefore the low bits of the output changes with batch size, chunked-prefill splitting, and padding layout. The same row produces different results depending on the batch it rides in.
This PR implements a batch-invariant deterministic GEMM (forward + backward), one op in the WS1 forward chain (#146). A row's output is bitwise invariant to batch size, chunked-prefill splitting, and padding layout. Invariance is achieved by pinning the tile shape, fixing the K-accumulation order, and forbidding split-K / shape-based kernel selection. BF16 inputs, FP32 accumulation, no TF32.
This is PR2 of the planned series (design note → kernel → tests → LM-head wiring → benchmark). Scope here: the kernel(s), op wiring, and invariance tests.
Invariance contract. The same (M, N, K) row produces the same output regardless of the surrounding batch, chunked-prefill split, or padding. Two design points protect this:
Backends. CUDA (SM90 TMA + mma.sync tensor cores; naive FP32 scalar fallback below SM90), Triton (portable / ROCm fallback + cross-backend invariance eference), and a PyTorch reference (torch.matmul, intentionally non-deterministic — reference/benchmark only, deliberately excluded from dispatch, since a non-deterministic fallback would silently defeat the op's purpose).
Out of scope (per #146): tensor-parallel GEMM (WS2), FP8, and occupancy/throughput tuning of the tensor-core path — correctness and invariance first; a slower-than-cuBLAS deterministic baseline is the accepted first milestone. Throughput tuning will be a separate perf PR (see Follow-ups).
Test Plan
on both the CUDA tensor-core kernel and the Triton path.
Test Result
Build: compiled and linked successfully with the SM90 TMA + mma.sync path enabled (CUDA 12.4, H100). det_gemm uses its own csrc/cuda/gemm/det_gemm_tma.cuh (shared::cluster); the shared csrc/utils/tma_utils.cuh emits shared::cta, which CUDA 12.4 ptxas rejects for cp.async.bulk.tensor — scoped the fix to this PR rather than touching the shared logp helpers (tracked separately in #). The
KERNEL_ALIGN_DET_GEMM_SM90build flag is independent of the fused_logp SM90 sources, which are left untouched.Tests: pytest tests/test_det_gemm.py → 22/22 passed on H100 SXM (CUDA 12.4) — forward/backward batch-invariance, chunked-prefill,
padding-invariance (all bitwise), correctness vs FP32, and all 5 target
projection shapes, on both the CUDA tensor-core kernel and Triton.
Benchmark (NVIDIA H100 80GB HBM3, SM90): overhead = det CUDA vs cuBLAS (TF32 disabled). Both deterministic paths trade speed for bitwise invariance by design (no split-K, fixed accumulation, FP32, no TF32).
(The det CUDA path uses SM90 TMA +
mma.synctensor cores, 128×128 tile, single-CTA-per-tile, no split-K. Occupancy/throughput tuning is deferred per #146 and will be an ncu-driven perf PR; this milestone is correctness- and invariance-first.)Files
Follow-ups
Summary by CodeRabbit
New Features
Bug Fixes