Add kernel gtest operator checker and used logprob to test gtest#197
Add kernel gtest operator checker and used logprob to test gtest#197a-kaa wants to merge 4 commits into
Conversation
📝 WalkthroughWalkthroughIntroduces a ChangesOperator Checking Framework
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 5
🧹 Nitpick comments (1)
rl_engine/kernels/gtest/op_checks.py (1)
184-187: 🎯 Functional Correctness | 🔵 Trivial | ⚡ Quick winCall the candidate's public callable path first.
For
torch.nn.Module-like candidates, jumping straight to.forward()bypasses__call__hooks and wrappers, so the checker may validate a different path than production. Prefercandidate(**inputs)when the object is callable, and only fall back to.forward()for non-callable adapters.Proposed change
def _call_candidate(candidate: Callable[..., Any] | Any, inputs: Mapping[str, Any]) -> Any: - if hasattr(candidate, "forward") and callable(candidate.forward): - return candidate.forward(**inputs) - return candidate(**inputs) + if callable(candidate): + return candidate(**inputs) + if hasattr(candidate, "forward") and callable(candidate.forward): + return candidate.forward(**inputs) + raise TypeError(f"candidate is not callable: {type(candidate)!r}")🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@rl_engine/kernels/gtest/op_checks.py` around lines 184 - 187, The _call_candidate helper is invoking .forward() first for Module-like objects, which skips the public callable path. Update _call_candidate so it prefers candidate(**inputs) whenever candidate is callable, and only falls back to candidate.forward(**inputs) for non-callable adapters; keep the existing callable checks around candidate and forward to preserve compatibility.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@rl_engine/kernels/gtest/op_checks.py`:
- Around line 149-150: The candidate and gold evaluations in the op check path
are sharing the same input tree, so in-place writes from the candidate can
affect the reference result. Update the logic around _call_candidate and
case.gold_fn in op_checks.py to build separate cloned copies of case.inputs for
each call, using distinct cloned input trees for the candidate and gold paths.
- Around line 151-155: The arity mismatch in the operator check path currently
raises a ValueError and aborts run_operator_suite() before it can build an
OperatorCheckReport. Update the mismatch handling in op_checks.py around the
candidate/gold output comparison so it records a failed CaseCheck/OutputCheck
for the relevant candidate rather than throwing, allowing
scripts/check_operator.py to continue and emit its structured report. Use the
existing run_operator_suite(), CaseCheck, and OutputCheck flow to surface the
mismatch as a normal test failure.
In `@rl_engine/kernels/gtest/operator_inputs.py`:
- Around line 199-210: Reject non-positive vocab values at the start of
_token_ids before any mode-specific logic runs. Add an upfront validation in
_token_ids that raises a clear error when vocab <= 0, so constant mode does not
hit token_value % vocab and random mode does not fall through to torch.randint
with an invalid range. Keep the fix localized to _token_ids and use its existing
parameters to enforce the check.
In `@rl_engine/kernels/gtest/operator_specs.py`:
- Around line 78-106: make_candidate() currently lets CUDA-backed candidates
through without validating the resolved device, so add an early guard in this
function to reject cuda/cuda-sm90 when the selected device is not CUDA-capable.
Use the existing args.device flow from
make_operator_case()/scripts/check_operator.py and check the resolved device
before loading the candidate, raising a clear ValueError for invalid
backend/device combinations. Keep the existing candidate selection logic and
unique symbols like make_candidate, CandidateSpec, and _LogpSM90CandidateAdapter
to locate the fix.
In `@rl_engine/kernels/gtest/tolerance.py`:
- Around line 11-18: The tolerance contract loader in load_contract is reading
tolerance_contract.yaml with json.load, which only supports strict JSON. Update
load_contract to parse the contract as YAML (using the existing _CONTRACT_PATH
target) or, if you intend to keep JSON parsing, rename the contract and path to
a .json file so the loader and file format match.
---
Nitpick comments:
In `@rl_engine/kernels/gtest/op_checks.py`:
- Around line 184-187: The _call_candidate helper is invoking .forward() first
for Module-like objects, which skips the public callable path. Update
_call_candidate so it prefers candidate(**inputs) whenever candidate is
callable, and only falls back to candidate.forward(**inputs) for non-callable
adapters; keep the existing callable checks around candidate and forward to
preserve compatibility.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: e1add4ce-5e2e-4b44-943e-1d6e4abc2a8c
📒 Files selected for processing (14)
docs/contributing/issue-108-session-log.mddocs/operators/fused-logp.mdrl_engine/kernels/gtest/__init__.pyrl_engine/kernels/gtest/op_checks.pyrl_engine/kernels/gtest/operator_inputs.pyrl_engine/kernels/gtest/operator_specs.pyrl_engine/kernels/gtest/tolerance.pyrl_engine/kernels/gtest/tolerance_contract.yamlrl_engine/kernels/ops/pytorch/loss/logp.pyscripts/check_operator.pytests/test_logp.pytests/test_op_checks.pytests/test_operator_inputs.pytests/test_tolerance_contract.py
Flink-ddd
left a comment
There was a problem hiding this comment.
This is an excellently designed testing harness. Building a unified gtest runner with a YAML-backed tolerance contract is exactly what Workstream 1 needs to scale across CUDA, Triton, and ROCm. The JSON output integration will make CI regression tracking much easier.
Since this is an initial MVP version, the foundation looks very solid. Below are some reviews that outline some architectural improvements:
| @@ -0,0 +1,1044 @@ | |||
| # ISSUE-108 Session Log | |||
|
|
|||
| 本文档记录本 session 中围绕 RL-Kernel 算子测试框架、CUDA 验证和 upstream 同步的所有关键修改。后续本 session 中每次代码修改都必须继续追加到本文档,记录目标、设计判断、修改文件、验证方式和结果。 | |||
There was a problem hiding this comment.
please use english to replace
There was a problem hiding this comment.
Please translate the session log from Chinese to English to maintain the repository's open-source linguistic consistency.
| ) | ||
|
|
||
|
|
||
| def _run_case( |
There was a problem hiding this comment.
Currently, _run_case only evaluates the forward pass (_call_candidate). As we have established in the RMSNorm, SwiGLU, and Embedding PRs, backward-pass consistency is critical to preventing RL training drift.
While it doesn't need to be implemented in this exact PR, you must add a TODO or open a tracking issue to support gradient checking in the gtest framework. Future iterations will need to set requires_grad=True on floating inputs, call .backward() on the outputs, and compare the resulting .grad tensors against the gold path using this same tolerance contract.
|
|
||
| if candidate_name in spec.candidate_paths: | ||
| candidate_op = _load_object(spec.candidate_paths[candidate_name])() | ||
| if args.op == "logp" and candidate_name == "cuda-sm90": |
There was a problem hiding this comment.
Injecting _LogpSM90CandidateAdapter using an if statement (if args.op == "logp" and candidate_name == "cuda-sm90":) inside make_candidate works for the MVP, but it will quickly become spaghetti code as you add more operators, ROCm backends, and Triton kernels that require shape-flattening or specific adapters.
Suggestion: Make the adapter mapping declarative. Add an optional candidate_adapters: dict[str, Callable] field to the OperatorSpec dataclass, so make_candidate can simply look up and wrap the candidate dynamically without needing to know operator-specific logic.
| logits, token_ids = _make_inputs(2, 16, 257, dtype=dtype) | ||
| out = op.forward(logits, token_ids) | ||
| assert out.dtype == dtype | ||
|
|
There was a problem hiding this comment.
I noticed you included the unit tests for NativeLogpOp in this PR to prove the gtest framework works. Similar to the previous reference operators, test_logp.py is completely missing backward-pass tests (test_gradient_flows). Please ensure you add a slice-based Batch-Invariance (Axis-A) test for NativeLogpOp's backward pass before WS1 concludes.
| } | ||
| }, | ||
| "arch_overrides": { | ||
| "sm90": {} |
There was a problem hiding this comment.
Starting with sm90 as the only arch_overrides key is perfectly fine for this stage. The structure here is very clean, and it will be trivial to expand this to include rocm or cdna3 specific tolerances when you reach WS2 and WS3. No changes needed here.
Summary
part of [WS1] Ground-truth harness + numerical contract for batch-invariant ops #108
Usage
Run a CPU smoke check against the PyTorch gold implementation:
Run a CUDA candidate check against the PyTorch gold path:
Print the full structured report as JSON:
Available key options:
Example output:
suite=logp passed=True pass_rate=1.0000
candidate=cuda-logp backend=cuda passed=True pass_rate=1.0000
case=logp-torch.bfloat16-1x1x4096 output=0 shape=(1, 1) dtype=torch.bfloat16 max_abs=2.69813538e-02 mean_abs=2.69813538e-02 max_rel=3.03093810e-03
tol=(atol=5.000e-02, rtol=0.000e+00) passed=True
Adding a New Operator
To add a new operator to the checker, keep the public test flow unchanged and only update the operator-specific registration/input files.
1. Add input generation
(Already added, need check the shapes)
2. Register gold and candidates
(NEED OPS OWNER ADD)
Rules:
Validation
python scripts/check_operator.py --op logp --candidate pytorch --device cpu --dtype fp32 --batch 1 --seq 2 --vocab 17In fact, the test did not pass; however, it proves that the workflow of our testing framework has achieved the minimum viable capability.
Notes
rl_engine.kernels.ops.pytorch.Summary by CodeRabbit
New Features
Bug Fixes
Documentation