Skip to content

Add packing-aware dynamic batching to AsyncGRPO#6092

Open
AmineDiro wants to merge 2 commits into
mainfrom
padding-free-scheduling
Open

Add packing-aware dynamic batching to AsyncGRPO#6092
AmineDiro wants to merge 2 commits into
mainfrom
padding-free-scheduling

Conversation

@AmineDiro

@AmineDiro AmineDiro commented Jun 17, 2026

Copy link
Copy Markdown
Member

What this PR does

Builds on the landed padding-free path (#5854) to make AsyncGRPO micro-batching packing-aware and
token-bounded
, while riding HF Trainer's existing gradient accumulation (no training-loop surgery, FSDP/EP
collectives stay in lockstep because K = gradient_accumulation_steps is uniform across ranks).

Two properties for each micro-batch:

  • the num_processes DP rows are Σ Lᵢ²-balanced (attention is O(L²), FFN is O(L), so equal token counts don't equalize wall-time). No cross-rank straggler at the per-micro-batch all-reduce
  • each row ≤ token_budget tokens (optional) → caps peak memory, decoupled fromper_device_train_batch_size.

Main structure : The collator becomes a pure packer ("packer"); a planner upstream decides which samples go to which row ("planner").

See this animation:
dynamic_batching_animation (1).html

What we introduce, and why

There are two separate mechanisms, both added on top of the padding-free packing from #5854.

1. Σ Lᵢ² row balancing (pfree_sample, class FixedCountBatcher). It replaces the old strided split
examples[i::self.num_processes] with a greedy longest-first assignment that gives every DP row the same Σ Lᵢ². Equal token counts do not equalize wall-time (attention is O(L²), FFN is O(L)); equal Σ Lᵢ² does. This removes the cross-rank straggler. In the benchmark, a strided imbalance of about 1.4 drops to about 0, and MFU goes up 19% at 4B !

2. Token-budget packing (pfree_tokenlimit, class TokenBudgetBatcher). It fills each DP row up to token_budget tokens, with a variable number of samples, instead of a fixed sample count. This caps peak memory independently of per_device.

What the token budget controls, and what it does not.

It is not a max-sequence-length knob, and it does not cap a single sample. It caps the packed row, which is the sum of the samples in that row. Fixed-count packing cannot do this. Each row contains per_device_bs samples of variable length, so it can spike to per_device × max_seq_len when a micro-batch draws all-long samples. To stay safe, fixed-count has to pick a small per_device_bs for the worst case, which wastes memory in the average case.

The token budget adapts per micro-batch: short samples pack many per row, long samples pack few, but the row is always about token_budget tokens. So, at a fixed memory ceiling, it runs a larger average forward. This advantage only shows up when memory is the binding constraint (see Limitations).

How it fits the rest of the pipeline.

  • Queue (source, unchanged): RolloutQueueDataset reads single scored samples FIFO from the rollout mp.Queue and drops stale ones.
  • Planner (new): it pulls samples and splits one micro-batch into num_processes rows. FixedCountBatcher uses a fixed sample count with Σ Lᵢ² balancing. TokenBudgetBatcher uses a token budget. One planner item is one micro-batch.
  • Packer (collator): it turns the pre-split rows into a (num_processes, T_max) tensor batch (concat tokens, reset position_ids per sequence, expand advantages per token, pad rows). It is a pure packer with one input contract.
  • Dispatch to DP: DataLoaderDispatcher sends row i to rank i. batch_size=1, so one planner item is exactly one scattered micro-batch.
  • Grad-accum and collectives: HF Trainer pulls K = gradient_accumulation_steps micro-batches per optimizer step. K is the same on every rank, so FSDP/EP collectives stay in lockstep. This is what avoids the deadlock you would get from chunking inside compute_loss with ragged chunk counts.

Limitations and open discussion

1. FIFO scheduling, not global. The scheduler processes the queue using FIFO: pulling data, dropping stale items, and splitting. It only manages and packs within the current micro-batch window and does not reorder or bucket data across the entire rollout buffer to optimize the global split. A more advanced scheduler could improve balance and create more uniform forwards, but this would come at the expense of freshness, as buffered samples age and may be dropped due to max_staleness. This challenge is unique to asynchronous RL and does not occur in SFT.

2. No scheduling across micro-batches (on purpose). We use HF Trainer's gradient accumulation: K micro-batches per step, with K the same on every rank. This keeps FSDP/EP collectives in lockstep, but it prevents reordering micro-batches in time. With pipeline parallelism, that ordering matters because it balances the warm-up and cool-down bubble. See NVIDIA's serpentine (zig-zag) micro-batch ordering for PP (ref: add link). We dropped serpentine ordering because we do not run PP. In pure DP plus FSDP, stragglers are within a micro-batch, and the Σ Lᵢ² balancing removes them. If PP is added later, micro-batch ordering should be revisited.

3. The token budget is a soft cap below the longest sample. For a hard memory cap, set token_budget >= max_completion_length + max_prompt. A sample longer than the budget is placed alone in a row, and that row goes over the budget, so peak memory is max(token_budget, longest_sample). The batcher closes a micro-batch only once every row has at least one sample, so this never produces an empty row or desyncs DP collectives.

The benchmark does not pick a clear winner. The right default depends on the regime.

  • Memory loose (gradient checkpointing on, the common case): pfree_sample (fixed count plus longest-first) matches or beats pfree_tokenlimit on MFU, and it balances Σ Lᵢ² strictly better. Longest-first directly minimizes Σ Lᵢ². The token budget equalizes tokens, not Σ Lᵢ², so its rows can have the same token count but very different Σ Lᵢ² (one long sequence vs many short ones).
  • Memory-bound (no checkpointing, long context, or very large models): pfree_tokenlimit is the only one that can safely run the largest forward the memory allows. Fixed count has to under-provision per_device for length spikes or it OOMs. This is the one regime where it clearly wins. In the no-checkpointing run, it ran a 6.2k-token forward at 19.8% MFU, where pfree_sample was forced down to bs2 and OOM'd at bs4.

pfree_sample is the simpler, better-balanced default. pfree_tokenlimit is opt-in(token_budget > 0) for memory-bound, spiky-length, long-context training, and its cap is soft below the longest sample. We expose both and leave the choice to the user. Whether the token-budget path is worth it for most workloads, versus just fixed-count Σ Lᵢ² balancing, is genuinely open. Real workload data would settle it.


Note

Medium Risk
Changes how training batches are formed and scattered across DP ranks (FSDP collectives assume every rank gets a non-empty forward); default token_budget <= 0 preserves fixed-count behavior but still replaces strided grouping with explicit balancing.

Overview
Adds packing-aware micro-batching for AsyncGRPO: upstream planners (FixedCountBatcher or TokenBudgetBatcher) partition rollout samples into one row per DP rank, balancing Σ Lᵢ² (attention cost) to cut cross-rank stragglers. DataCollatorForRollout is now a pure packer that tensorizes those pre-partitioned rows (dataloader batch_size=1 per micro-batch) instead of striding a flat sample list.

New token_budget on AsyncGRPOConfig: when > 0, TokenBudgetBatcher caps real tokens per rank forward with dynamic samples-per-row; when <= 0 (default), FixedCountBatcher keeps per_device_train_batch_size × num_processes samples per micro-batch with the same balancing.

TestPackingAwareBatching covers balancing, both batchers, and collator packing/padding (CPU-only).

Reviewed by Cursor Bugbot for commit 83496e2. Bugbot is set up for automated code reviews on this repo. Configure here.

@AmineDiro AmineDiro requested a review from qgallouedec June 17, 2026 21:29
@AmineDiro AmineDiro marked this pull request as ready for review June 18, 2026 12:47

@cursor cursor Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes using default effort and found 3 potential issues.

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit 83496e2. Configure here.

fits = [i for i in range(self.num_processes) if token_counts[i] + n <= self.token_budget]
if not fits:
# No row has room for this sample: close the micro-batch and start a fresh one.
yield rows

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Empty DP rows in token batcher

High Severity

When TokenBudgetBatcher closes a micro-batch because the next sample fits no row, it yields the current rows without ensuring every DP rank has at least one sample. A full row plus an oversized successor can emit a batch where some ranks are empty, breaking the collator or causing ranks to forward zero tokens and desync FSDP/EP collectives.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 83496e2. Configure here.

rows = [[] for _ in range(self.num_processes)]
squared_loads = [0] * self.num_processes
token_counts = [0] * self.num_processes
fits = list(range(self.num_processes))

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All-empty micro-batch before oversized sample

High Severity

If a rollout’s input_ids length exceeds token_budget, TokenBudgetBatcher yields an all-empty micro-batch before placing that sample. DataCollatorForRollout then hits an empty all_examples list and can fail when building metrics or tensors, and training may take a step on zero tokens.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 83496e2. Configure here.

i = min(fits, key=lambda j: squared_loads[j])
rows[i].append(sample)
squared_loads[i] += n * n
token_counts[i] += n

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Buffered samples never flushed

Medium Severity

TokenBudgetBatcher and FixedCountBatcher never emit a final partial micro-batch when iteration stops. Samples left in rows or an unfilled batch buffer are dropped after they were already taken from the rollout queue, so those rollouts never contribute to training.

Additional Locations (1)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 83496e2. Configure here.

@cursor cursor Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes using default effort and found 2 potential issues.

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit 83496e2. Configure here.

i = min(fits, key=lambda j: squared_loads[j])
rows[i].append(sample)
squared_loads[i] += n * n
token_counts[i] += n

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Empty ranks exceed token budget

High Severity

When a rollout’s input_ids length exceeds token_budget, TokenBudgetBatcher can emit a micro-batch with one or more empty DP rows, and after rollover it assigns the sample using all ranks without re-checking the budget, so a row can hold more than token_budget tokens. That breaks the documented non-empty-row invariant and can desync FSDP/EP collectives; an all-empty yield can also make DataCollatorForRollout index an empty all_examples.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 83496e2. Configure here.

i = min(fits, key=lambda j: squared_loads[j])
rows[i].append(sample)
squared_loads[i] += n * n
token_counts[i] += n

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Planner drops trailing buffered samples

Medium Severity

TokenBudgetBatcher and FixedCountBatcher never yield a final partial micro-batch when iteration stops, so samples left in rows or batch are discarded instead of reaching the collator and compute_loss.

Additional Locations (1)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 83496e2. Configure here.

@bot-ci-comment

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@AmineDiro

Copy link
Copy Markdown
Member Author

Benchmark setup

Benchmarked the training loop in isolation (no vLLM) with a synthetic rollout worker, so the only variable is the
batching strategy. (Harness lives in padding_free_bench/, not part of this PR.)

  • Models/parallelism: Qwen3-1.7B (DDP) → Qwen3-4B / 8B (FSDP2), DP = 1 / 4 / 8, bf16 autocast (params fp32),
    grad-checkpointing on (turned off only for the memory-bound experiments). H100 80GB, FlashAttention-3,
    num_generations=8, gradient_accumulation_steps=4.
  • Synthetic worker: emits random samples in per-prompt blocks of num_generations; completion lengths from a
    clipped log-normal (median 700, σ 0.9) with a 10% truncation spike at 4096 (mirrors vLLM EOS-vs-cap), prompt
    median 256. Seeded → every strategy sees identical data. Realized full-sequence lengths: min 116 / mean 1597 /
    max 5120 tokens.
  • Strategies compared:
    • padded : pre-padding-free baseline: one right-padded row per sample, standard batched forward (no packing).
    • pfree_strided: StridedBatcher: padding-free, original-main grouping examples[i::P], no Σ Lᵢ² balancing.
    • pfree_sample : FixedCountBatcher: padding-free, Σ Lᵢ²-balanced, fixed per_device × num_processes samples/row.
    • pfree_tokenlimit : TokenBudgetBatcher (the planner): padding-free, Σ Lᵢ² + per-row token budget (dynamic count).

Rollout distribution

image

What we ran

All runs: Qwen3, num_generations=8, gradient_accumulation_steps=4, bf16 autocast (fp32 params), FlashAttention-3, on H100-80GB. Synthetic vLLM-shaped lengths: completion log-normal(median 700, σ0.9), 10% truncated @ 4096, prompt median 256 (seed 0 → identical stream for every run).

min mean p50 p90 max
completion 24 1316 794 4096 4096
full 116 1597 1086 4269 5120

full (prompt+completion) is what each forward processes; the heavy tail (mean 1597, max 5120) is what makes balancing matter.

All runs summary

ALL RUNS - Click here to expand !

Columns

  • bs/budget: bs=N = fixed per_device samples/row; tok=N = per-row token budget (dynamic count).
  • microbatch tok/row: mean real tokens per DP row = the actual per-rank forward size.
  • Strategy: padded (1 padded row/sample) · pfree_strided (orig main examples[i::P], no balance) · pfree_sample (Σ Lᵢ²-balanced fixed count) · pfree_tokenlimit (Σ Lᵢ² + token cap).
  • MFU: vs bf16 peak 989.5 TFLOPS, world-size corrected. GPU mem: per-rank peak (max across ranks).
Model DP Parallel grad-ckpt Strategy bs/budget grad-acc microbatch tok/row TPS MFU % Peak mem GB ΣLᵢ² imbal
1.7B 1 DDP on padded bs=4 4 6112 5538 6.8 37.0 0.00
1.7B 1 DDP on padded bs=8 4 13009 5235 6.4 44.4 0.00
1.7B 1 DDP on pfree_sample bs=4 4 6112 6777 8.3 35.7 0.00
1.7B 1 DDP on pfree_sample bs=8 4 13009 7167 8.8 38.8 0.00
1.7B 1 DDP on pfree_tokenlimit tok=4096 4 3358 5669 6.8 32.3 0.00
1.7B 1 DDP on pfree_tokenlimit tok=8192 4 6950 6997 8.5 33.8 0.00
1.7B 1 DDP on pfree_tokenlimit tok=16384 4 14922 7282 9.0 37.1 0.00
1.7B 4 DDP on padded bs=4 4 6468 16712 5.1 44.1 1.52
1.7B 4 DDP on padded bs=8 4 12763 17894 5.5 51.5 1.14
1.7B 4 DDP on pfree_sample bs=4 4 6468 21409 6.6 42.8 0.29
1.7B 4 DDP on pfree_sample bs=8 4 12763 25100 7.7 44.7 0.00
1.7B 4 DDP on pfree_tokenlimit tok=8192 4 6370 23926 7.3 40.7 0.81
1.7B 4 DDP on pfree_tokenlimit tok=16384 4 15265 26772 8.2 44.1 0.59
1.7B 8 DDP on padded bs=4 4 6382 31090 4.8 44.5 2.02
1.7B 8 DDP on padded bs=8 4 12764 34116 5.2 51.6 1.52
1.7B 8 DDP on pfree_sample bs=4 4 6382 39749 6.1 42.5 0.32
1.7B 8 DDP on pfree_sample bs=8 4 12764 49788 7.6 44.6 0.00
1.7B 8 DDP on pfree_sample bs=16 4 25323 55845 8.5 49.5 0.00
1.7B 8 DDP on pfree_sample bs=24 4 38102 54493 8.3 54.9 0.00
1.7B 8 DDP on pfree_tokenlimit tok=8192 4 6312 44888 6.9 40.7 1.02
1.7B 8 DDP on pfree_tokenlimit tok=16384 4 15479 51988 7.9 44.1 0.75
1.7B 8 DDP on pfree_tokenlimit tok=24576 4 23750 53675 8.2 47.4 0.62
1.7B 8 DDP on pfree_tokenlimit tok=32768 4 31943 54204 8.3 50.7 0.54
1.7B 8 DDP on pfree_tokenlimit tok=49152 4 48495 56270 8.6 57.4 0.45
4B 8 FSDP on padded bs=8 4 12850 27355 10.0 33.1 1.46
4B 8 FSDP on pfree_strided bs=8 4 12751 42489 15.4 28.7 1.40
4B 8 FSDP on pfree_strided bs=16 4 25309 47165 17.1 34.1 1.02
4B 8 FSDP on pfree_sample bs=8 4 12850 51749 18.8 24.7 0.00
4B 8 FSDP on pfree_sample bs=16 4 25605 57563 20.9 29.0 0.00
4B 8 FSDP on pfree_tokenlimit tok=16384 4 15349 55114 20.0 24.3 0.87
4B 8 FSDP on pfree_tokenlimit tok=32768 4 31930 56271 20.4 30.3 0.52
4B 8 FSDP on pfree_tokenlimit tok=49152 4 48575 66729 24.2 36.6 0.42
4B 8 FSDP OFF padded bs=2 4 3276 26682 9.7 80.1 2.67
4B 8 FSDP OFF pfree_sample bs=2 4 3276 43953 16.0 53.3 1.93
4B 8 FSDP OFF pfree_tokenlimit tok=8192 4 6174 54402 19.8 69.8 0.99
8B 8 FSDP on padded bs=8 4 12696 18214 12.4 52.9 1.45
8B 8 FSDP on pfree_sample bs=8 4 12696 34287 23.3 41.5 0.00
8B 8 FSDP on pfree_sample bs=16 4 25497 38130 25.9 47.2 0.00
8B 8 FSDP on pfree_sample bs=32 4 50671 42162 28.6 62.1 0.00
8B 8 FSDP on pfree_tokenlimit tok=32768 4 31975 39693 26.9 48.9 0.53
8B 8 FSDP on pfree_tokenlimit tok=49152 4 48504 42201 28.6 57.8 0.43
8B 8 FSDP on pfree_tokenlimit tok=65536 4 65009 39001 26.5 67.2 0.37
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant