[FEAT][kernels] route LM-head vocab projection through deterministic GEMM#194
[FEAT][kernels] route LM-head vocab projection through deterministic GEMM#194inaniloquentee wants to merge 1 commit into
Conversation
…GEMM Signed-off-by: inaniloquentee <3051000145@qq.com>
📝 WalkthroughWalkthroughThe DeepSpeed trainer now computes log-probabilities through a hidden-state linear-logp path, resolves ZeRO stage from DeepSpeed config, updates weight publishing, and adds tests for gather behavior, masking, hidden-state extraction, and layout-invariant gradients. ChangesDeepSpeed linear-logp training and tests
Sequence Diagram(s)sequenceDiagram
participant DeepSpeedTrainingWorker
participant GatheredParameters
participant kernel_registry
participant NativeLinearLogpOp
participant DeepSpeedEngine
DeepSpeedTrainingWorker->>GatheredParameters: enter lm_head parameter context
DeepSpeedTrainingWorker->>kernel_registry: resolve device linear-logp op
kernel_registry-->>DeepSpeedTrainingWorker: NativeLinearLogpOp or registered op
DeepSpeedTrainingWorker->>NativeLinearLogpOp: compute current_logps and ref_logps
NativeLinearLogpOp-->>DeepSpeedTrainingWorker: per-token logprobs
DeepSpeedTrainingWorker->>DeepSpeedEngine: backward(loss)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@rl_engine/executors/deepspeed_trainer.py`:
- Around line 437-441: The backend selection in _linear_logp_op_for_device is
still driven by the global device_ctx through
kernel_registry.get_op("linear_logp"), so non-CPU workers can pick the wrong
kernel when device_ctx is stale. Fix this by either synchronizing device_ctx
with the active worker device before worker construction or changing
kernel_registry.get_op and its callers to accept and use the explicit device
argument from _linear_logp_op_for_device / the DeepSpeed worker initialization
path.
- Around line 221-226: The manifest layout merge in the DeepSpeed trainer is
allowing caller-supplied metadata to overwrite resolved layout values. Update
the layout assembly around the manifest_metadata["layout"] merge so the resolved
fields from the trainer stay authoritative, and only non-conflicting extra
layout keys from caller metadata are added. Use the existing
deepspeed_trainer.py layout-building logic and the symbols manifest_metadata,
layout, _deepspeed_zero_stage, _engine_world_size, and _engine_rank to keep the
exported manifest consistent with the actual model state.
In `@tests/test_deepspeed_training_worker.py`:
- Around line 64-65: The shared test counters in the test class are mutable
class attributes, so Ruff expects them to be annotated as ClassVar rather than
treated as instance defaults. Update the class that defines modifier_ranks and
parameter_counts to mark both attributes as ClassVar, keeping them as
intentionally shared fake state while satisfying the lint rule.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: c9e7344b-c879-48a3-9169-e905dc1d7e32
📒 Files selected for processing (3)
rl_engine/executors/deepspeed_trainer.pytests/test_deepspeed_training_worker.pytests/test_linear_logp.py
| "zero_stage": self._deepspeed_zero_stage, | ||
| "world_size": self._engine_world_size(), | ||
| "rank": self._engine_rank(), | ||
| } | ||
| layout.update(dict(manifest_metadata.get("layout", {}))) | ||
| manifest_metadata["layout"] = layout |
There was a problem hiding this comment.
🗄️ Data Integrity & Integration | 🟠 Major | ⚡ Quick win
Preserve the resolved layout fields when merging caller metadata.
Line 225 lets metadata["layout"] overwrite the resolved zero_stage, world_size, or rank, so a ZeRO-3 export can publish a manifest that claims a different layout than the model state actually used.
Suggested fix
- layout = {
+ layout = dict(manifest_metadata.get("layout", {}))
+ layout.update({
"kind": "full-state",
"zero_stage": self._deepspeed_zero_stage,
"world_size": self._engine_world_size(),
"rank": self._engine_rank(),
- }
- layout.update(dict(manifest_metadata.get("layout", {})))
+ })
manifest_metadata["layout"] = layout📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| "zero_stage": self._deepspeed_zero_stage, | |
| "world_size": self._engine_world_size(), | |
| "rank": self._engine_rank(), | |
| } | |
| layout.update(dict(manifest_metadata.get("layout", {}))) | |
| manifest_metadata["layout"] = layout | |
| "zero_stage": self._deepspeed_zero_stage, | |
| "world_size": self._engine_world_size(), | |
| "rank": self._engine_rank(), | |
| } | |
| layout = dict(manifest_metadata.get("layout", {})) | |
| layout.update({ | |
| "kind": "full-state", | |
| "zero_stage": self._deepspeed_zero_stage, | |
| "world_size": self._engine_world_size(), | |
| "rank": self._engine_rank(), | |
| }) | |
| manifest_metadata["layout"] = layout |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@rl_engine/executors/deepspeed_trainer.py` around lines 221 - 226, The
manifest layout merge in the DeepSpeed trainer is allowing caller-supplied
metadata to overwrite resolved layout values. Update the layout assembly around
the manifest_metadata["layout"] merge so the resolved fields from the trainer
stay authoritative, and only non-conflicting extra layout keys from caller
metadata are added. Use the existing deepspeed_trainer.py layout-building logic
and the symbols manifest_metadata, layout, _deepspeed_zero_stage,
_engine_world_size, and _engine_rank to keep the exported manifest consistent
with the actual model state.
| def _linear_logp_op_for_device(device: torch.device | str) -> Any: | ||
| resolved = torch.device(device) | ||
| if resolved.type == "cpu": | ||
| return NativeLinearLogpOp() | ||
| return kernel_registry.get_op("linear_logp") |
There was a problem hiding this comment.
🚀 Performance & Scalability | 🟠 Major
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Verify whether worker device selection updates the registry/device context before linear_logp op lookup.
rg -n -C3 'device_ctx|set_device|_linear_logp_op_for_device|kernel_registry\.get_op|DeepSpeedTrainingWorker' rl_engineRepository: RL-Align/RL-Kernel
Length of output: 7361
Non-CPU backend selection decoupled from actual worker device context.
The _linear_logp_op_for_device function accepts a specific device argument but delegates to kernel_registry.get_op("linear_logp"), which ignores this argument and relies exclusively on the global device_ctx singleton. Since device_ctx is initialized at module load time and not synchronized with self.engine.device before worker construction, CUDA/ROCm workers may silently select an incorrect backend if the global context fails to reflect the active worker device.
Action: Align device_ctx before worker initialization or refactor the registry to accept an explicit device argument for kernel selection.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@rl_engine/executors/deepspeed_trainer.py` around lines 437 - 441, The backend
selection in _linear_logp_op_for_device is still driven by the global device_ctx
through kernel_registry.get_op("linear_logp"), so non-CPU workers can pick the
wrong kernel when device_ctx is stale. Fix this by either synchronizing
device_ctx with the active worker device before worker construction or changing
kernel_registry.get_op and its callers to accept and use the explicit device
argument from _linear_logp_op_for_device / the DeepSpeed worker initialization
path.
| modifier_ranks = [] | ||
| parameter_counts = [] |
There was a problem hiding this comment.
📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win
Annotate shared fake state as ClassVar.
Ruff flags these mutable class attributes; they are intentionally shared test counters, so mark them as class variables instead of instance defaults.
Suggested fix
+from typing import ClassVar
+
class FakeGatheredParameters:
calls = 0
active = 0
max_active = 0
- modifier_ranks = []
- parameter_counts = []
+ modifier_ranks: ClassVar[list[object]] = []
+ parameter_counts: ClassVar[list[int]] = []📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| modifier_ranks = [] | |
| parameter_counts = [] | |
| from typing import ClassVar | |
| class FakeGatheredParameters: | |
| calls = 0 | |
| active = 0 | |
| max_active = 0 | |
| modifier_ranks: ClassVar[list[object]] = [] | |
| parameter_counts: ClassVar[list[int]] = [] |
🧰 Tools
🪛 Ruff (0.15.18)
[warning] 64-64: Mutable default value for class attribute
(RUF012)
[warning] 65-65: Mutable default value for class attribute
(RUF012)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tests/test_deepspeed_training_worker.py` around lines 64 - 65, The shared
test counters in the test class are mutable class attributes, so Ruff expects
them to be annotated as ClassVar rather than treated as instance defaults.
Update the class that defines modifier_ranks and parameter_counts to mark both
attributes as ClassVar, keeping them as intentionally shared fake state while
satisfying the lint rule.
Source: Linters/SAST tools
Summary
PR split
Testing
Summary by CodeRabbit
New Features
Bug Fixes
Tests