Skip to content

feat(trainer): pack RL multimodal samples#2889

Draft
hubert-marek wants to merge 9 commits into
mainfrom
feat/rl-mm-packing
Draft

feat(trainer): pack RL multimodal samples#2889
hubert-marek wants to merge 9 commits into
mainfrom
feat/rl-mm-packing

Conversation

@hubert-marek

@hubert-marek hubert-marek commented Jun 26, 2026

Copy link
Copy Markdown
Contributor

Summary

  • Adds RL pack_multimodal support for eager multimodal samples while preserving run/LoRA and rank modality scheduling constraints.
  • Makes seq_lens the first-class packed sample boundary contract and threads it through transport, RL data loading, trainer forward, and Qwen3.5 MRoPE.
  • Gates packed multimodal training with an explicit model capability validator, varlen attention validation, and config-layer rejection for VLM context parallelism.

Verification

  • uv run --no-project python -m py_compile src/prime_rl/trainer/batch.py src/prime_rl/trainer/rl/data.py src/prime_rl/transport/types.py tests/unit/orchestrator/test_batch.py
  • uv run --no-project --with ruff==0.13.0 ruff check --fix --config=pyproject.toml <PR files>
  • uv run --no-project --with ruff==0.13.0 ruff format --config=pyproject.toml <PR files>
  • PYTHONPATH="src:packages/prime-rl-configs/src" uv run --no-project --with pytest --with numpy --with msgspec --with torch --with jaxtyping --with tomli-w --with pydantic --with psutil --with setproctitle --with pandas --with rich --with transformers --with wandb --with 'prime-pydantic-config[toml]' --with pyzmq --with loguru python -m pytest tests/unit/orchestrator/test_batch.py -q
  • PYTHONPATH=src:packages/prime-rl-configs/src uv run --no-project --with ruff ruff check src/prime_rl/utils/vlm.py src/prime_rl/trainer/rl/train.py packages/prime-rl-configs/src/prime_rl/configs/trainer.py tests/unit/utils/test_vlm.py tests/unit/test_configs.py
  • PYTHONPATH=src:packages/prime-rl-configs/src uv run --no-project --with pytest --with torch --with numpy --with transformers --with psutil --with setproctitle --with loguru pytest tests/unit/utils/test_vlm.py -q
  • Direct TrainerConfig validation check for VLM + cp=2 + pack_multimodal=true

Note

High Risk
Changes multimodal batching, position/MRoPE handling, and varlen attention boundaries on the RL training path—incorrect packing or seq_lens would silently skew logprobs and loss.

Overview
Adds optional packed multimodal RL microbatches behind trainer.pack_multimodal (default on), with the trainer turning it off when the model, attention kernel, or context parallelism cannot support safe packing.

The packer can now co-pack text and compatible eager multimodal samples from the same run/LoRA when enabled, merging mm_kwargs and recording per-sample boundaries in new seq_lens metadata (also threaded through transport, DataLoader, padding, and the training forward). Qwen-style image_grid_thw forwards no longer pass trainer position_ids; they receive seq_lens so MRoPE and varlen attention respect packed segments. Qwen3.5 MoE VLMs advertise supports_packed_multimodal_training and use seq_lens for cu_seqlens / MRoPE construction.

Reviewed by Cursor Bugbot for commit 15956ad. Bugbot is set up for automated code reviews on this repo. Configure here.

@hubert-marek hubert-marek marked this pull request as ready for review June 26, 2026 21:46
Co-authored-by: Cursor <cursoragent@cursor.com>
Comment thread src/prime_rl/utils/vlm.py Outdated
return False


def get_packed_mm_disabled_reasons(

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lest call this, "validate_multi_modal_pack" and lets fail instead of disabling if not compatible.

also the cp X multi modal pack should be done at the config layer

Comment thread src/prime_rl/trainer/batch.py Outdated
Comment on lines +114 to +115
_validate_encoded_tensor_payload(dst)
_validate_encoded_tensor_payload(src)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really need this check here ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove them

Comment thread src/prime_rl/trainer/batch.py Outdated
Comment on lines +116 to +119
if dst.dtype != src.dtype:
raise ValueError(f"Cannot pack mm_kwargs[{key!r}] with different dtypes: {dst.dtype} vs {src.dtype}")
if len(dst.shape) == 0 or len(dst.shape) != len(src.shape) or dst.shape[1:] != src.shape[1:]:
raise ValueError(f"Cannot pack mm_kwargs[{key!r}] with incompatible shapes: {dst.shape} vs {src.shape}")

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here can't we check before ?

Comment thread src/prime_rl/trainer/batch.py Outdated
Comment on lines +98 to +142
def _encoded_tensor_itemsize(encoded: EncodedTensor) -> int:
dtype = encoded.dtype.replace("numpy.", "").replace("torch.", "")
return np.dtype(dtype).itemsize


def _validate_encoded_tensor_payload(encoded: EncodedTensor) -> None:
expected_nbytes = int(np.prod(encoded.shape)) * _encoded_tensor_itemsize(encoded)
if len(encoded.data) != expected_nbytes:
raise ValueError(
"EncodedTensor byte length does not match dtype and shape: "
f"dtype={encoded.dtype}, shape={encoded.shape}, "
f"data_nbytes={len(encoded.data)}, expected_nbytes={expected_nbytes}"
)


def _append_encoded_tensor(dst: EncodedTensor, src: EncodedTensor, key: str) -> None:
_validate_encoded_tensor_payload(dst)
_validate_encoded_tensor_payload(src)
if dst.dtype != src.dtype:
raise ValueError(f"Cannot pack mm_kwargs[{key!r}] with different dtypes: {dst.dtype} vs {src.dtype}")
if len(dst.shape) == 0 or len(dst.shape) != len(src.shape) or dst.shape[1:] != src.shape[1:]:
raise ValueError(f"Cannot pack mm_kwargs[{key!r}] with incompatible shapes: {dst.shape} vs {src.shape}")
dst.data += src.data
dst.shape[0] += src.shape[0]


def _append_mm_kwargs(dst: dict[str, EncodedTensor], src: dict[str, EncodedTensor]) -> None:
if set(dst) != set(src):
raise ValueError(f"Cannot pack mm_kwargs with different keys: {sorted(dst)} vs {sorted(src)}")
for key in dst:
_append_encoded_tensor(dst[key], src[key], key)


def _can_pack_mm_kwargs(dst: dict[str, EncodedTensor] | None, src: dict[str, EncodedTensor] | None) -> bool:
if dst is None or src is None or set(dst) != set(src):
return False
return all(
dst[key].dtype == src[key].dtype
and len(dst[key].shape) > 0
and len(dst[key].shape) == len(src[key].shape)
and dst[key].shape[1:] == src[key].shape[1:]
for key in dst
)


Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not the biggest fan here, overly defensive code there must be an easier way to do this

Comment thread src/prime_rl/trainer/batch.py Outdated
Comment on lines +143 to +144
def _has_video_tokens(sample: MicroBatch) -> bool:
return sample.mm_token_type_ids is not None and 2 in sample.mm_token_type_ids

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we working with video ? if not lets just remvoe this

@hubert-marek hubert-marek requested review from samsja and removed request for samsja June 27, 2026 00:37
@hubert-marek hubert-marek marked this pull request as draft June 27, 2026 00:39
Fail fast for unsupported multimodal packing runtimes and move the VLM context-parallelism incompatibility into config validation.

Co-authored-by: Cursor <cursoragent@cursor.com>
@hubert-marek hubert-marek changed the base branch from fix/qwen35-mrope-model to main June 27, 2026 00:42

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 15956ad26e

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

return False
if existing_mm_sample is not None and sample_is_mm:
return _can_pack_mm_kwargs(existing_mm_sample.mm_kwargs, sample.mm_kwargs)
return True

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Avoid packing text spans that contain image tokens

When pack_multimodal is enabled, this also permits a text-only sample to share a bin with an image sample. In the custom Qwen VLM path, the image scatter still counts placeholders with input_ids == image_token_id (see modeling_qwen3_5_moe.py:994-1005) rather than mm_token_type_ids, so a text-only rollout that generated or otherwise contains the image special token will be packed with pixel_values from another sample and then fail the image token/feature mismatch check. Either keep such text spans out of multimodal bins or scatter from the modality mask so packed text cannot add phantom image placeholders.

Useful? React with 👍 / 👎.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants