Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions tests/experimental/test_online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from types import SimpleNamespace

import pytest
import torch
from datasets import Dataset, features, load_dataset
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
from transformers.utils import is_peft_available, is_vision_available

import trl.experimental.online_dpo.online_dpo_trainer as online_dpo_trainer_module
from trl.experimental.online_dpo import OnlineDPOConfig, OnlineDPOTrainer

from ..testing_utils import TrlTestCase, require_peft, require_torch_accelerator, require_vision, require_vllm
Expand All @@ -31,6 +35,52 @@
from transformers import AutoModelForImageTextToText, AutoProcessor


def test_vllm_server_preserves_completion_token_sequences(monkeypatch):
monkeypatch.setattr(online_dpo_trainer_module, "gather_object", lambda value: value)
monkeypatch.setattr(online_dpo_trainer_module, "broadcast_object_list", lambda value, from_process: value)

calls = []

class DummyVLLMClient:
def generate(self, **kwargs):
calls.append(kwargs)
return {
"completion_ids": [
[11, 12],
[13, 14, 15],
[21, 22],
[23],
]
}

class DummyProcessor:
def __call__(self, **kwargs):
assert kwargs["text"] == ["first prompt", "second prompt"]
return {"input_ids": torch.tensor([[101, 102], [201, 202]])}

trainer = OnlineDPOTrainer.__new__(OnlineDPOTrainer)
trainer.accelerator = SimpleNamespace(is_main_process=True, process_index=0)
trainer.args = SimpleNamespace(generation_kwargs=None)
trainer.generation_config = SimpleNamespace(max_tokens=16)
trainer.min_p = None
trainer.num_generations = 2
trainer.processing_class = DummyProcessor()
trainer.repetition_penalty = 1.0
trainer.state = SimpleNamespace(global_step=1)
trainer.temperature = 0.7
trainer.top_k = None
trainer.top_p = 0.95
trainer.vllm_client = DummyVLLMClient()
trainer._last_loaded_step = 1

completion_ids, prompt_ids = trainer._generate_vllm_server(["first prompt", "second prompt"])

assert calls[0]["prompts"] == ["first prompt", "second prompt"]
assert calls[0]["n"] == 2
assert completion_ids == [[11, 12], [21, 22], [13, 14, 15], [23]]
assert prompt_ids == [[101, 102], [201, 202], [101, 102], [201, 202]]


class TestOnlineDPOTrainer(TrlTestCase):
def setup_method(self):
self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
Expand Down
38 changes: 19 additions & 19 deletions trl/experimental/online_dpo/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,18 +644,12 @@ def _generate_vllm_server(self, prompts, images=None):
all_images = gather_object(images)

if self.accelerator.is_main_process:
# Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
# prompt individually.
ordered_set_of_prompts = all_prompts[:: self.num_generations]
if has_images:
ordered_set_of_images = [
[img] if img is not None else None for img in all_images[:: self.num_generations]
]
ordered_set_of_images = [[img] if img is not None else None for img in all_images]
else:
ordered_set_of_images = None
completion_ids = self.vllm_client.generate(
prompts=ordered_set_of_prompts,
prompts=all_prompts,
images=ordered_set_of_images,
n=self.num_generations,
repetition_penalty=self.repetition_penalty,
Expand All @@ -669,20 +663,27 @@ def _generate_vllm_server(self, prompts, images=None):
else None,
generation_kwargs=self.args.generation_kwargs,
)["completion_ids"]
# Flatten: each prompt generates 2 completions
completion_ids = [[comp_id] for prompt_completions in completion_ids for comp_id in prompt_completions]
# The server returns completions grouped by prompt. Online DPO expects them grouped by generation.
completion_ids = [
completion_ids[prompt_index * self.num_generations + generation_index]
for generation_index in range(self.num_generations)
for prompt_index in range(len(all_prompts))
]
else:
completion_ids = [None] * (len(all_prompts) * 2)
completion_ids = [None] * (len(all_prompts) * self.num_generations)

# Broadcast completions to all processes
completion_ids = broadcast_object_list(completion_ids, from_process=0)

# Each process takes its slice
process_slice = slice(
self.accelerator.process_index * len(prompts) * 2,
(self.accelerator.process_index + 1) * len(prompts) * 2,
)
completion_ids = completion_ids[process_slice]
process_start = self.accelerator.process_index * len(prompts)
process_end = process_start + len(prompts)
process_indices = [
generation_index * len(all_prompts) + prompt_index
for generation_index in range(self.num_generations)
for prompt_index in range(process_start, process_end)
]
completion_ids = [completion_ids[index] for index in process_indices]

# Create prompt_ids by tokenizing locally
prompt_inputs = self.processing_class(
Expand All @@ -692,9 +693,8 @@ def _generate_vllm_server(self, prompts, images=None):
padding_side="left",
add_special_tokens=False,
)
prompt_ids = []
for prompt_tokens in prompt_inputs["input_ids"]:
prompt_ids.extend([prompt_tokens.tolist(), prompt_tokens.tolist()]) # 2 copies for 2 completions
prompt_token_ids = [prompt_tokens.tolist() for prompt_tokens in prompt_inputs["input_ids"]]
prompt_ids = [prompt_tokens for _ in range(self.num_generations) for prompt_tokens in prompt_token_ids]
return completion_ids, prompt_ids

def _generate_vllm_colocate(self, prompts, images=None):
Expand Down