Skip to content

fix: load image-text policy for async grpo#6032

Open
he-yufeng wants to merge 1 commit into
huggingface:mainfrom
he-yufeng:fix/async-grpo-vlm-weight-sync
Open

fix: load image-text policy for async grpo#6032
he-yufeng wants to merge 1 commit into
huggingface:mainfrom
he-yufeng:fix/async-grpo-vlm-weight-sync

Conversation

@he-yufeng

@he-yufeng he-yufeng commented Jun 12, 2026

Copy link
Copy Markdown

What does this PR do?

Fixes #6028.

AsyncGRPOTrainer always loaded the policy with AutoModelForCausalLM. For image-text / conditional-generation checkpoints this gives the trainer model.* parameter names, while vLLM serves the same checkpoint with language_model.model.* and vision-tower keys. The first weight sync can then fail because the two sides do not agree on parameter names.

This PR keeps text models on the existing causal-LM path, but detects image-text / conditional-generation architectures from AutoConfig.architectures and loads them with AutoModelForImageTextToText. The vision tower parameters are frozen for the text-only RL use case described in the issue, so the trainer namespace matches the vLLM server while avoiding accidental vision-tower updates.

I also updated the trainer argument docs and added focused tests for both loader paths.

Before submitting

AI writing disclosure

We welcome the use of AI tools to help with contributions. For transparency and to help us improve our review process, please indicate the level of AI involvement in this PR.

  • No AI usage: the PR was written entirely by a human.
  • AI-assisted: some parts were suggested or improved by AI, but the PR was written and reviewed by a human.
  • AI-generated: the PR was mostly or fully generated by an AI tool.

Who can review?

Anyone in the community is free to review the PR once the tests have passed.

To verify

  • python -m py_compile trl\experimental\async_grpo\async_grpo_trainer.py tests\experimental\test_async_grpo_trainer.py
  • python -m ruff check trl\experimental\async_grpo\async_grpo_trainer.py tests\experimental\test_async_grpo_trainer.py
  • python -m ruff format --check trl\experimental\async_grpo\async_grpo_trainer.py tests\experimental\test_async_grpo_trainer.py
  • .\.venv\Scripts\python.exe -m pytest tests\experimental\test_async_grpo_trainer.py::TestAsyncGRPOTrainer::test_load_policy_model_keeps_text_models_on_causal_lm tests\experimental\test_async_grpo_trainer.py::TestAsyncGRPOTrainer::test_load_policy_model_uses_image_text_model_for_conditional_generation -q

Note

Medium Risk
Changes core model loading and which parameters are trainable for VL checkpoints, directly affecting weight sync and training behavior; scope is limited to async GRPO policy init.

Overview
Fixes weight-sync failures when training image-text / conditional-generation checkpoints with AsyncGRPOTrainer, where the trainer used AutoModelForCausalLM parameter names (model.*) while vLLM exposes language_model.model.* and vision keys.

Adds _load_policy_model, which inspects AutoConfig.architectures and keeps text-only models on AutoModelForCausalLM. Architectures containing ConditionalGeneration or ImageTextToText load via AutoModelForImageTextToText, with visual / vision parameters frozen for text-only RL. Trainer init now uses this helper instead of always loading causal LM. Docs and mocked unit tests cover both paths.

Reviewed by Cursor Bugbot for commit 545d902. Bugbot is set up for automated code reviews on this repo. Configure here.

@AmineDiro AmineDiro self-requested a review June 19, 2026 08:33

@AmineDiro AmineDiro left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching ! 🤗
I would like to point out if you could reuse the existing testing infra for consistency with how the trl tests VLMs everywhere else. See test_grpo_trainer.py::test_training_vlm (and the family around it): real tiny checkpoints (trl-internal-testing/tiny-*ForConditionalGeneration), parametrized per arch with version-gated skipif, @require_vision, and behavioral assertions on real parameters.

Could we follow that pattern? No vLLM server is needed since the namespace is observable directly on the loaded model:

@require_vision
def test_load_policy_model_vlm_namespace_and_freeze(self):
  model = _load_policy_model("trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration")
  names = [n for n, _ in model.named_parameters()]
  assert any(n.startswith("language_model.model.") for n in names)  # matches vLLM
  assert any("visual" in n for n in names)
  for n, p in model.named_parameters():
      assert p.requires_grad != ("visual" in n or "vision" in n)

def test_load_policy_model_text_namespace(self):
  model = _load_policy_model("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
  assert any(n.startswith("model.") for n, _ in model.named_parameters())

Thx

RewardFunc = Callable[..., list[float]]


def _load_policy_model(model_name: str):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should probably replace this with create_model_from_path(model, dtype=torch.float32, device_map=None) like GRPO and then the freeze vision layer on top. This converges with GRPOTrainer

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

AsyncGRPOTrainer: vision-language (*ForConditionalGeneration) checkpoints can't be trained (weight-sync key mismatch)

2 participants