[WS1]feat: add embedding invariance and LM-head linear_logp routing#190
[WS1]feat: add embedding invariance and LM-head linear_logp routing#190inaniloquentee wants to merge 2 commits into
Conversation
Signed-off-by: inaniloquentee <3051000145@qq.com>
📝 WalkthroughWalkthroughAdds deterministic embedding invariance tests and updates DeepSpeed training to score LM-head outputs through a linear-logp path with new routing, masking, metrics, and worker coverage. ChangesEmbedding lookup invariance tests
DeepSpeed linear-logp training
ROCm flash attention typing
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Signed-off-by: inaniloquentee <3051000145@qq.com>
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@rl_engine/executors/deepspeed_trainer.py`:
- Around line 301-312: The `forward` method in `deepspeed_trainer.py` is reading
`self.lm_head.weight` and `self.lm_head.bias` directly on the `target_ids` path,
which bypasses ZeRO-3 parameter-gather hooks. Move the custom scoring logic into
a dedicated `nn.Module` such as `_LinearLogpHead` and call it through its
`forward` method from `forward`, so both the normal logits path and the
`target_ids` scoring path go through module execution and trigger DeepSpeed
hooks correctly.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: f6f14192-299a-49eb-8654-92c7ef959302
📒 Files selected for processing (3)
rl_engine/executors/deepspeed_trainer.pyrl_engine/kernels/ops/rocm/attention/flash_attn.pytests/test_deepspeed_training_worker.py
✅ Files skipped from review due to trivial changes (1)
- rl_engine/kernels/ops/rocm/attention/flash_attn.py
| def forward( | ||
| self, | ||
| input_ids: torch.Tensor, | ||
| *, | ||
| target_ids: Optional[torch.Tensor] = None, | ||
| ) -> torch.Tensor: | ||
| hidden = self.embedding(input_ids.long()) | ||
| if target_ids is None: | ||
| return self.lm_head(hidden) | ||
| if self.linear_logp_op is None: | ||
| raise ValueError("target_ids scoring requires a linear_logp_op") | ||
| return self.linear_logp_op(hidden, self.lm_head.weight, target_ids, self.lm_head.bias) |
There was a problem hiding this comment.
🩺 Stability & Availability | 🟠 Major
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Verify whether this PR has a ZeRO-3 train-path test and where the LM-head
# parameters are accessed directly.
rg -n -C3 'zero_stage\s*=\s*3|"zero_optimization": \{"stage": 3\}' tests rl_engine
rg -n -C3 'lm_head\.(weight|bias)|linear_logp_op' rl_engine/executors/deepspeed_trainer.pyRepository: RL-Align/RL-Kernel
Length of output: 3317
Direct parameter access in ZeRO-3 path bypasses module hooks
In rl_engine/executors/deepspeed_trainer.py (lines 301–312), the forward method accesses self.lm_head.weight and self.lm_head.bias directly when target_ids is provided. Under DeepSpeed ZeRO-3, model parameters are partitioned across ranks and must be gathered via the module’s forward method to trigger the necessary parameter-gather hooks. Direct attribute access bypasses these hooks, which can lead to runtime failures or silent incorrect results in distributed training.
Although tests exist for zero_stage=3, they do not explicitly verify the code path where target_ids is passed and linear_logp_op is invoked. To ensure ZeRO-3 compatibility, encapsulate the custom scoring logic inside a dedicated nn.Module (e.g., _LinearLogpHead) and route both standard and custom paths through its forward method. This guarantees parameter-gather hooks are invoked regardless of the execution path.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@rl_engine/executors/deepspeed_trainer.py` around lines 301 - 312, The
`forward` method in `deepspeed_trainer.py` is reading `self.lm_head.weight` and
`self.lm_head.bias` directly on the `target_ids` path, which bypasses ZeRO-3
parameter-gather hooks. Move the custom scoring logic into a dedicated
`nn.Module` such as `_LinearLogpHead` and call it through its `forward` method
from `forward`, so both the normal logits path and the `target_ids` scoring path
go through module execution and trigger DeepSpeed hooks correctly.
Summary
Add the second half of issue151 by routing the toy DeepSpeed training worker through the deterministic
linear_logpLM-head path, while keeping the earlier embedding lookup invariance coverage.What this PR does
linear_logpinstead of materializing logits and callingselected_logprobs_reference.Validation
Signed-off-by.black,isort,flake8,ruff,git diff --check.mypy --python-version 3.10 --ignore-missing-imports rl_engine/.python -m pytest -q tests/test_deepspeed_training_worker.py tests/test_linear_logp.py tests/test_embedding_lookup_invariance.py tests/test_reference_ops.py rl_engine/tests/test_dispatch.pyNotes
pre-commit runend-to-end because hook bootstrap in this environment failed while installing remote hook deps from the configured package mirror, so I ran the equivalent local hooks directly.testsalso depends on optional env pieces (triton,tabulate) outside this branch; the CI-targeted subset above passed cleanly.