From 7db9a15911739087286b5bad93fcd4714a8cb700 Mon Sep 17 00:00:00 2001 From: Yufeng He <40085740+he-yufeng@users.noreply.github.com> Date: Sat, 13 Jun 2026 12:37:50 +0800 Subject: [PATCH] fix: preserve OnlineDPO vLLM completion ids --- tests/experimental/test_online_dpo_trainer.py | 50 +++++++++++++++++++ .../online_dpo/online_dpo_trainer.py | 38 +++++++------- 2 files changed, 69 insertions(+), 19 deletions(-) diff --git a/tests/experimental/test_online_dpo_trainer.py b/tests/experimental/test_online_dpo_trainer.py index dc22aedcfdf..a44e7d6feb4 100644 --- a/tests/experimental/test_online_dpo_trainer.py +++ b/tests/experimental/test_online_dpo_trainer.py @@ -12,11 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from types import SimpleNamespace + import pytest +import torch from datasets import Dataset, features, load_dataset from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer from transformers.utils import is_peft_available, is_vision_available +import trl.experimental.online_dpo.online_dpo_trainer as online_dpo_trainer_module from trl.experimental.online_dpo import OnlineDPOConfig, OnlineDPOTrainer from ..testing_utils import TrlTestCase, require_peft, require_torch_accelerator, require_vision, require_vllm @@ -31,6 +35,52 @@ from transformers import AutoModelForImageTextToText, AutoProcessor +def test_vllm_server_preserves_completion_token_sequences(monkeypatch): + monkeypatch.setattr(online_dpo_trainer_module, "gather_object", lambda value: value) + monkeypatch.setattr(online_dpo_trainer_module, "broadcast_object_list", lambda value, from_process: value) + + calls = [] + + class DummyVLLMClient: + def generate(self, **kwargs): + calls.append(kwargs) + return { + "completion_ids": [ + [11, 12], + [13, 14, 15], + [21, 22], + [23], + ] + } + + class DummyProcessor: + def __call__(self, **kwargs): + assert kwargs["text"] == ["first prompt", "second prompt"] + return {"input_ids": torch.tensor([[101, 102], [201, 202]])} + + trainer = OnlineDPOTrainer.__new__(OnlineDPOTrainer) + trainer.accelerator = SimpleNamespace(is_main_process=True, process_index=0) + trainer.args = SimpleNamespace(generation_kwargs=None) + trainer.generation_config = SimpleNamespace(max_tokens=16) + trainer.min_p = None + trainer.num_generations = 2 + trainer.processing_class = DummyProcessor() + trainer.repetition_penalty = 1.0 + trainer.state = SimpleNamespace(global_step=1) + trainer.temperature = 0.7 + trainer.top_k = None + trainer.top_p = 0.95 + trainer.vllm_client = DummyVLLMClient() + trainer._last_loaded_step = 1 + + completion_ids, prompt_ids = trainer._generate_vllm_server(["first prompt", "second prompt"]) + + assert calls[0]["prompts"] == ["first prompt", "second prompt"] + assert calls[0]["n"] == 2 + assert completion_ids == [[11, 12], [21, 22], [13, 14, 15], [23]] + assert prompt_ids == [[101, 102], [201, 202], [101, 102], [201, 202]] + + class TestOnlineDPOTrainer(TrlTestCase): def setup_method(self): self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" diff --git a/trl/experimental/online_dpo/online_dpo_trainer.py b/trl/experimental/online_dpo/online_dpo_trainer.py index b54c4962e48..8d7dcea56b2 100644 --- a/trl/experimental/online_dpo/online_dpo_trainer.py +++ b/trl/experimental/online_dpo/online_dpo_trainer.py @@ -644,18 +644,12 @@ def _generate_vllm_server(self, prompts, images=None): all_images = gather_object(images) if self.accelerator.is_main_process: - # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate - # num_generations outputs for each one. This is faster than generating outputs for each duplicate - # prompt individually. - ordered_set_of_prompts = all_prompts[:: self.num_generations] if has_images: - ordered_set_of_images = [ - [img] if img is not None else None for img in all_images[:: self.num_generations] - ] + ordered_set_of_images = [[img] if img is not None else None for img in all_images] else: ordered_set_of_images = None completion_ids = self.vllm_client.generate( - prompts=ordered_set_of_prompts, + prompts=all_prompts, images=ordered_set_of_images, n=self.num_generations, repetition_penalty=self.repetition_penalty, @@ -669,20 +663,27 @@ def _generate_vllm_server(self, prompts, images=None): else None, generation_kwargs=self.args.generation_kwargs, )["completion_ids"] - # Flatten: each prompt generates 2 completions - completion_ids = [[comp_id] for prompt_completions in completion_ids for comp_id in prompt_completions] + # The server returns completions grouped by prompt. Online DPO expects them grouped by generation. + completion_ids = [ + completion_ids[prompt_index * self.num_generations + generation_index] + for generation_index in range(self.num_generations) + for prompt_index in range(len(all_prompts)) + ] else: - completion_ids = [None] * (len(all_prompts) * 2) + completion_ids = [None] * (len(all_prompts) * self.num_generations) # Broadcast completions to all processes completion_ids = broadcast_object_list(completion_ids, from_process=0) # Each process takes its slice - process_slice = slice( - self.accelerator.process_index * len(prompts) * 2, - (self.accelerator.process_index + 1) * len(prompts) * 2, - ) - completion_ids = completion_ids[process_slice] + process_start = self.accelerator.process_index * len(prompts) + process_end = process_start + len(prompts) + process_indices = [ + generation_index * len(all_prompts) + prompt_index + for generation_index in range(self.num_generations) + for prompt_index in range(process_start, process_end) + ] + completion_ids = [completion_ids[index] for index in process_indices] # Create prompt_ids by tokenizing locally prompt_inputs = self.processing_class( @@ -692,9 +693,8 @@ def _generate_vllm_server(self, prompts, images=None): padding_side="left", add_special_tokens=False, ) - prompt_ids = [] - for prompt_tokens in prompt_inputs["input_ids"]: - prompt_ids.extend([prompt_tokens.tolist(), prompt_tokens.tolist()]) # 2 copies for 2 completions + prompt_token_ids = [prompt_tokens.tolist() for prompt_tokens in prompt_inputs["input_ids"]] + prompt_ids = [prompt_tokens for _ in range(self.num_generations) for prompt_tokens in prompt_token_ids] return completion_ids, prompt_ids def _generate_vllm_colocate(self, prompts, images=None):