Skip to content

[WS1]feat: add embedding invariance and LM-head linear_logp routing#190

Open
inaniloquentee wants to merge 2 commits into
mainfrom
codex/issue-151-embedding-lookup-invariance
Open

[WS1]feat: add embedding invariance and LM-head linear_logp routing#190
inaniloquentee wants to merge 2 commits into
mainfrom
codex/issue-151-embedding-lookup-invariance

Conversation

@inaniloquentee

@inaniloquentee inaniloquentee commented Jun 25, 2026

Copy link
Copy Markdown
Collaborator

Summary

Add the second half of issue151 by routing the toy DeepSpeed training worker through the deterministic linear_logp LM-head path, while keeping the earlier embedding lookup invariance coverage.

What this PR does

  • Keeps the embedding lookup invariance tests across RL batch layouts.
  • Routes the DeepSpeed worker LM-head projection through linear_logp instead of materializing logits and calling selected_logprobs_reference.
  • Adds safe masked token-id handling for inactive positions and explicit range validation for active token ids.
  • Records the LM-head projection backend in worker metrics.
  • Fixes a small ROCm FlashAttention type annotation so repo-wide mypy passes under the CI Python target.

Validation

  • DCO: both commits on this branch include Signed-off-by.
  • Formatting/lint: black, isort, flake8, ruff, git diff --check.
  • Type check: mypy --python-version 3.10 --ignore-missing-imports rl_engine/.
  • Tests: python -m pytest -q tests/test_deepspeed_training_worker.py tests/test_linear_logp.py tests/test_embedding_lookup_invariance.py tests/test_reference_ops.py rl_engine/tests/test_dispatch.py

Notes

  • I could not run pre-commit run end-to-end because hook bootstrap in this environment failed while installing remote hook deps from the configured package mirror, so I ran the equivalent local hooks directly.
  • Full local tests also depends on optional env pieces (triton, tabulate) outside this branch; the CI-targeted subset above passed cleanly.

Signed-off-by: inaniloquentee <3051000145@qq.com>
@coderabbitai

coderabbitai Bot commented Jun 25, 2026

Copy link
Copy Markdown

Review Change Stack

📝 Walkthrough

Walkthrough

Adds deterministic embedding invariance tests and updates DeepSpeed training to score LM-head outputs through a linear-logp path with new routing, masking, metrics, and worker coverage.

Changes

Embedding lookup invariance tests

Layer / File(s) Summary
Test fixtures and batch setup
tests/test_embedding_lookup_invariance.py
Defines probe constants, batch layouts, CUDA gating, deterministic embedding weights, and synthetic batch construction with probe tokens and masks.
Batch permutation and probe checks
tests/test_embedding_lookup_invariance.py
Adds row-permutation logic that keeps batch tensors aligned and a shared assertion helper for probe-position embeddings.
Layout invariance test
tests/test_embedding_lookup_invariance.py
Compares probe embeddings across all configured batch layouts against reference vectors.
Permutation and masked-tail tests
tests/test_embedding_lookup_invariance.py
Verifies row-order preservation under permutation and unchanged active embeddings after masked-tail token mutations.

DeepSpeed linear-logp training

Layer / File(s) Summary
Model and scoring path
rl_engine/executors/deepspeed_trainer.py
Replaces the Sequential LM head with an embedding LM-head model, adds device-based linear-logp selection, validates token ids, and extracts log-probabilities from engine outputs.
Metrics and worker tests
rl_engine/executors/deepspeed_trainer.py, tests/test_deepspeed_training_worker.py
Extends training metrics with projection metadata and adds worker tests for ragged rollout handling, projection routing, and safe token-id behavior.

ROCm flash attention typing

Layer / File(s) Summary
Flash attention annotation
rl_engine/kernels/ops/rocm/attention/flash_attn.py
Adds a callable import and annotates the attention op attribute in the ROCm flash attention wrapper.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Poem

I hopped through layouts, neat and bright,
With probe tokens holding tight.
DeepSpeed learned a logp trail,
While masked tails and rows stayed hale.
🐇✨

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title is concise and accurately captures the two main changes: embedding invariance tests and LM-head linear_logp routing.
✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch codex/issue-151-embedding-lookup-invariance

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

Signed-off-by: inaniloquentee <3051000145@qq.com>
@inaniloquentee inaniloquentee changed the title test: add embedding lookup invariance coverage feat: add embedding invariance and LM-head linear_logp routing Jun 26, 2026

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@rl_engine/executors/deepspeed_trainer.py`:
- Around line 301-312: The `forward` method in `deepspeed_trainer.py` is reading
`self.lm_head.weight` and `self.lm_head.bias` directly on the `target_ids` path,
which bypasses ZeRO-3 parameter-gather hooks. Move the custom scoring logic into
a dedicated `nn.Module` such as `_LinearLogpHead` and call it through its
`forward` method from `forward`, so both the normal logits path and the
`target_ids` scoring path go through module execution and trigger DeepSpeed
hooks correctly.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: f6f14192-299a-49eb-8654-92c7ef959302

📥 Commits

Reviewing files that changed from the base of the PR and between b4fffba and 80cff12.

📒 Files selected for processing (3)
  • rl_engine/executors/deepspeed_trainer.py
  • rl_engine/kernels/ops/rocm/attention/flash_attn.py
  • tests/test_deepspeed_training_worker.py
✅ Files skipped from review due to trivial changes (1)
  • rl_engine/kernels/ops/rocm/attention/flash_attn.py

Comment on lines +301 to +312
def forward(
self,
input_ids: torch.Tensor,
*,
target_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden = self.embedding(input_ids.long())
if target_ids is None:
return self.lm_head(hidden)
if self.linear_logp_op is None:
raise ValueError("target_ids scoring requires a linear_logp_op")
return self.linear_logp_op(hidden, self.lm_head.weight, target_ids, self.lm_head.bias)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify whether this PR has a ZeRO-3 train-path test and where the LM-head
# parameters are accessed directly.
rg -n -C3 'zero_stage\s*=\s*3|"zero_optimization": \{"stage": 3\}' tests rl_engine
rg -n -C3 'lm_head\.(weight|bias)|linear_logp_op' rl_engine/executors/deepspeed_trainer.py

Repository: RL-Align/RL-Kernel

Length of output: 3317


Direct parameter access in ZeRO-3 path bypasses module hooks

In rl_engine/executors/deepspeed_trainer.py (lines 301–312), the forward method accesses self.lm_head.weight and self.lm_head.bias directly when target_ids is provided. Under DeepSpeed ZeRO-3, model parameters are partitioned across ranks and must be gathered via the module’s forward method to trigger the necessary parameter-gather hooks. Direct attribute access bypasses these hooks, which can lead to runtime failures or silent incorrect results in distributed training.

Although tests exist for zero_stage=3, they do not explicitly verify the code path where target_ids is passed and linear_logp_op is invoked. To ensure ZeRO-3 compatibility, encapsulate the custom scoring logic inside a dedicated nn.Module (e.g., _LinearLogpHead) and route both standard and custom paths through its forward method. This guarantees parameter-gather hooks are invoked regardless of the execution path.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/executors/deepspeed_trainer.py` around lines 301 - 312, The
`forward` method in `deepspeed_trainer.py` is reading `self.lm_head.weight` and
`self.lm_head.bias` directly on the `target_ids` path, which bypasses ZeRO-3
parameter-gather hooks. Move the custom scoring logic into a dedicated
`nn.Module` such as `_LinearLogpHead` and call it through its `forward` method
from `forward`, so both the normal logits path and the `target_ids` scoring path
go through module execution and trigger DeepSpeed hooks correctly.

@Flink-ddd Flink-ddd changed the title feat: add embedding invariance and LM-head linear_logp routing [WS1]feat: add embedding invariance and LM-head linear_logp routing Jun 27, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant