[train] Default to chunked logprobs for Megatron#1610
Conversation
Set chunk_size=1024 (was None) for both the ref (inference_only=True) and policy (inference_only=False) calls to from_parallel_logits_to_logprobs in MegatronModelWrapper. This bounds peak GPU memory during the log_softmax+gather step at the cost of some wall-clock time. Add examples/benchmarks/bench_chunked_logprobs.py to quantify the time/memory trade-off across seq_lens 32k–128k. Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
2b6706a to
4d04fa4
Compare
…display) Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
…comment Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
44a9b44 to
2324f2d
Compare
There was a problem hiding this comment.
Code Review
This pull request implements chunked logprob computation within the Megatron model wrapper to prevent out-of-memory errors when processing large sequence lengths. It includes a new benchmark utility to measure the performance impact and memory savings of various chunk sizes. The feedback suggests moving the hardcoded chunk size constant into the TrainerConfig to allow for hardware-specific tuning and the ability to disable chunking to reduce latency.
| # NOTE (sumanthrh): We use a chunk size of 1024 for calculating logprobs | ||
| # from logits to avoid OOMs for large sequence lengths. | ||
| # For more details, see https://github.com/NovaSky-AI/SkyRL/pull/1610 | ||
| CHUNK_SIZE_LOGPROBS = 1024 |
There was a problem hiding this comment.
The CHUNK_SIZE_LOGPROBS is currently hardcoded as a global constant. While 1024 is a sensible default based on the provided benchmarks, hardcoding it prevents users from tuning this parameter for their specific hardware or disabling it if they have sufficient memory and want to avoid the latency penalty (which your benchmarks show can be significant, e.g., from 132ms to 600ms). It would be better to make this configurable via TrainerConfig so it can be adjusted without modifying the source code.
There was a problem hiding this comment.
Added logprobs_chunk_size to trainer config
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
2324f2d to
4fa27eb
Compare
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces configurable chunking for log-probability and entropy calculations to manage peak GPU memory usage, adding a logprobs_chunk_size parameter to the trainer configuration and propagating it through the model wrappers. A new benchmark script is also included to measure the performance impact of different chunk sizes. Feedback indicates that setting the chunk size to None (to disable chunking) could lead to a TypeError during calculation, and suggests updating type hints and adding logic to handle the None case by defaulting to the full sequence length.
| def chunked_cross_entropy_from_log_probs( | ||
| logprobs: Float[torch.Tensor, "batch_size seqlen vocab_size"], requires_grad: bool = False | ||
| logprobs: Float[torch.Tensor, "batch_size seqlen vocab_size"], | ||
| requires_grad: bool = False, | ||
| chunk_size: int = 1024, | ||
| ) -> Float[torch.Tensor, "batch_size seqlen"]: | ||
| cm = nullcontext() if requires_grad else torch.no_grad() | ||
| with cm: | ||
| # Calculate entropy in chunks to avoid OOM | ||
| chunk_size = CHUNK_SIZE | ||
| num_chunks = (logprobs.size(1) + chunk_size - 1) // chunk_size |
There was a problem hiding this comment.
The TrainerConfig allows logprobs_chunk_size to be None to disable chunking. However, if chunk_size is None, the calculation of num_chunks will raise a TypeError. We should handle the None case by defaulting to the full sequence length and update the type hint to int | None.
def chunked_cross_entropy_from_log_probs(
logprobs: Float[torch.Tensor, "batch_size seqlen vocab_size"],
requires_grad: bool = False,
chunk_size: int | None = 1024,
) -> Float[torch.Tensor, "batch_size seqlen"]:
cm = nullcontext() if requires_grad else torch.no_grad()
with cm:
# Calculate entropy in chunks to avoid OOM
chunk_size = chunk_size or logprobs.size(1)
num_chunks = (logprobs.size(1) + chunk_size - 1) // chunk_size| with cm: | ||
| # Calculate entropy in chunks to avoid OOM | ||
| chunk_size = CHUNK_SIZE | ||
| num_chunks = (logits.size(1) + chunk_size - 1) // chunk_size |
There was a problem hiding this comment.
Ensure chunk_size is a valid integer before calculating num_chunks to avoid a crash when chunking is disabled (chunk_size=None).
| num_chunks = (logits.size(1) + chunk_size - 1) // chunk_size | |
| chunk_size = chunk_size or logits.size(1) | |
| num_chunks = (logits.size(1) + chunk_size - 1) // chunk_size |
| logits: Float[torch.Tensor, "batch_size seqlen vocab"], | ||
| requires_grad: bool = False, | ||
| attention_mask: Float[torch.Tensor, "batch_size seqlen"] = None, | ||
| chunk_size: int = 1024, | ||
| ) -> Float[torch.Tensor, "batch_size seqlen"]: |
There was a problem hiding this comment.
Update the type hint for chunk_size to int | None to align with the configuration which allows disabling chunking by setting it to None.
| logits: Float[torch.Tensor, "batch_size seqlen vocab"], | |
| requires_grad: bool = False, | |
| attention_mask: Float[torch.Tensor, "batch_size seqlen"] = None, | |
| chunk_size: int = 1024, | |
| ) -> Float[torch.Tensor, "batch_size seqlen"]: | |
| logits: Float[torch.Tensor, "batch_size seqlen vocab"], | |
| requires_grad: bool = False, | |
| attention_mask: Float[torch.Tensor, "batch_size seqlen"] = None, | |
| chunk_size: int | None = 1024, | |
| ) -> Float[torch.Tensor, "batch_size seqlen"]: |
| model_config_kwargs: dict = {}, | ||
| meta_init: bool = False, | ||
| language_model_only: bool = False, | ||
| logprobs_chunk_size: int = 1024, | ||
| **kwargs, | ||
| ) -> None: |
There was a problem hiding this comment.
The logprobs_chunk_size parameter should be typed as int | None to match the TrainerConfig and allow disabling chunking.
| model_config_kwargs: dict = {}, | |
| meta_init: bool = False, | |
| language_model_only: bool = False, | |
| logprobs_chunk_size: int = 1024, | |
| **kwargs, | |
| ) -> None: | |
| model_config_kwargs: dict = {}, | |
| meta_init: bool = False, | |
| language_model_only: bool = False, | |
| logprobs_chunk_size: int | None = 1024, | |
| **kwargs, | |
| ) -> None: |
600aab3 to
87aac06
Compare
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces configurable sequence-dimension chunking for log-probability and entropy calculations to reduce peak GPU memory usage. It adds a logprobs_chunk_size setting to the TrainerConfig, updates FSDP and Megatron model wrappers to utilize this parameter, and provides a new benchmarking script. The review feedback suggests renaming a utility function for clarity, correcting a speedup interpretation in the benchmark script, and improving the configuration docstring by refining performance descriptions and removing ephemeral PR links.
…y loss (#1648) ### What `tests/backends/skyrl_train/gpu/gpu_ci/megatron/test_megatron_worker.py::test_megatron_train[tp2_pp2_policy_seq_packing_with_entropy_loss]` has been failing since #1610 made chunked logprobs the default for Megatron, with: ```bash RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1, 31, 75968]], which is output 0 of torch::autograd::CopySlices, is at version 3; expected version 1 instead. ``` ### Why Two `autograd.Function`s save the same `vocab_parallel_logits` tensor for backward: - `ChunkedDistributedLogprob.forward` saves it directly (the input logits). - `_VocabParallelEntropy.forward` saves `action_logits = logits[:, -num_actions-1:-1, :]`, which is a view of the same storage. `_VocabParallelEntropy.backward` used a `sub_` + `add_` "modify-then-restore" trick on its saved tensor to avoid an allocation. The restore makes the values correct, but each in-place op still bumps the storage version counter (1 → 2 → 3). When `ChunkedDistributedLogprob.backward` later reads `ctx.saved_tensors`, its version check fails. Before #1610, the non-chunked `DistributedLogprob` saved `softmax_output` rather than the input logits, so the two Functions didn't share storage and the version bump went unnoticed. ### Fix `skyrl/backends/skyrl_train/distributed/megatron/model_utils.py` — replace the in-place `sub_`/`add_` pair with one out-of-place `vocab_parallel_logits.sub(...)`. Same math, no mutation of the saved tensor. Costs one temporary the size of the action-logits slice during entropy backward; freed when backward returns. Only on the `use_entropy_loss=True` path (default is `False`). ### Test plan - [x] Repro'd the original failure locally. - [x] `tp2_pp2_policy_seq_packing_with_entropy_loss` passes after fix. - [x] `tp2_pp2_policy_seq_packing` (non-entropy sibling) still passes.
What does this PR do?
Defaults to chunked logprobs for Megatron with chunk_size of 1024
Currently, we don't do any chunking for logprobs calculation in the Megatron backend. For models like Qwen 3.5 with large vocab sizes (248,320), this can lead to OOMs.
This PR chooses a chunk size of 1024 as the best balance for memory saved vs increase in latency. For large sequence lengths and vocab sizes, this can save a lot of memory - ex: for 65K sequence lenght and 128K vocab size, we save 42GB of memory while the time taken increases from 132.2 ms to 600 ms (acceptable IMO).
Current memory usage
Currently, the full logprobs tensor of size
(bsz, seq_len, vocab_size)gets materialized while we calculate the logprobs for the given input_ids.The only way users can save memory in logprobs calculation is with tensor parallelism right now, where each TP rank would work with Tp sharded logits
(bsz, seq_len, vocab_size/ TP). Chunking along sequence is also needed for large sequence lengths.Benchmark Settings
Device: NVIDIA B200
Vocab sizes: 32,000, 64,000 and 128,000
Measurement: forward + backward pass
Chunk sizes tested: None (no chunking), 32, 1024, 4096, 8192, 16384. I added 32 chunk size to the mix just to see what a small chunk size will yield.
Results
Vocab size 32K
Vocab Size 64K
Vocab Size 128K