Skip to content

[model, feature] qwen3-omni: add packed sequence support and shared sequence utilities#4304

Open
hbhflw2000 wants to merge 6 commits into
NVIDIA-NeMo:mainfrom
hbhflw2000:pr4_omni3_packseq_sequence_utils
Open

[model, feature] qwen3-omni: add packed sequence support and shared sequence utilities#4304
hbhflw2000 wants to merge 6 commits into
NVIDIA-NeMo:mainfrom
hbhflw2000:pr4_omni3_packseq_sequence_utils

Conversation

@hbhflw2000

@hbhflw2000 hbhflw2000 commented Jun 11, 2026

Copy link
Copy Markdown
Contributor

What does this PR do?

Add Qwen3-Omni packed sequence training support and introduce shared raw sequence padding / packed-sequence metadata utilities for the Qwen3-Omni training path.

Changelog

  • Add Qwen3-Omni pack_sequences_in_batch=True forward-step support.
  • Preserve dense CP behavior by keeping raw input_ids available for model-internal mRoPE while slicing train tensors on CP ranks.
  • Add shared raw-batch sequence padding helpers in training/utils/padding_utils.py.
  • Add shared uniform PackedSeqParams construction in training/utils/packed_seq_utils.py.
  • Follow the existing Qwen3-VL packed-padding pattern WITHOUT changing Qwen3-VL code in this PR.
  • Add unit coverage for Qwen3-Omni packed sequence / CP behavior and shared sequence utilities.

Design note / RFC

This implementation follows the existing Qwen3-VL packed-padding pattern: pad raw batch sequence tensors to an aligned dense length, build uniform THD PackedSeqParams, and keep model-specific multimodal / mRoPE handling inside the Qwen3-Omni step and model code.

This PR intentionally does not reuse slice_batch_for_context_parallel for Qwen3-Omni raw-batch padding. That utility operates after embedding preparation and slices inputs_embeds, while Qwen3-Omni needs pre-forward raw sequence normalization so the full input_ids tensor remains available for multimodal placeholder handling and mRoPE.

The shared abstraction here is intentionally narrow: compute the padded target sequence length, pad/truncate common raw batch tensors, and construct uniform THD PackedSeqParams. Model-specific logic such as multimodal merge, CP rank slicing, and mRoPE handling remains in Qwen3-Omni code.

ATTENTION: Qwen3-VL code is intentionally left unchanged in this PR. Applying these helpers back to Qwen3-VL can be considered separately with Qwen3-VL-specific regression coverage.

Validation

Unit tests:

pytest tests/unit_tests/training/utils/test_padding_utils.py tests/unit_tests/training/utils/test_packed_seq_utils.py
# 16 passed

pytest tests/unit_tests/models/qwen_omni/test_qwen3_omni_step.py tests/unit_tests/models/qwen_omni/modeling_qwen3_omni/test_omni_model.py
# 27 passed

E2E validation:
4-node / 32-GPU Qwen3-Omni packed sequence full-model training passed:
Parallel config: TP=2, PP=2, CP=2, EP=4, SP=True.
Training config: seq_length=16384, global_batch_size=16, micro_batch_size=2, train_iters=200.
Result: completed 200 steps with finite loss, stable grad norm, and stable throughput.

@copy-pr-bot

copy-pr-bot Bot commented Jun 11, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@hbhflw2000 hbhflw2000 force-pushed the pr4_omni3_packseq_sequence_utils branch from f0f95d8 to 042eed1 Compare June 11, 2026 11:27
@yaoyu-33 yaoyu-33 added area:model Model implementations and HF bridge logic feature New capabilities, enhancements, or enablement work needs-review PR is ready for code review and waiting on a reviewer labels Jun 11, 2026
@yaoyu-33

Copy link
Copy Markdown
Contributor

/claude review

Comment thread src/megatron/bridge/models/qwen_omni/qwen3_omni_step.py Outdated
@claude

claude Bot commented Jun 14, 2026

Copy link
Copy Markdown
Contributor

Light Code Review - Clean implementation that follows the existing Qwen3-VL packed-padding pattern well. One inline comment posted about using _parallel_size() consistently. Suggested test cases: No perf tests impacted.

@claude

claude Bot commented Jun 14, 2026

Copy link
Copy Markdown
Contributor

Details: pack_or_pad_batch_sequences accesses pg_collection.tp.size() / .cp.size() directly instead of the null-safe _parallel_size() helper introduced in the same PR. The sibling function pad_batch_sequences_for_context_parallel was already refactored to use it. See inline comment for a suggested fix.

@claude

claude Bot commented Jun 14, 2026

Copy link
Copy Markdown
Contributor

Test coverage notes: (1) test_forward_step_passes_packed_sequence_params_to_model only exercises CP=1 -- a CP>1 variant would cover _get_dense_batch_on_this_cp_rank + packed path interaction but may require more complex mocking and could be deferred. (2) No direct unit test for pack_or_pad_batch_sequences (tested indirectly via forward_step) -- a focused test would make FP8 padding alignment (math.lcm) and force_to_seq_length easier to verify in isolation. Suggested test cases: No perf tests impacted.

cp_size=cp_size,
cp_rank=cp_rank,
sequence_parallel=self.config.sequence_parallel,
if packed_seq_params is None:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these 2 here seems not need condition check, can just be one path?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated, thanks. I merged this into one split_deepstack_embs path. Packed THD tensors are already CP-aware after preprocess_packed_seqs, so only the non-packed path keeps the actual cp_size/cp_rank.

deepstack_visual_embeds,
tp_size=tp_size,
tp_rank=tp_rank,
cp_size=1,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why hard code cp?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This follows the existing Qwen3-VL packed THD handling: packed tensors are already CP-partitioned by preprocess_packed_seqs, so this split should only apply SP and avoid a second dense CP split. I made the intent explicit with local variables/comments in the latest commit.

@yaoyu-33 yaoyu-33 added waiting-on-customer Waiting on the original author to respond and removed needs-review PR is ready for code review and waiting on a reviewer labels Jun 15, 2026

@yaoyu-33 yaoyu-33 left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this is a well-structured PR — the shared utilities are clean, the design note clearly explains the trade-offs vs slice_batch_for_context_parallel, and the unit test coverage is solid. A few things worth addressing before merge:


Bug: is_thd_format = True is permanent state mutation

thinker_model.py line 165:

self.language_model.rotary_pos_emb.is_thd_format = True

This mutates the language model's rotary embedding state without ever resetting it. If the model is called again with packed_seq_params=None (e.g. inference or a non-packed eval step), is_thd_format stays True from the previous call. The guarding condition one line above already checks if packed_seq_params is not None, so this is:

if packed_seq_params is not None and position_ids is not None:
    ...
    self.language_model.rotary_pos_emb.is_thd_format = True
# but never reset to False

Should be set/reset in a symmetric pattern, e.g.:

self.language_model.rotary_pos_emb.is_thd_format = packed_seq_params is not None

placed unconditionally just before return self.language_model(...).


Magic constants in _get_qwen3_omni_audio_output_lengths

input_lengths_leave = input_lengths % 100
feat_lengths = (input_lengths_leave - 1) // 2 + 1
return ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13

100 and 13 are unexplained. Given the docstring says "Match HF Qwen3-Omni audio encoder forward output lengths", these likely represent the audio encoder chunk size (100 frames) and output tokens per full chunk (13). A comment explaining what these constants represent would make this much easier to audit when HF changes the audio encoder:

# Audio encoder processes chunks of 100 input frames,
# each producing 13 output tokens; remainder frames go through
# two stride-2 convolutions then a final stride-2 pool.

Also worth noting: the existing _get_feat_extract_output_lengths is a private API (_ prefix) so avoiding it is correct — but a reference to the HF source or audio encoder config that grounds these constants would prevent future drift.


Redundant attention mask construction in step

qwen3_omni_step.py lines 326–330:

forward_args["attention_mask"] = torch.ones_like(
    original_tokens, dtype=torch.bool, device=original_tokens.device,
)

thinker_model.forward already reconstructs attention_mask = torch.ones_like(input_ids, ...) when packed_seq_params is not None and attention_mask is None. Passing ones from the step means the model's guard is never triggered, but the result is the same. Minor, but adds an allocation on the critical path that the model discards immediately (it sets attention_mask = None before passing to language_model()). Could pass None here and let the model handle it, consistent with the existing guard.


Readability: nested ternary in language_model call

thinker_model.py lines 169–175:

input_ids=(
    lm_input_ids
    if packed_seq_params is not None
    else None
    if combined_embeddings is not None
    else input_ids
),

This is a nested conditional expression. A simple if/elif/else assignment before the call would be easier to follow.


Minor: no unit test for CP + packed combined path in forward_step

The test test_forward_step_passes_packed_sequence_params_to_model uses cp.size()=1. The E2E tested CP=2 + packing combined, but there's no unit-level coverage of the _get_dense_batch_on_this_cp_rank branch inside the packed path (if pack_sequences_in_batch and cp > 1). Not blocking given the E2E coverage, but worth a follow-up.


use_fp8_padding=True is hardcoded in the packed path regardless of whether FP8 training is actually enabled (adds lcm-16 alignment unconditionally). This is a conservative safe choice but worth a comment explaining why.

@hbhflw2000

Copy link
Copy Markdown
Contributor Author

@yaoyu-33 Thanks for the detailed review.

Fixed in the latest commits:

  • Reset rotary_pos_emb.is_thd_format symmetrically on packed/non-packed forwards, with a unit test.
  • Replaced audio length magic numbers with named constants and comments.
  • Rewrote the nested input_ids ternary.
  • Added a comment for the conservative use_fp8_padding=True choice.
  • Added a focused CP>1 packed forward_step unit test covering the CP slicing interaction.

I kept the explicit attention_mask for now to avoid changing the validated packed+CP path. The new unit test verifies labels/loss_mask are CP-local while full raw input_ids and packed_seq_params are still passed to the model.

Signed-off-by: hbhflw2000 <417911774@qq.com>
Signed-off-by: hbhflw2000 <417911774@qq.com>
Signed-off-by: hbhflw2000 <417911774@qq.com>
Signed-off-by: hbhflw2000 <417911774@qq.com>
Signed-off-by: hbhflw2000 <417911774@qq.com>
Signed-off-by: hbhflw2000 <417911774@qq.com>
@hbhflw2000 hbhflw2000 force-pushed the pr4_omni3_packseq_sequence_utils branch from f1aeec0 to 6c31ed3 Compare June 17, 2026 03:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:model Model implementations and HF bridge logic community-request feature New capabilities, enhancements, or enablement work waiting-on-customer Waiting on the original author to respond

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants