Skip to content

[FEAT][kernels] route LM-head vocab projection through deterministic GEMM#194

Open
inaniloquentee wants to merge 1 commit into
mainfrom
feat/deterministic-lm-head-gemm
Open

[FEAT][kernels] route LM-head vocab projection through deterministic GEMM#194
inaniloquentee wants to merge 1 commit into
mainfrom
feat/deterministic-lm-head-gemm

Conversation

@inaniloquentee

@inaniloquentee inaniloquentee commented Jun 26, 2026

Copy link
Copy Markdown
Collaborator

Summary

  • route the DeepSpeed training LM-head vocab projection through linear_logp so the deterministic GEMM path is used instead of materializing logits for selected logprobs
  • gather lm_head parameters for ZeRO-3 around the forward/backward window and report the resolved DeepSpeed zero stage in training/export metadata
  • harden hidden-state extraction and input-token validation, with regression coverage for tuple outputs, config overrides, and ZeRO-3 publishing

PR split

  • this PR is the issue151 follow-up only
  • the tensor-parallel linear_logp work stays separate from this branch so review can stay focused on the deterministic GEMM routing change

Testing

  • python -m pre_commit run --all-files
  • python -m mypy --python-version 3.10 --ignore-missing-imports rl_engine/
  • python -m pytest rl_engine/tests/test_dispatch.py -v
  • PYTEST_DISABLE_PLUGIN_AUTOLOAD=1 python -m pytest tests/test_attention_correctness.py -q -rs
  • python -m pytest tests/test_deepspeed_training_worker.py tests/test_linear_logp.py -q
  • python -m pytest tests/test_alignment_model_wrappers.py tests/test_deepspeed_training_worker.py tests/test_linear_logp.py tests/test_reference_ops.py tests/test_op_accuracy.py tests/test_weight_sync_bridge.py tests/test_ray_actor_manager.py tests/test_stateless_executor.py tests/test_stateless_training_contract.py -q
  • python -m mkdocs build --strict -f mkdocs.yaml

Summary by CodeRabbit

  • New Features

    • Training now uses a more robust log-probability path that works across devices and supports tied embedding/language-model head weights.
    • Added broader support for different model output shapes and safer handling of masked token targets.
  • Bug Fixes

    • Improved DeepSpeed training and weight publishing behavior across ZeRO configurations.
    • Added stronger validation to prevent invalid token handling during training.
  • Tests

    • Expanded coverage for layout invariance, masked inputs, gradient consistency, and ZeRO-specific training/publishing behavior.

…GEMM

Signed-off-by: inaniloquentee <3051000145@qq.com>
@coderabbitai

coderabbitai Bot commented Jun 26, 2026

Copy link
Copy Markdown

Review Change Stack

📝 Walkthrough

Walkthrough

The DeepSpeed trainer now computes log-probabilities through a hidden-state linear-logp path, resolves ZeRO stage from DeepSpeed config, updates weight publishing, and adds tests for gather behavior, masking, hidden-state extraction, and layout-invariant gradients.

Changes

DeepSpeed linear-logp training and tests

Layer / File(s) Summary
Linear-logp helpers
rl_engine/executors/deepspeed_trainer.py
Adds zero-stage resolution, the embedding-plus-lm-head model, hidden-state extraction, token validation, device op selection, and gathered-parameter logprob helpers.
Training worker path
rl_engine/executors/deepspeed_trainer.py
DeepSpeedTrainingWorker.train() now builds the embedding-plus-lm-head model, computes logprobs through _extract_logps, and reports the new metrics.
Weight publishing
rl_engine/executors/deepspeed_trainer.py
publish_weights() uses the resolved ZeRO stage to choose the manifest layout and full-state export condition.
DeepSpeed worker tests
tests/test_deepspeed_training_worker.py
Expands the DeepSpeed worker tests with gather tracking, a NativeLinearLogpOp spy, masking checks, hidden-state extraction checks, ZeRO-3 behavior, metrics, and publishing coverage.
Linear-logp operator tests
tests/test_linear_logp.py
Adds layout-permutation helpers and checks forward masking, chunked backward gradients, and tied embedding/lm-head gradient invariance.

Sequence Diagram(s)

sequenceDiagram
  participant DeepSpeedTrainingWorker
  participant GatheredParameters
  participant kernel_registry
  participant NativeLinearLogpOp
  participant DeepSpeedEngine
  DeepSpeedTrainingWorker->>GatheredParameters: enter lm_head parameter context
  DeepSpeedTrainingWorker->>kernel_registry: resolve device linear-logp op
  kernel_registry-->>DeepSpeedTrainingWorker: NativeLinearLogpOp or registered op
  DeepSpeedTrainingWorker->>NativeLinearLogpOp: compute current_logps and ref_logps
  NativeLinearLogpOp-->>DeepSpeedTrainingWorker: per-token logprobs
  DeepSpeedTrainingWorker->>DeepSpeedEngine: backward(loss)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

  • RL-Align/RL-Kernel#122: Adds the linear_logp kernel and registry dispatch that this PR consumes in DeepSpeed training.
  • RL-Align/RL-Kernel#94: Changes the DeepSpeed training worker’s log-probability and KL handling in the same executor path.

Suggested labels

needs-gpu-ci

Suggested reviewers

  • maxiaosong1124
  • bitborne
  • Flink-ddd

Poem

A bunny hopped through tensors bright,
And found new log-probs taking flight.
With ZeRO rows and kernels neat,
My whiskers twitch at gradient heat.
🐰✨

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 4.55% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately summarizes the main change: routing LM-head vocab projection through a deterministic GEMM path.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feat/deterministic-lm-head-gemm

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@rl_engine/executors/deepspeed_trainer.py`:
- Around line 437-441: The backend selection in _linear_logp_op_for_device is
still driven by the global device_ctx through
kernel_registry.get_op("linear_logp"), so non-CPU workers can pick the wrong
kernel when device_ctx is stale. Fix this by either synchronizing device_ctx
with the active worker device before worker construction or changing
kernel_registry.get_op and its callers to accept and use the explicit device
argument from _linear_logp_op_for_device / the DeepSpeed worker initialization
path.
- Around line 221-226: The manifest layout merge in the DeepSpeed trainer is
allowing caller-supplied metadata to overwrite resolved layout values. Update
the layout assembly around the manifest_metadata["layout"] merge so the resolved
fields from the trainer stay authoritative, and only non-conflicting extra
layout keys from caller metadata are added. Use the existing
deepspeed_trainer.py layout-building logic and the symbols manifest_metadata,
layout, _deepspeed_zero_stage, _engine_world_size, and _engine_rank to keep the
exported manifest consistent with the actual model state.

In `@tests/test_deepspeed_training_worker.py`:
- Around line 64-65: The shared test counters in the test class are mutable
class attributes, so Ruff expects them to be annotated as ClassVar rather than
treated as instance defaults. Update the class that defines modifier_ranks and
parameter_counts to mark both attributes as ClassVar, keeping them as
intentionally shared fake state while satisfying the lint rule.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: c9e7344b-c879-48a3-9169-e905dc1d7e32

📥 Commits

Reviewing files that changed from the base of the PR and between ea196da and 6101978.

📒 Files selected for processing (3)
  • rl_engine/executors/deepspeed_trainer.py
  • tests/test_deepspeed_training_worker.py
  • tests/test_linear_logp.py

Comment on lines +221 to 226
"zero_stage": self._deepspeed_zero_stage,
"world_size": self._engine_world_size(),
"rank": self._engine_rank(),
}
layout.update(dict(manifest_metadata.get("layout", {})))
manifest_metadata["layout"] = layout

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🗄️ Data Integrity & Integration | 🟠 Major | ⚡ Quick win

Preserve the resolved layout fields when merging caller metadata.

Line 225 lets metadata["layout"] overwrite the resolved zero_stage, world_size, or rank, so a ZeRO-3 export can publish a manifest that claims a different layout than the model state actually used.

Suggested fix
-        layout = {
+        layout = dict(manifest_metadata.get("layout", {}))
+        layout.update({
             "kind": "full-state",
             "zero_stage": self._deepspeed_zero_stage,
             "world_size": self._engine_world_size(),
             "rank": self._engine_rank(),
-        }
-        layout.update(dict(manifest_metadata.get("layout", {})))
+        })
         manifest_metadata["layout"] = layout
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
"zero_stage": self._deepspeed_zero_stage,
"world_size": self._engine_world_size(),
"rank": self._engine_rank(),
}
layout.update(dict(manifest_metadata.get("layout", {})))
manifest_metadata["layout"] = layout
"zero_stage": self._deepspeed_zero_stage,
"world_size": self._engine_world_size(),
"rank": self._engine_rank(),
}
layout = dict(manifest_metadata.get("layout", {}))
layout.update({
"kind": "full-state",
"zero_stage": self._deepspeed_zero_stage,
"world_size": self._engine_world_size(),
"rank": self._engine_rank(),
})
manifest_metadata["layout"] = layout
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/executors/deepspeed_trainer.py` around lines 221 - 226, The
manifest layout merge in the DeepSpeed trainer is allowing caller-supplied
metadata to overwrite resolved layout values. Update the layout assembly around
the manifest_metadata["layout"] merge so the resolved fields from the trainer
stay authoritative, and only non-conflicting extra layout keys from caller
metadata are added. Use the existing deepspeed_trainer.py layout-building logic
and the symbols manifest_metadata, layout, _deepspeed_zero_stage,
_engine_world_size, and _engine_rank to keep the exported manifest consistent
with the actual model state.

Comment on lines +437 to +441
def _linear_logp_op_for_device(device: torch.device | str) -> Any:
resolved = torch.device(device)
if resolved.type == "cpu":
return NativeLinearLogpOp()
return kernel_registry.get_op("linear_logp")

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚀 Performance & Scalability | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify whether worker device selection updates the registry/device context before linear_logp op lookup.
rg -n -C3 'device_ctx|set_device|_linear_logp_op_for_device|kernel_registry\.get_op|DeepSpeedTrainingWorker' rl_engine

Repository: RL-Align/RL-Kernel

Length of output: 7361


Non-CPU backend selection decoupled from actual worker device context.

The _linear_logp_op_for_device function accepts a specific device argument but delegates to kernel_registry.get_op("linear_logp"), which ignores this argument and relies exclusively on the global device_ctx singleton. Since device_ctx is initialized at module load time and not synchronized with self.engine.device before worker construction, CUDA/ROCm workers may silently select an incorrect backend if the global context fails to reflect the active worker device.

Action: Align device_ctx before worker initialization or refactor the registry to accept an explicit device argument for kernel selection.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/executors/deepspeed_trainer.py` around lines 437 - 441, The backend
selection in _linear_logp_op_for_device is still driven by the global device_ctx
through kernel_registry.get_op("linear_logp"), so non-CPU workers can pick the
wrong kernel when device_ctx is stale. Fix this by either synchronizing
device_ctx with the active worker device before worker construction or changing
kernel_registry.get_op and its callers to accept and use the explicit device
argument from _linear_logp_op_for_device / the DeepSpeed worker initialization
path.

Comment on lines +64 to +65
modifier_ranks = []
parameter_counts = []

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win

Annotate shared fake state as ClassVar.

Ruff flags these mutable class attributes; they are intentionally shared test counters, so mark them as class variables instead of instance defaults.

Suggested fix
+from typing import ClassVar
+
 class FakeGatheredParameters:
     calls = 0
     active = 0
     max_active = 0
-    modifier_ranks = []
-    parameter_counts = []
+    modifier_ranks: ClassVar[list[object]] = []
+    parameter_counts: ClassVar[list[int]] = []
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
modifier_ranks = []
parameter_counts = []
from typing import ClassVar
class FakeGatheredParameters:
calls = 0
active = 0
max_active = 0
modifier_ranks: ClassVar[list[object]] = []
parameter_counts: ClassVar[list[int]] = []
🧰 Tools
🪛 Ruff (0.15.18)

[warning] 64-64: Mutable default value for class attribute

(RUF012)


[warning] 65-65: Mutable default value for class attribute

(RUF012)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/test_deepspeed_training_worker.py` around lines 64 - 65, The shared
test counters in the test class are mutable class attributes, so Ruff expects
them to be annotated as ClassVar rather than treated as instance defaults.
Update the class that defines modifier_ranks and parameter_counts to mark both
attributes as ClassVar, keeping them as intentionally shared fake state while
satisfying the lint rule.

Source: Linters/SAST tools

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant