fix: load image-text policy for async grpo#6032
Conversation
AmineDiro
left a comment
There was a problem hiding this comment.
Thanks for catching ! 🤗
I would like to point out if you could reuse the existing testing infra for consistency with how the trl tests VLMs everywhere else. See test_grpo_trainer.py::test_training_vlm (and the family around it): real tiny checkpoints (trl-internal-testing/tiny-*ForConditionalGeneration), parametrized per arch with version-gated skipif, @require_vision, and behavioral assertions on real parameters.
Could we follow that pattern? No vLLM server is needed since the namespace is observable directly on the loaded model:
@require_vision
def test_load_policy_model_vlm_namespace_and_freeze(self):
model = _load_policy_model("trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration")
names = [n for n, _ in model.named_parameters()]
assert any(n.startswith("language_model.model.") for n in names) # matches vLLM
assert any("visual" in n for n in names)
for n, p in model.named_parameters():
assert p.requires_grad != ("visual" in n or "vision" in n)
def test_load_policy_model_text_namespace(self):
model = _load_policy_model("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
assert any(n.startswith("model.") for n, _ in model.named_parameters())Thx
| RewardFunc = Callable[..., list[float]] | ||
|
|
||
|
|
||
| def _load_policy_model(model_name: str): |
There was a problem hiding this comment.
I think we should probably replace this with create_model_from_path(model, dtype=torch.float32, device_map=None) like GRPO and then the freeze vision layer on top. This converges with GRPOTrainer
What does this PR do?
Fixes #6028.
AsyncGRPOTraineralways loaded the policy withAutoModelForCausalLM. For image-text / conditional-generation checkpoints this gives the trainermodel.*parameter names, while vLLM serves the same checkpoint withlanguage_model.model.*and vision-tower keys. The first weight sync can then fail because the two sides do not agree on parameter names.This PR keeps text models on the existing causal-LM path, but detects image-text / conditional-generation architectures from
AutoConfig.architecturesand loads them withAutoModelForImageTextToText. The vision tower parameters are frozen for the text-only RL use case described in the issue, so the trainer namespace matches the vLLM server while avoiding accidental vision-tower updates.I also updated the trainer argument docs and added focused tests for both loader paths.
Before submitting
AI writing disclosure
We welcome the use of AI tools to help with contributions. For transparency and to help us improve our review process, please indicate the level of AI involvement in this PR.
Who can review?
Anyone in the community is free to review the PR once the tests have passed.
To verify
python -m py_compile trl\experimental\async_grpo\async_grpo_trainer.py tests\experimental\test_async_grpo_trainer.pypython -m ruff check trl\experimental\async_grpo\async_grpo_trainer.py tests\experimental\test_async_grpo_trainer.pypython -m ruff format --check trl\experimental\async_grpo\async_grpo_trainer.py tests\experimental\test_async_grpo_trainer.py.\.venv\Scripts\python.exe -m pytest tests\experimental\test_async_grpo_trainer.py::TestAsyncGRPOTrainer::test_load_policy_model_keeps_text_models_on_causal_lm tests\experimental\test_async_grpo_trainer.py::TestAsyncGRPOTrainer::test_load_policy_model_uses_image_text_model_for_conditional_generation -qNote
Medium Risk
Changes core model loading and which parameters are trainable for VL checkpoints, directly affecting weight sync and training behavior; scope is limited to async GRPO policy init.
Overview
Fixes weight-sync failures when training image-text / conditional-generation checkpoints with
AsyncGRPOTrainer, where the trainer usedAutoModelForCausalLMparameter names (model.*) while vLLM exposeslanguage_model.model.*and vision keys.Adds
_load_policy_model, which inspectsAutoConfig.architecturesand keeps text-only models onAutoModelForCausalLM. Architectures containingConditionalGenerationorImageTextToTextload viaAutoModelForImageTextToText, withvisual/visionparameters frozen for text-only RL. Trainer init now uses this helper instead of always loading causal LM. Docs and mocked unit tests cover both paths.Reviewed by Cursor Bugbot for commit 545d902. Bugbot is set up for automated code reviews on this repo. Configure here.