feat(trainer): pack RL multimodal samples#2889
Conversation
Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
| return False | ||
|
|
||
|
|
||
| def get_packed_mm_disabled_reasons( |
There was a problem hiding this comment.
lest call this, "validate_multi_modal_pack" and lets fail instead of disabling if not compatible.
also the cp X multi modal pack should be done at the config layer
| _validate_encoded_tensor_payload(dst) | ||
| _validate_encoded_tensor_payload(src) |
There was a problem hiding this comment.
do we really need this check here ?
There was a problem hiding this comment.
We can remove them
| if dst.dtype != src.dtype: | ||
| raise ValueError(f"Cannot pack mm_kwargs[{key!r}] with different dtypes: {dst.dtype} vs {src.dtype}") | ||
| if len(dst.shape) == 0 or len(dst.shape) != len(src.shape) or dst.shape[1:] != src.shape[1:]: | ||
| raise ValueError(f"Cannot pack mm_kwargs[{key!r}] with incompatible shapes: {dst.shape} vs {src.shape}") |
There was a problem hiding this comment.
same here can't we check before ?
| def _encoded_tensor_itemsize(encoded: EncodedTensor) -> int: | ||
| dtype = encoded.dtype.replace("numpy.", "").replace("torch.", "") | ||
| return np.dtype(dtype).itemsize | ||
|
|
||
|
|
||
| def _validate_encoded_tensor_payload(encoded: EncodedTensor) -> None: | ||
| expected_nbytes = int(np.prod(encoded.shape)) * _encoded_tensor_itemsize(encoded) | ||
| if len(encoded.data) != expected_nbytes: | ||
| raise ValueError( | ||
| "EncodedTensor byte length does not match dtype and shape: " | ||
| f"dtype={encoded.dtype}, shape={encoded.shape}, " | ||
| f"data_nbytes={len(encoded.data)}, expected_nbytes={expected_nbytes}" | ||
| ) | ||
|
|
||
|
|
||
| def _append_encoded_tensor(dst: EncodedTensor, src: EncodedTensor, key: str) -> None: | ||
| _validate_encoded_tensor_payload(dst) | ||
| _validate_encoded_tensor_payload(src) | ||
| if dst.dtype != src.dtype: | ||
| raise ValueError(f"Cannot pack mm_kwargs[{key!r}] with different dtypes: {dst.dtype} vs {src.dtype}") | ||
| if len(dst.shape) == 0 or len(dst.shape) != len(src.shape) or dst.shape[1:] != src.shape[1:]: | ||
| raise ValueError(f"Cannot pack mm_kwargs[{key!r}] with incompatible shapes: {dst.shape} vs {src.shape}") | ||
| dst.data += src.data | ||
| dst.shape[0] += src.shape[0] | ||
|
|
||
|
|
||
| def _append_mm_kwargs(dst: dict[str, EncodedTensor], src: dict[str, EncodedTensor]) -> None: | ||
| if set(dst) != set(src): | ||
| raise ValueError(f"Cannot pack mm_kwargs with different keys: {sorted(dst)} vs {sorted(src)}") | ||
| for key in dst: | ||
| _append_encoded_tensor(dst[key], src[key], key) | ||
|
|
||
|
|
||
| def _can_pack_mm_kwargs(dst: dict[str, EncodedTensor] | None, src: dict[str, EncodedTensor] | None) -> bool: | ||
| if dst is None or src is None or set(dst) != set(src): | ||
| return False | ||
| return all( | ||
| dst[key].dtype == src[key].dtype | ||
| and len(dst[key].shape) > 0 | ||
| and len(dst[key].shape) == len(src[key].shape) | ||
| and dst[key].shape[1:] == src[key].shape[1:] | ||
| for key in dst | ||
| ) | ||
|
|
||
|
|
There was a problem hiding this comment.
not the biggest fan here, overly defensive code there must be an easier way to do this
| def _has_video_tokens(sample: MicroBatch) -> bool: | ||
| return sample.mm_token_type_ids is not None and 2 in sample.mm_token_type_ids |
There was a problem hiding this comment.
are we working with video ? if not lets just remvoe this
Fail fast for unsupported multimodal packing runtimes and move the VLM context-parallelism incompatibility into config validation. Co-authored-by: Cursor <cursoragent@cursor.com>
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 15956ad26e
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| return False | ||
| if existing_mm_sample is not None and sample_is_mm: | ||
| return _can_pack_mm_kwargs(existing_mm_sample.mm_kwargs, sample.mm_kwargs) | ||
| return True |
There was a problem hiding this comment.
Avoid packing text spans that contain image tokens
When pack_multimodal is enabled, this also permits a text-only sample to share a bin with an image sample. In the custom Qwen VLM path, the image scatter still counts placeholders with input_ids == image_token_id (see modeling_qwen3_5_moe.py:994-1005) rather than mm_token_type_ids, so a text-only rollout that generated or otherwise contains the image special token will be packed with pixel_values from another sample and then fail the image token/feature mismatch check. Either keep such text spans out of multimodal bins or scatter from the modality mask so packed text cannot add phantom image placeholders.
Useful? React with 👍 / 👎.
Summary
pack_multimodalsupport for eager multimodal samples while preserving run/LoRA and rank modality scheduling constraints.seq_lensthe first-class packed sample boundary contract and threads it through transport, RL data loading, trainer forward, and Qwen3.5 MRoPE.Verification
uv run --no-project python -m py_compile src/prime_rl/trainer/batch.py src/prime_rl/trainer/rl/data.py src/prime_rl/transport/types.py tests/unit/orchestrator/test_batch.pyuv run --no-project --with ruff==0.13.0 ruff check --fix --config=pyproject.toml <PR files>uv run --no-project --with ruff==0.13.0 ruff format --config=pyproject.toml <PR files>PYTHONPATH="src:packages/prime-rl-configs/src" uv run --no-project --with pytest --with numpy --with msgspec --with torch --with jaxtyping --with tomli-w --with pydantic --with psutil --with setproctitle --with pandas --with rich --with transformers --with wandb --with 'prime-pydantic-config[toml]' --with pyzmq --with loguru python -m pytest tests/unit/orchestrator/test_batch.py -qPYTHONPATH=src:packages/prime-rl-configs/src uv run --no-project --with ruff ruff check src/prime_rl/utils/vlm.py src/prime_rl/trainer/rl/train.py packages/prime-rl-configs/src/prime_rl/configs/trainer.py tests/unit/utils/test_vlm.py tests/unit/test_configs.pyPYTHONPATH=src:packages/prime-rl-configs/src uv run --no-project --with pytest --with torch --with numpy --with transformers --with psutil --with setproctitle --with loguru pytest tests/unit/utils/test_vlm.py -qTrainerConfigvalidation check for VLM +cp=2+pack_multimodal=trueNote
High Risk
Changes multimodal batching, position/MRoPE handling, and varlen attention boundaries on the RL training path—incorrect packing or
seq_lenswould silently skew logprobs and loss.Overview
Adds optional packed multimodal RL microbatches behind
trainer.pack_multimodal(default on), with the trainer turning it off when the model, attention kernel, or context parallelism cannot support safe packing.The packer can now co-pack text and compatible eager multimodal samples from the same run/LoRA when enabled, merging
mm_kwargsand recording per-sample boundaries in newseq_lensmetadata (also threaded through transport,DataLoader, padding, and the training forward). Qwen-styleimage_grid_thwforwards no longer pass trainerposition_ids; they receiveseq_lensso MRoPE and varlen attention respect packed segments. Qwen3.5 MoE VLMs advertisesupports_packed_multimodal_trainingand useseq_lensforcu_seqlens/ MRoPE construction.Reviewed by Cursor Bugbot for commit 15956ad. Bugbot is set up for automated code reviews on this repo. Configure here.