Skip to content

[train] Default to chunked logprobs for Megatron#1610

Merged
SumanthRH merged 11 commits into
mainfrom
default-chunk-logprobs-megatron
May 7, 2026
Merged

[train] Default to chunked logprobs for Megatron#1610
SumanthRH merged 11 commits into
mainfrom
default-chunk-logprobs-megatron

Conversation

@SumanthRH
Copy link
Copy Markdown
Member

@SumanthRH SumanthRH commented May 1, 2026

What does this PR do?

Defaults to chunked logprobs for Megatron with chunk_size of 1024

Currently, we don't do any chunking for logprobs calculation in the Megatron backend. For models like Qwen 3.5 with large vocab sizes (248,320), this can lead to OOMs.

This PR chooses a chunk size of 1024 as the best balance for memory saved vs increase in latency. For large sequence lengths and vocab sizes, this can save a lot of memory - ex: for 65K sequence lenght and 128K vocab size, we save 42GB of memory while the time taken increases from 132.2 ms to 600 ms (acceptable IMO).

Current memory usage

Currently, the full logprobs tensor of size (bsz, seq_len, vocab_size) gets materialized while we calculate the logprobs for the given input_ids.

The only way users can save memory in logprobs calculation is with tensor parallelism right now, where each TP rank would work with Tp sharded logits (bsz, seq_len, vocab_size/ TP) . Chunking along sequence is also needed for large sequence lengths.

Benchmark Settings

Device: NVIDIA B200
Vocab sizes: 32,000, 64,000 and 128,000
Measurement: forward + backward pass
Chunk sizes tested: None (no chunking), 32, 1024, 4096, 8192, 16384. I added 32 chunk size to the mix just to see what a small chunk size will yield.

Results

Vocab size 32K

seq_len chunk_size time (ms) peak mem (MB) vs no-chunk mem saved (MB)
32,768 None 17.0 15,601 1.00x 0
32,768 32 311.3 10,016 0.05x 5,585
32,768 1,024 41.5 10,501 0.41x 5,100
32,768 4,096 38.3 12,001 0.44x 3,600
32,768 8,192 37.7 14,001 0.45x 1,600
32,768 16,384 37.5 20,001 0.45x -4,400
65,536 None 33.5 31,202 1.00x 0
65,536 32 627.8 20,017 0.05x 11,185
65,536 1,024 82.7 20,501 0.41x 10,700
65,536 4,096 76.5 22,001 0.44x 9,200
65,536 8,192 75.2 24,001 0.45x 7,200
65,536 16,384 74.7 28,001 0.45x 3,200
131,072 None 66.6 62,404 1.00x 0
131,072 32 1,269.6 40,018 0.05x 22,385
131,072 1,024 165.9 40,503 0.40x 21,901
131,072 4,096 152.9 42,003 0.44x 20,401
131,072 8,192 150.2 44,003 0.44x 18,401
131,072 16,384 149.0 48,003 0.45x 14,401

Vocab Size 64K

seq_len chunk_size time (ms) peak mem (MB) vs no-chunk mem saved (MB)
32,768 None 33.4 31,201 1.00x 0
32,768 32 316.1 20,032 0.11x 11,169
32,768 1,024 78.6 21,001 0.42x 10,200
32,768 4,096 75.0 24,001 0.45x 7,200
32,768 8,192 74.5 28,001 0.45x 3,200
32,768 16,384 74.5 40,001 0.45x -8,800
65,536 None 66.4 62,402 1.00x 0
65,536 32 637.5 40,033 0.10x 22,369
65,536 1,024 157.4 41,001 0.42x 21,400
65,536 4,096 149.8 44,001 0.44x 18,400
65,536 8,192 148.6 48,001 0.45x 14,400
65,536 16,384 148.5 56,001 0.45x 6,400
131,072 None 133.7 124,804 1.00x 0
131,072 32 1,274.2 80,034 0.10x 44,770
131,072 1,024 314.7 81,003 0.42x 43,801
131,072 4,096 299.4 84,003 0.45x 40,801
131,072 8,192 297.0 88,003 0.45x 36,801
131,072 16,384 296.7 96,003 0.45x 28,801

Vocab Size 128K

seq_len chunk_size time (ms) peak mem (MB) vs no-chunk mem saved (MB)
32,768 None 66.4 62,401 1.00x 0
32,768 32 331.5 40,063 0.20x 22,338
32,768 1,024 152.8 42,001 0.43x 20,400
32,768 4,096 148.5 48,001 0.45x 14,400
32,768 8,192 148.3 56,001 0.45x 6,400
32,768 16,384 148.3 80,001 0.45x -17,600
65,536 None 132.2 124,802 1.00x 0
65,536 32 699.5 80,064 0.19x 44,738
65,536 1,024 305.5 82,001 0.43x 42,800
65,536 4,096 296.9 88,001 0.45x 36,800
65,536 8,192 296.5 96,001 0.45x 28,800
65,536 16,384 296.3 112,001 0.45x 12,800
131,072 None OOM
131,072 32 OOM
131,072 1,024 OOM
131,072 4,096 OOM
131,072 8,192 OOM
131,072 16,384 OOM

SumanthRH added 6 commits May 1, 2026 19:39
Set chunk_size=1024 (was None) for both the ref (inference_only=True)
and policy (inference_only=False) calls to from_parallel_logits_to_logprobs
in MegatronModelWrapper. This bounds peak GPU memory during the
log_softmax+gather step at the cost of some wall-clock time.

Add examples/benchmarks/bench_chunked_logprobs.py to quantify the
time/memory trade-off across seq_lens 32k–128k.

Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
x
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
@SumanthRH SumanthRH force-pushed the default-chunk-logprobs-megatron branch from 2b6706a to 4d04fa4 Compare May 1, 2026 23:31
SumanthRH added 2 commits May 2, 2026 00:18
…display)

Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
…comment

Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
@SumanthRH SumanthRH force-pushed the default-chunk-logprobs-megatron branch from 44a9b44 to 2324f2d Compare May 6, 2026 23:27
@SumanthRH SumanthRH marked this pull request as ready for review May 6, 2026 23:28
@SumanthRH SumanthRH requested a review from erictang000 May 6, 2026 23:29
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements chunked logprob computation within the Megatron model wrapper to prevent out-of-memory errors when processing large sequence lengths. It includes a new benchmark utility to measure the performance impact and memory savings of various chunk sizes. The feedback suggests moving the hardcoded chunk size constant into the TrainerConfig to allow for hardware-specific tuning and the ability to disable chunking to reduce latency.

# NOTE (sumanthrh): We use a chunk size of 1024 for calculating logprobs
# from logits to avoid OOMs for large sequence lengths.
# For more details, see https://github.com/NovaSky-AI/SkyRL/pull/1610
CHUNK_SIZE_LOGPROBS = 1024
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The CHUNK_SIZE_LOGPROBS is currently hardcoded as a global constant. While 1024 is a sensible default based on the provided benchmarks, hardcoding it prevents users from tuning this parameter for their specific hardware or disabling it if they have sufficient memory and want to avoid the latency penalty (which your benchmarks show can be significant, e.g., from 132ms to 600ms). It would be better to make this configurable via TrainerConfig so it can be adjusted without modifying the source code.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added logprobs_chunk_size to trainer config

SumanthRH added 2 commits May 7, 2026 00:01
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
x
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
@SumanthRH SumanthRH force-pushed the default-chunk-logprobs-megatron branch from 2324f2d to 4fa27eb Compare May 7, 2026 00:25
@SumanthRH
Copy link
Copy Markdown
Member Author

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces configurable chunking for log-probability and entropy calculations to manage peak GPU memory usage, adding a logprobs_chunk_size parameter to the trainer configuration and propagating it through the model wrappers. A new benchmark script is also included to measure the performance impact of different chunk sizes. Feedback indicates that setting the chunk size to None (to disable chunking) could lead to a TypeError during calculation, and suggests updating type hints and adding logic to handle the None case by defaulting to the full sequence length.

Comment on lines 31 to 39
def chunked_cross_entropy_from_log_probs(
logprobs: Float[torch.Tensor, "batch_size seqlen vocab_size"], requires_grad: bool = False
logprobs: Float[torch.Tensor, "batch_size seqlen vocab_size"],
requires_grad: bool = False,
chunk_size: int = 1024,
) -> Float[torch.Tensor, "batch_size seqlen"]:
cm = nullcontext() if requires_grad else torch.no_grad()
with cm:
# Calculate entropy in chunks to avoid OOM
chunk_size = CHUNK_SIZE
num_chunks = (logprobs.size(1) + chunk_size - 1) // chunk_size
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The TrainerConfig allows logprobs_chunk_size to be None to disable chunking. However, if chunk_size is None, the calculation of num_chunks will raise a TypeError. We should handle the None case by defaulting to the full sequence length and update the type hint to int | None.

def chunked_cross_entropy_from_log_probs(
    logprobs: Float[torch.Tensor, "batch_size seqlen vocab_size"],
    requires_grad: bool = False,
    chunk_size: int | None = 1024,
) -> Float[torch.Tensor, "batch_size seqlen"]:
    cm = nullcontext() if requires_grad else torch.no_grad()
    with cm:
        # Calculate entropy in chunks to avoid OOM
        chunk_size = chunk_size or logprobs.size(1)
        num_chunks = (logprobs.size(1) + chunk_size - 1) // chunk_size

with cm:
# Calculate entropy in chunks to avoid OOM
chunk_size = CHUNK_SIZE
num_chunks = (logits.size(1) + chunk_size - 1) // chunk_size
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Ensure chunk_size is a valid integer before calculating num_chunks to avoid a crash when chunking is disabled (chunk_size=None).

Suggested change
num_chunks = (logits.size(1) + chunk_size - 1) // chunk_size
chunk_size = chunk_size or logits.size(1)
num_chunks = (logits.size(1) + chunk_size - 1) // chunk_size

Comment on lines 59 to 63
logits: Float[torch.Tensor, "batch_size seqlen vocab"],
requires_grad: bool = False,
attention_mask: Float[torch.Tensor, "batch_size seqlen"] = None,
chunk_size: int = 1024,
) -> Float[torch.Tensor, "batch_size seqlen"]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Update the type hint for chunk_size to int | None to align with the configuration which allows disabling chunking by setting it to None.

Suggested change
logits: Float[torch.Tensor, "batch_size seqlen vocab"],
requires_grad: bool = False,
attention_mask: Float[torch.Tensor, "batch_size seqlen"] = None,
chunk_size: int = 1024,
) -> Float[torch.Tensor, "batch_size seqlen"]:
logits: Float[torch.Tensor, "batch_size seqlen vocab"],
requires_grad: bool = False,
attention_mask: Float[torch.Tensor, "batch_size seqlen"] = None,
chunk_size: int | None = 1024,
) -> Float[torch.Tensor, "batch_size seqlen"]:

Comment on lines 80 to 85
model_config_kwargs: dict = {},
meta_init: bool = False,
language_model_only: bool = False,
logprobs_chunk_size: int = 1024,
**kwargs,
) -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logprobs_chunk_size parameter should be typed as int | None to match the TrainerConfig and allow disabling chunking.

Suggested change
model_config_kwargs: dict = {},
meta_init: bool = False,
language_model_only: bool = False,
logprobs_chunk_size: int = 1024,
**kwargs,
) -> None:
model_config_kwargs: dict = {},
meta_init: bool = False,
language_model_only: bool = False,
logprobs_chunk_size: int | None = 1024,
**kwargs,
) -> None:

x
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
@SumanthRH SumanthRH force-pushed the default-chunk-logprobs-megatron branch from 600aab3 to 87aac06 Compare May 7, 2026 01:09
@SumanthRH
Copy link
Copy Markdown
Member Author

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces configurable sequence-dimension chunking for log-probability and entropy calculations to reduce peak GPU memory usage. It adds a logprobs_chunk_size setting to the TrainerConfig, updates FSDP and Megatron model wrappers to utilize this parameter, and provides a new benchmarking script. The review feedback suggests renaming a utility function for clarity, correcting a speedup interpretation in the benchmark script, and improving the configuration docstring by refining performance descriptions and removing ephemeral PR links.

Comment thread skyrl/backends/skyrl_train/utils/torch_utils.py
Comment thread skyrl/benchmarks/bench_chunked_logprobs.py
Comment thread skyrl/train/config/config.py
@SumanthRH SumanthRH merged commit 4c6793f into main May 7, 2026
3 of 5 checks passed
erictang000 added a commit that referenced this pull request May 11, 2026
…y loss (#1648)

### What


`tests/backends/skyrl_train/gpu/gpu_ci/megatron/test_megatron_worker.py::test_megatron_train[tp2_pp2_policy_seq_packing_with_entropy_loss]`
has been failing since #1610 made
  chunked logprobs the default for Megatron, with:

```bash
  RuntimeError: one of the variables needed for gradient computation has been modified by an
  inplace operation: [torch.cuda.FloatTensor [1, 31, 75968]], which is output 0 of
  torch::autograd::CopySlices, is at version 3; expected version 1 instead.
```
  ### Why

Two `autograd.Function`s save the same `vocab_parallel_logits` tensor
for backward:

- `ChunkedDistributedLogprob.forward` saves it directly (the input
logits).
- `_VocabParallelEntropy.forward` saves `action_logits = logits[:,
-num_actions-1:-1, :]`, which is a view of the same storage.

`_VocabParallelEntropy.backward` used a `sub_` + `add_`
"modify-then-restore" trick on its saved tensor to avoid an allocation.
The restore makes the values correct, but each
in-place op still bumps the storage version counter (1 → 2 → 3). When
`ChunkedDistributedLogprob.backward` later reads `ctx.saved_tensors`,
its version check fails.

Before #1610, the non-chunked `DistributedLogprob` saved
`softmax_output` rather than the input logits, so the two Functions
didn't share storage and the version bump went
  unnoticed.

  ### Fix

`skyrl/backends/skyrl_train/distributed/megatron/model_utils.py` —
replace the in-place `sub_`/`add_` pair with one out-of-place
`vocab_parallel_logits.sub(...)`. Same math, no
mutation of the saved tensor. Costs one temporary the size of the
action-logits slice during entropy backward; freed when backward
returns. Only on the `use_entropy_loss=True`
  path (default is `False`).

  ### Test plan

  - [x] Repro'd the original failure locally.
  - [x] `tp2_pp2_policy_seq_packing_with_entropy_loss` passes after fix.
  - [x] `tp2_pp2_policy_seq_packing` (non-entropy sibling) still passes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant