Warn when sequence-level importance sampling is combined with a token-summed loss type#6042
Warn when sequence-level importance sampling is combined with a token-summed loss type#6042discobot wants to merge 1 commit into
Conversation
…-summed loss type With importance_sampling_level='sequence' the per-token loss is (B, 1), and loss types that sum per-token contributions (bnpo, dr_grpo, dapo, cispo) broadcast it against the (B, T) completion mask, effectively weighting each sequence by its completion length instead of optimizing the per-sequence objective GSPO prescribes. The recommended combo (loss_type='grpo') was only documented in the paper index, so the default config triggered this silently. Add a warning in GRPOTrainer.__init__ mirroring the existing luspo/vespo mismatch checks, plus regression tests for the warning and the recommended setup.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 6e1a5005f1
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| "set to `'token'` (the default)." | ||
| ) | ||
|
|
||
| if args.importance_sampling_level == "sequence" and args.loss_type in ["bnpo", "dr_grpo", "dapo", "cispo"]: |
There was a problem hiding this comment.
Propagate the sequence-level warning to SDPO
This only adds the warning to GRPOTrainer, but I checked the copied SDPO policy-loss path in trl/experimental/sdpo/sdpo_trainer.py:1160-1176: it also supports importance_sampling_level='sequence' and then token-summed bnpo, dr_grpo, and dapo reductions. With SDPOConfig(importance_sampling_level='sequence', loss_type='dapo'), users still silently get length-weighted sequence losses, so the duplicated trainer logic remains inconsistent with the repository requirement that these patterns stay aligned.
Useful? React with 👍 / 👎.
What does this PR do?
Fixes #3823.
Combining
importance_sampling_level="sequence"(the GSPO setup) with a loss type thatsums per-token contributions (
bnpo,dr_grpo,cispo, or the defaultdapo) silentlylength-weights each sequence's loss instead of applying the per-sequence objective GSPO
prescribes.
This PR adds the warning suggested in
#3823 (comment), placed right
after the existing
luspo/vespomismatch checks inGRPOTrainer.__init__andmirroring their style. It points users to
loss_type="grpo", matching the GSPO recipe inthe paper index. The experimental
PAPOTrainerreusesGRPOTrainer.__init__, so itinherits the check.
One thing that turned up while testing: the existing
test_train_sequence_importance_samplingitself uses the footgun combo (sequence-levelimportance sampling with the default
dapo), so the new warning now fires during itssetup. I left it unchanged to keep the diff minimal — happy to switch it to
loss_type="grpo"if preferred.Tests: a parametrized test asserting the warning fires for all four affected loss types
(fails on
mainwithout the fix), and a negative test asserting the recommended GSPOcombo stays silent. Also ran the other warning tests and
test_train_loss_typeslocally(CPU).
Before submitting
importance_sampling_level == "sequence") #3823)AI writing disclosure
We welcome the use of AI tools to help with contributions. For transparency and to help us improve our review process, please indicate the level of AI involvement in this PR.
The fix and tests were developed with Claude Code and verified by running the test suite locally.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
Note
Low Risk
Init-time logging only; training math and defaults are unchanged aside from surfacing a previously silent config pitfall.
Overview
Adds a startup warning in
GRPOTrainer.__init__whenimportance_sampling_level="sequence"is used withloss_typeinbnpo,dr_grpo,dapo, orcispo. Those losses sum per-token terms, so with sequence-level importance sampling the objective effectively length-weights completions instead of matching the GSPO per-sequence setup. The message recommendsloss_type="grpo"and links to the GSPO paper index, consistent with existingluspo/vespoconfig mismatch warnings.Tests cover the four loss types (warning present) and
sequence+grpo(no warning).PAPOTrainerinherits the check viaGRPOTrainer.__init__.Reviewed by Cursor Bugbot for commit 6e1a500. Bugbot is set up for automated code reviews on this repo. Configure here.