Add packing-aware dynamic batching to AsyncGRPO#6092
Conversation
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes using default effort and found 3 potential issues.
❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
Reviewed by Cursor Bugbot for commit 83496e2. Configure here.
| fits = [i for i in range(self.num_processes) if token_counts[i] + n <= self.token_budget] | ||
| if not fits: | ||
| # No row has room for this sample: close the micro-batch and start a fresh one. | ||
| yield rows |
There was a problem hiding this comment.
Empty DP rows in token batcher
High Severity
When TokenBudgetBatcher closes a micro-batch because the next sample fits no row, it yields the current rows without ensuring every DP rank has at least one sample. A full row plus an oversized successor can emit a batch where some ranks are empty, breaking the collator or causing ranks to forward zero tokens and desync FSDP/EP collectives.
Reviewed by Cursor Bugbot for commit 83496e2. Configure here.
| rows = [[] for _ in range(self.num_processes)] | ||
| squared_loads = [0] * self.num_processes | ||
| token_counts = [0] * self.num_processes | ||
| fits = list(range(self.num_processes)) |
There was a problem hiding this comment.
All-empty micro-batch before oversized sample
High Severity
If a rollout’s input_ids length exceeds token_budget, TokenBudgetBatcher yields an all-empty micro-batch before placing that sample. DataCollatorForRollout then hits an empty all_examples list and can fail when building metrics or tensors, and training may take a step on zero tokens.
Reviewed by Cursor Bugbot for commit 83496e2. Configure here.
| i = min(fits, key=lambda j: squared_loads[j]) | ||
| rows[i].append(sample) | ||
| squared_loads[i] += n * n | ||
| token_counts[i] += n |
There was a problem hiding this comment.
Buffered samples never flushed
Medium Severity
TokenBudgetBatcher and FixedCountBatcher never emit a final partial micro-batch when iteration stops. Samples left in rows or an unfilled batch buffer are dropped after they were already taken from the rollout queue, so those rollouts never contribute to training.
Additional Locations (1)
Reviewed by Cursor Bugbot for commit 83496e2. Configure here.
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes using default effort and found 2 potential issues.
❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
Reviewed by Cursor Bugbot for commit 83496e2. Configure here.
| i = min(fits, key=lambda j: squared_loads[j]) | ||
| rows[i].append(sample) | ||
| squared_loads[i] += n * n | ||
| token_counts[i] += n |
There was a problem hiding this comment.
Empty ranks exceed token budget
High Severity
When a rollout’s input_ids length exceeds token_budget, TokenBudgetBatcher can emit a micro-batch with one or more empty DP rows, and after rollover it assigns the sample using all ranks without re-checking the budget, so a row can hold more than token_budget tokens. That breaks the documented non-empty-row invariant and can desync FSDP/EP collectives; an all-empty yield can also make DataCollatorForRollout index an empty all_examples.
Reviewed by Cursor Bugbot for commit 83496e2. Configure here.
| i = min(fits, key=lambda j: squared_loads[j]) | ||
| rows[i].append(sample) | ||
| squared_loads[i] += n * n | ||
| token_counts[i] += n |
There was a problem hiding this comment.
Planner drops trailing buffered samples
Medium Severity
TokenBudgetBatcher and FixedCountBatcher never yield a final partial micro-batch when iteration stops, so samples left in rows or batch are discarded instead of reaching the collator and compute_loss.
Additional Locations (1)
Reviewed by Cursor Bugbot for commit 83496e2. Configure here.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |




What this PR does
Builds on the landed padding-free path (#5854) to make AsyncGRPO micro-batching packing-aware and
token-bounded, while riding HF Trainer's existing gradient accumulation (no training-loop surgery, FSDP/EP
collectives stay in lockstep because
K = gradient_accumulation_stepsis uniform across ranks).Two properties for each micro-batch:
num_processesDP rows are Σ Lᵢ²-balanced (attention isO(L²), FFN isO(L), so equal token counts don't equalize wall-time). No cross-rank straggler at the per-micro-batch all-reducetoken_budgettokens (optional) → caps peak memory, decoupled fromper_device_train_batch_size.Main structure : The collator becomes a pure packer ("packer"); a planner upstream decides which samples go to which row ("planner").
See this animation:
dynamic_batching_animation (1).html
What we introduce, and why
There are two separate mechanisms, both added on top of the padding-free packing from #5854.
1. Σ Lᵢ² row balancing (
pfree_sample, classFixedCountBatcher). It replaces the old strided splitexamples[i::self.num_processes]with a greedy longest-first assignment that gives every DP row the same Σ Lᵢ². Equal token counts do not equalize wall-time (attention isO(L²), FFN isO(L)); equal Σ Lᵢ² does. This removes the cross-rank straggler. In the benchmark, a strided imbalance of about 1.4 drops to about 0, and MFU goes up 19% at 4B !2. Token-budget packing (
pfree_tokenlimit, classTokenBudgetBatcher). It fills each DP row up totoken_budgettokens, with a variable number of samples, instead of a fixed sample count. This caps peak memory independently ofper_device.What the token budget controls, and what it does not.
It is not a max-sequence-length knob, and it does not cap a single sample. It caps the packed row, which is the sum of the samples in that row. Fixed-count packing cannot do this. Each row contains
per_device_bssamples of variable length, so it can spike toper_device × max_seq_lenwhen a micro-batch draws all-long samples. To stay safe, fixed-count has to pick a smallper_device_bsfor the worst case, which wastes memory in the average case.The token budget adapts per micro-batch: short samples pack many per row, long samples pack few, but the row is always about
token_budgettokens. So, at a fixed memory ceiling, it runs a larger average forward. This advantage only shows up when memory is the binding constraint (see Limitations).How it fits the rest of the pipeline.
RolloutQueueDatasetreads single scored samples FIFO from the rolloutmp.Queueand drops stale ones.num_processesrows.FixedCountBatcheruses a fixed sample count with Σ Lᵢ² balancing.TokenBudgetBatcheruses a token budget. One planner item is one micro-batch.(num_processes, T_max)tensor batch (concat tokens, resetposition_idsper sequence, expand advantages per token, pad rows). It is a pure packer with one input contract.DataLoaderDispatchersends rowito ranki.batch_size=1, so one planner item is exactly one scattered micro-batch.K = gradient_accumulation_stepsmicro-batches per optimizer step.Kis the same on every rank, so FSDP/EP collectives stay in lockstep. This is what avoids the deadlock you would get from chunking insidecompute_losswith ragged chunk counts.Limitations and open discussion
1. FIFO scheduling, not global. The scheduler processes the queue using FIFO: pulling data, dropping stale items, and splitting. It only manages and packs within the current micro-batch window and does not reorder or bucket data across the entire rollout buffer to optimize the global split. A more advanced scheduler could improve balance and create more uniform forwards, but this would come at the expense of freshness, as buffered samples age and may be dropped due to
max_staleness. This challenge is unique to asynchronous RL and does not occur in SFT.2. No scheduling across micro-batches (on purpose). We use HF Trainer's gradient accumulation:
Kmicro-batches per step, withKthe same on every rank. This keeps FSDP/EP collectives in lockstep, but it prevents reordering micro-batches in time. With pipeline parallelism, that ordering matters because it balances the warm-up and cool-down bubble. See NVIDIA's serpentine (zig-zag) micro-batch ordering for PP (ref: add link). We dropped serpentine ordering because we do not run PP. In pure DP plus FSDP, stragglers are within a micro-batch, and the Σ Lᵢ² balancing removes them. If PP is added later, micro-batch ordering should be revisited.3. The token budget is a soft cap below the longest sample. For a hard memory cap, set
token_budget >= max_completion_length + max_prompt. A sample longer than the budget is placed alone in a row, and that row goes over the budget, so peak memory ismax(token_budget, longest_sample). The batcher closes a micro-batch only once every row has at least one sample, so this never produces an empty row or desyncs DP collectives.The benchmark does not pick a clear winner. The right default depends on the regime.
pfree_sample(fixed count plus longest-first) matches or beatspfree_tokenlimiton MFU, and it balances Σ Lᵢ² strictly better. Longest-first directly minimizes Σ Lᵢ². The token budget equalizes tokens, not Σ Lᵢ², so its rows can have the same token count but very different Σ Lᵢ² (one long sequence vs many short ones).pfree_tokenlimitis the only one that can safely run the largest forward the memory allows. Fixed count has to under-provisionper_devicefor length spikes or it OOMs. This is the one regime where it clearly wins. In the no-checkpointing run, it ran a 6.2k-token forward at 19.8% MFU, wherepfree_samplewas forced down to bs2 and OOM'd at bs4.pfree_sampleis the simpler, better-balanced default.pfree_tokenlimitis opt-in(token_budget > 0) for memory-bound, spiky-length, long-context training, and its cap is soft below the longest sample. We expose both and leave the choice to the user. Whether the token-budget path is worth it for most workloads, versus just fixed-count Σ Lᵢ² balancing, is genuinely open. Real workload data would settle it.Note
Medium Risk
Changes how training batches are formed and scattered across DP ranks (FSDP collectives assume every rank gets a non-empty forward); default
token_budget <= 0preserves fixed-count behavior but still replaces strided grouping with explicit balancing.Overview
Adds packing-aware micro-batching for AsyncGRPO: upstream planners (
FixedCountBatcherorTokenBudgetBatcher) partition rollout samples into one row per DP rank, balancing Σ Lᵢ² (attention cost) to cut cross-rank stragglers.DataCollatorForRolloutis now a pure packer that tensorizes those pre-partitioned rows (dataloaderbatch_size=1per micro-batch) instead of striding a flat sample list.New
token_budgetonAsyncGRPOConfig: when> 0,TokenBudgetBatchercaps real tokens per rank forward with dynamic samples-per-row; when<= 0(default),FixedCountBatcherkeepsper_device_train_batch_size × num_processessamples per micro-batch with the same balancing.TestPackingAwareBatchingcovers balancing, both batchers, and collator packing/padding (CPU-only).Reviewed by Cursor Bugbot for commit 83496e2. Bugbot is set up for automated code reviews on this repo. Configure here.