fix: preserve OnlineDPO vLLM completion ids#6038
Open
he-yufeng wants to merge 1 commit into
Open
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Fixes #5514.
trl vllm-servealready returnscompletion_idsas one token-id list per generated completion.OnlineDPOTrainer._generate_vllm_server()was treating that flat list as if it were grouped by prompt, so it split each token sequence into single-token completions.This PR keeps each returned completion token sequence intact, then reorders the server output into the generation-major order that Online DPO expects for reward splitting. It also keeps the local
prompt_idsin the same order.Before submitting
AI writing disclosure
We welcome the use of AI tools to help with contributions. For transparency and to help us improve our review process, please indicate the level of AI involvement in this PR.
Test plan
python -m py_compile trl\experimental\online_dpo\online_dpo_trainer.py tests\experimental\test_online_dpo_trainer.pypython -c "import sys, types, importlib.machinery; m=types.ModuleType('kernels'); m.__spec__=importlib.machinery.ModuleSpec('kernels', loader=None); sys.modules['kernels']=m; import pytest; raise SystemExit(pytest.main(['tests/experimental/test_online_dpo_trainer.py', '-k', 'vllm_server_preserves_completion_token_sequences', '-q']))"python -m ruff check trl\experimental\online_dpo\online_dpo_trainer.py tests\experimental\test_online_dpo_trainer.pypython -m ruff format --check trl\experimental\online_dpo\online_dpo_trainer.py tests\experimental\test_online_dpo_trainer.pygit diff --checkThe targeted pytest command temporarily stubs the installed
kernelspackage in the test process because this local environment has an import-timetransformers/kernelsversion mismatch (ValueError: Either a revision or a version must be specified.) before test collection. The regression itself is a pure mock test and does not require vLLM or kernels.Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
Note
Medium Risk
Changes Online DPO vLLM-server generation and distributed indexing; wrong ordering would corrupt training, but scope is limited to that code path and is covered by a regression test.
Overview
Fixes incorrect handling of
trl vllm-serveoutput inOnlineDPOTrainer._generate_vllm_server()(#5514). The server already returns one full token-id list per completion; the trainer was flattening those lists into single-token “completions” and deduplicating prompts before calling generate.The path now passes all gathered prompts to the vLLM client with
n=num_generations, keeps each returned completion sequence intact, and reorders results into the generation-major layout Online DPO uses for chosen/rejected reward pairing. Multi-process slicing andprompt_idsduplication usenum_generationsinstead of hardcoded2.A mocked regression test
test_vllm_server_preserves_completion_token_sequenceslocks in the expected reordering and prompt alignment.Reviewed by Cursor Bugbot for commit 7db9a15. Bugbot is set up for automated code reviews on this repo. Configure here.