Skip to content

Add kernel gtest operator checker and used logprob to test gtest#197

Open
a-kaa wants to merge 4 commits into
RL-Align:mainfrom
a-kaa:logp-gtest
Open

Add kernel gtest operator checker and used logprob to test gtest#197
a-kaa wants to merge 4 commits into
RL-Align:mainfrom
a-kaa:logp-gtest

Conversation

@a-kaa

@a-kaa a-kaa commented Jun 27, 2026

Copy link
Copy Markdown
Collaborator

Summary

Usage

Run a CPU smoke check against the PyTorch gold implementation:

  python scripts/check_operator.py \
    --op logp \
    --candidate pytorch \
    --device cpu \
    --dtype fp32 \
    --batch 1 \
    --seq 2 \
    --vocab 17

Run a CUDA candidate check against the PyTorch gold path:

  python scripts/check_operator.py \
    --op logp \
    --candidate cuda \
    --device cuda \
    --dtype bf16 \
    --arch-key sm90 \
    --batch 1 \
    --seq 1 \
    --vocab 4096

Print the full structured report as JSON:

  python scripts/check_operator.py \
    --op logp \
    --candidate pytorch \
    --device cpu \
    --dtype fp32 \
    --batch 1 \
    --seq 2 \
    --vocab 17 \
    --json

Available key options:

  • --op: operator name. Current minimal version supports logp.
  • --candidate: backend candidate, for example pytorch, cuda, cuda-generic, cuda-sm90, or registry.
  • --dtype: fp32, bf16, or fp16.
  • --device: auto, cpu, cuda, or any torch device string.
  • --arch-key: optional tolerance override key, for example sm90.
  • --batch, --seq, --vocab: shape controls.
  • --input-mode: random or constant.
  • --constant-value: floating-point tensor value used by constant mode.
  • --token-value: token id used by constant mode, modulo vocab.
  • --seed: random input seed.
  • --json: emit full JSON report.

Example output:

suite=logp passed=True pass_rate=1.0000
candidate=cuda-logp backend=cuda passed=True pass_rate=1.0000
case=logp-torch.bfloat16-1x1x4096 output=0 shape=(1, 1) dtype=torch.bfloat16 max_abs=2.69813538e-02 mean_abs=2.69813538e-02 max_rel=3.03093810e-03
tol=(atol=5.000e-02, rtol=0.000e+00) passed=True

Adding a New Operator

To add a new operator to the checker, keep the public test flow unchanged and only update the operator-specific registration/input files.

1. Add input generation

(Already added, need check the shapes)

  rl_engine/kernels/gtest/operator_inputs.py 


  Update make_operator_inputs():

  builders = {
      ...
      "new_op": _make_new_op_inputs,
  }

  Update operator_shape_name():

  names = {
      ...
      "new_op": f"{batch}x{seq}x...",
  }

  Add the input builder:

  def _make_new_op_inputs(args, dtype, device):
      batch, seq = _batch_seq(args)
      return {
          "x": _floating_tensor((batch, seq, ...), args, dtype, device, offset=0),
          ...
      }

2. Register gold and candidates

(NEED OPS OWNER ADD)


  rl_engine/kernels/gtest/operator_specs.py

  Add an OperatorSpec entry:

  "new_op": OperatorSpec(
      name="new_op",
      op_class="elementwise",
      gold_path="rl_engine.kernels.ops.pytorch....NativeNewOp",
      registry_name="new_op",
      candidate_paths={
          "pytorch": "rl_engine.kernels.ops.pytorch....NativeNewOp",
          "cuda": "rl_engine.kernels.ops.cuda....CudaNewOp",
          "triton": "rl_engine.kernels.ops.triton....TritonNewOp",
      },
  )

Rules:

  • gold_path must come from rl_engine.kernels.ops.pytorch.
  • CUDA/Triton/ROCm implementations are candidates only.
  • Do not compare two operators that implement different math.
  • candidate=pytorch is only for checker smoke tests.

Validation

  • python scripts/check_operator.py --op logp --candidate pytorch --device cpu --dtype fp32 --batch 1 --seq 2 --vocab 17
71837019648a53174ec8e566ce210027

In fact, the test did not pass; however, it proves that the workflow of our testing framework has achieved the minimum viable capability.

Notes

  • Gold paths are required to come from rl_engine.kernels.ops.pytorch.
  • CUDA/Triton/ROCm implementations are treated as candidates.
  • SM90 fused logp remains under separate validation and is not included as a passing path in this PR.

Summary by CodeRabbit

  • New Features

    • Added a command-line tool to run operator checks and return a readable or JSON report.
    • Introduced broader operator input generation and reusable test-suite support for multiple kernel operators.
  • Bug Fixes

    • Improved log-probability handling for consistent forward and fp32 outputs.
    • Updated tolerance rules to better match observed accuracy across data types and hardware paths.
  • Documentation

    • Added detailed contribution/session notes and refreshed fused-logp usage, test, and contract guidance.

@coderabbitai

coderabbitai Bot commented Jun 27, 2026

Copy link
Copy Markdown

Review Change Stack

📝 Walkthrough

Walkthrough

Introduces a rl_engine/kernels/gtest/ package implementing a reusable gtest-style kernel operator validation framework: a YAML tolerance contract, structured dataclass harness (op_checks.py), parameterized input factory (operator_inputs.py), operator specs and candidate resolution (operator_specs.py), and a scripts/check_operator.py CLI. NativeLogpOp is refactored to expose forward/forward_fp32 methods. Tests and documentation are added throughout.

Changes

Operator Checking Framework

Layer / File(s) Summary
Tolerance contract and loader
rl_engine/kernels/gtest/tolerance_contract.yaml, rl_engine/kernels/gtest/tolerance.py, tests/test_tolerance_contract.py
Defines per-dtype atol/rtol for elementwise, reduction, and logprob operator classes plus an arch_overrides block; load_contract() reads and parses the YAML; tests assert structure and numeric values.
op_checks harness dataclasses and runner
rl_engine/kernels/gtest/op_checks.py, rl_engine/kernels/gtest/__init__.py
Introduces OperatorCase, CandidateSpec, OutputCheck, CaseCheck, CandidateReport, OperatorCheckReport dataclasses and run_operator_suite with internal helpers for candidate invocation, tensor flattening, tolerance resolution, and FP32 comparison.
Operator input factory
rl_engine/kernels/gtest/operator_inputs.py, tests/test_operator_inputs.py
make_operator_inputs and operator_shape_name dispatch to per-operator builders using deterministic random/constant tensor generation helpers; tests verify all Issue-108 ops and logp determinism.
NativeLogpOp refactor
rl_engine/kernels/ops/pytorch/loss/logp.py, tests/test_logp.py
Adds op_class = "logprob", moves baseline logic into forward/forward_fp32, converts apply/apply_fp32 to aliases; full pytest suite covers correctness, batch invariance, accuracy, and registry dispatch.
Operator specs and candidate resolution
rl_engine/kernels/gtest/operator_specs.py
OperatorSpec and OP_SPECS define the logp operator; make_operator_case dynamically loads the gold op; make_candidate resolves native/registry/cuda-sm90 backends, wrapping SM90 with _LogpSM90CandidateAdapter.
check_operator.py CLI
scripts/check_operator.py
Adds a CLI with dtype/device normalization, parse_args, _summarize, and main that builds one candidate and one operator case, runs the suite, and exits 1 on failure.
op_checks integration tests
tests/test_op_checks.py
Tests native/registry/bad/to_dict/arch_key override paths of run_operator_suite for the logp operator.
Session log and operator docs
docs/contributing/issue-108-session-log.md, docs/operators/fused-logp.md
Session log records full development timeline, H20/H100 CUDA verification results, SM90 status, and new-operator onboarding guide; fused-logp.md updates NativeLogpOp API reference, tensor contract, test commands, and implementation file list.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Suggested labels

component: kernels

Suggested reviewers

  • Flink-ddd
  • inaniloquentee
  • EthanZero2Hero

Poem

🐇 Hop, hop — the gtest is here at last,
Tolerance tables and contracts now cast.
forward_fp32 leaps with FP32 care,
SM90 stumbles, but we document fair.
The rabbit checks ops with a JSON report,
No kernel escapes this rigorous sort! ✨

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 7.32% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title matches the main change: adding a kernel gtest operator checker and logprob test path.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🧹 Nitpick comments (1)
rl_engine/kernels/gtest/op_checks.py (1)

184-187: 🎯 Functional Correctness | 🔵 Trivial | ⚡ Quick win

Call the candidate's public callable path first.

For torch.nn.Module-like candidates, jumping straight to .forward() bypasses __call__ hooks and wrappers, so the checker may validate a different path than production. Prefer candidate(**inputs) when the object is callable, and only fall back to .forward() for non-callable adapters.

Proposed change
 def _call_candidate(candidate: Callable[..., Any] | Any, inputs: Mapping[str, Any]) -> Any:
-    if hasattr(candidate, "forward") and callable(candidate.forward):
-        return candidate.forward(**inputs)
-    return candidate(**inputs)
+    if callable(candidate):
+        return candidate(**inputs)
+    if hasattr(candidate, "forward") and callable(candidate.forward):
+        return candidate.forward(**inputs)
+    raise TypeError(f"candidate is not callable: {type(candidate)!r}")
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/kernels/gtest/op_checks.py` around lines 184 - 187, The
_call_candidate helper is invoking .forward() first for Module-like objects,
which skips the public callable path. Update _call_candidate so it prefers
candidate(**inputs) whenever candidate is callable, and only falls back to
candidate.forward(**inputs) for non-callable adapters; keep the existing
callable checks around candidate and forward to preserve compatibility.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@rl_engine/kernels/gtest/op_checks.py`:
- Around line 149-150: The candidate and gold evaluations in the op check path
are sharing the same input tree, so in-place writes from the candidate can
affect the reference result. Update the logic around _call_candidate and
case.gold_fn in op_checks.py to build separate cloned copies of case.inputs for
each call, using distinct cloned input trees for the candidate and gold paths.
- Around line 151-155: The arity mismatch in the operator check path currently
raises a ValueError and aborts run_operator_suite() before it can build an
OperatorCheckReport. Update the mismatch handling in op_checks.py around the
candidate/gold output comparison so it records a failed CaseCheck/OutputCheck
for the relevant candidate rather than throwing, allowing
scripts/check_operator.py to continue and emit its structured report. Use the
existing run_operator_suite(), CaseCheck, and OutputCheck flow to surface the
mismatch as a normal test failure.

In `@rl_engine/kernels/gtest/operator_inputs.py`:
- Around line 199-210: Reject non-positive vocab values at the start of
_token_ids before any mode-specific logic runs. Add an upfront validation in
_token_ids that raises a clear error when vocab <= 0, so constant mode does not
hit token_value % vocab and random mode does not fall through to torch.randint
with an invalid range. Keep the fix localized to _token_ids and use its existing
parameters to enforce the check.

In `@rl_engine/kernels/gtest/operator_specs.py`:
- Around line 78-106: make_candidate() currently lets CUDA-backed candidates
through without validating the resolved device, so add an early guard in this
function to reject cuda/cuda-sm90 when the selected device is not CUDA-capable.
Use the existing args.device flow from
make_operator_case()/scripts/check_operator.py and check the resolved device
before loading the candidate, raising a clear ValueError for invalid
backend/device combinations. Keep the existing candidate selection logic and
unique symbols like make_candidate, CandidateSpec, and _LogpSM90CandidateAdapter
to locate the fix.

In `@rl_engine/kernels/gtest/tolerance.py`:
- Around line 11-18: The tolerance contract loader in load_contract is reading
tolerance_contract.yaml with json.load, which only supports strict JSON. Update
load_contract to parse the contract as YAML (using the existing _CONTRACT_PATH
target) or, if you intend to keep JSON parsing, rename the contract and path to
a .json file so the loader and file format match.

---

Nitpick comments:
In `@rl_engine/kernels/gtest/op_checks.py`:
- Around line 184-187: The _call_candidate helper is invoking .forward() first
for Module-like objects, which skips the public callable path. Update
_call_candidate so it prefers candidate(**inputs) whenever candidate is
callable, and only falls back to candidate.forward(**inputs) for non-callable
adapters; keep the existing callable checks around candidate and forward to
preserve compatibility.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: e1add4ce-5e2e-4b44-943e-1d6e4abc2a8c

📥 Commits

Reviewing files that changed from the base of the PR and between ea196da and 7defbdd.

📒 Files selected for processing (14)
  • docs/contributing/issue-108-session-log.md
  • docs/operators/fused-logp.md
  • rl_engine/kernels/gtest/__init__.py
  • rl_engine/kernels/gtest/op_checks.py
  • rl_engine/kernels/gtest/operator_inputs.py
  • rl_engine/kernels/gtest/operator_specs.py
  • rl_engine/kernels/gtest/tolerance.py
  • rl_engine/kernels/gtest/tolerance_contract.yaml
  • rl_engine/kernels/ops/pytorch/loss/logp.py
  • scripts/check_operator.py
  • tests/test_logp.py
  • tests/test_op_checks.py
  • tests/test_operator_inputs.py
  • tests/test_tolerance_contract.py

Comment thread rl_engine/kernels/gtest/op_checks.py
Comment thread rl_engine/kernels/gtest/op_checks.py
Comment thread rl_engine/kernels/gtest/operator_inputs.py
Comment thread rl_engine/kernels/gtest/operator_specs.py
Comment thread rl_engine/kernels/gtest/tolerance.py

@Flink-ddd Flink-ddd left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an excellently designed testing harness. Building a unified gtest runner with a YAML-backed tolerance contract is exactly what Workstream 1 needs to scale across CUDA, Triton, and ROCm. The JSON output integration will make CI regression tracking much easier.

Since this is an initial MVP version, the foundation looks very solid. Below are some reviews that outline some architectural improvements:

@@ -0,0 +1,1044 @@
# ISSUE-108 Session Log

本文档记录本 session 中围绕 RL-Kernel 算子测试框架、CUDA 验证和 upstream 同步的所有关键修改。后续本 session 中每次代码修改都必须继续追加到本文档,记录目标、设计判断、修改文件、验证方式和结果。

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use english to replace

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please translate the session log from Chinese to English to maintain the repository's open-source linguistic consistency.

)


def _run_case(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, _run_case only evaluates the forward pass (_call_candidate). As we have established in the RMSNorm, SwiGLU, and Embedding PRs, backward-pass consistency is critical to preventing RL training drift.

While it doesn't need to be implemented in this exact PR, you must add a TODO or open a tracking issue to support gradient checking in the gtest framework. Future iterations will need to set requires_grad=True on floating inputs, call .backward() on the outputs, and compare the resulting .grad tensors against the gold path using this same tolerance contract.


if candidate_name in spec.candidate_paths:
candidate_op = _load_object(spec.candidate_paths[candidate_name])()
if args.op == "logp" and candidate_name == "cuda-sm90":

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Injecting _LogpSM90CandidateAdapter using an if statement (if args.op == "logp" and candidate_name == "cuda-sm90":) inside make_candidate works for the MVP, but it will quickly become spaghetti code as you add more operators, ROCm backends, and Triton kernels that require shape-flattening or specific adapters.

Suggestion: Make the adapter mapping declarative. Add an optional candidate_adapters: dict[str, Callable] field to the OperatorSpec dataclass, so make_candidate can simply look up and wrap the candidate dynamically without needing to know operator-specific logic.

Comment thread tests/test_logp.py
logits, token_ids = _make_inputs(2, 16, 257, dtype=dtype)
out = op.forward(logits, token_ids)
assert out.dtype == dtype

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed you included the unit tests for NativeLogpOp in this PR to prove the gtest framework works. Similar to the previous reference operators, test_logp.py is completely missing backward-pass tests (test_gradient_flows). Please ensure you add a slice-based Batch-Invariance (Axis-A) test for NativeLogpOp's backward pass before WS1 concludes.

}
},
"arch_overrides": {
"sm90": {}

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Starting with sm90 as the only arch_overrides key is perfectly fine for this stage. The structure here is very clean, and it will be trivial to expand this to include rocm or cdna3 specific tolerances when you reach WS2 and WS3. No changes needed here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants