From 87457250954a5cef2aa8797f13714840031b63ab Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Fri, 1 May 2026 19:39:05 +0000 Subject: [PATCH 01/11] megatron: default chunk_size=1024 for logprob computation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Set chunk_size=1024 (was None) for both the ref (inference_only=True) and policy (inference_only=False) calls to from_parallel_logits_to_logprobs in MegatronModelWrapper. This bounds peak GPU memory during the log_softmax+gather step at the cost of some wall-clock time. Add examples/benchmarks/bench_chunked_logprobs.py to quantify the time/memory trade-off across seq_lens 32k–128k. Signed-off-by: SumanthRH --- examples/benchmarks/bench_chunked_logprobs.py | 130 ++++++++++++++++++ .../megatron/megatron_model_wrapper.py | 4 +- 2 files changed, 132 insertions(+), 2 deletions(-) create mode 100644 examples/benchmarks/bench_chunked_logprobs.py diff --git a/examples/benchmarks/bench_chunked_logprobs.py b/examples/benchmarks/bench_chunked_logprobs.py new file mode 100644 index 0000000000..235659ce4a --- /dev/null +++ b/examples/benchmarks/bench_chunked_logprobs.py @@ -0,0 +1,130 @@ +""" +Benchmark: chunked vs non-chunked logprob computation. + +Tests log_softmax + gather over large vocab × large sequence, +which is the bottleneck in from_parallel_logits_to_logprobs. + +Usage: + CUDA_VISIBLE_DEVICES=0 python examples/benchmarks/bench_chunked_logprobs.py +""" + +import time +import torch +import torch.nn.functional as F + +VOCAB_SIZE = 32000 +SEQ_LENS = [32768, 65536, 131072] +CHUNK_SIZE = 1024 +WARMUP_REPS = 2 +BENCH_REPS = 5 + + +def logprobs_chunked(logits: torch.Tensor, labels: torch.Tensor, chunk_size=None) -> torch.Tensor: + """ + Compute log-probs matching the SkyRL chunked pattern. + + logits : [T, V] — requires_grad must be True for gradient path + labels : [T] — token indices in [0, V) + Returns: [T] — per-token log-probs + """ + if chunk_size is None: + # Non-chunked: materialise full float32 logits at once + log_probs = F.log_softmax(logits.float(), dim=-1) + return log_probs.gather(-1, labels.unsqueeze(-1)).squeeze(-1) + + results = [] + for i in range(0, logits.shape[0], chunk_size): + chunk = logits[i : i + chunk_size].float() + lp = F.log_softmax(chunk, dim=-1) + results.append(lp.gather(-1, labels[i : i + chunk_size].unsqueeze(-1)).squeeze(-1)) + return torch.cat(results) + + +def measure(logits: torch.Tensor, labels: torch.Tensor, chunk_size, reps: int): + """Run forward+backward and return (mean_wall_ms, peak_mem_bytes).""" + device = logits.device + times = [] + peak_mems = [] + + for _ in range(reps): + # Fresh leaf tensor each rep so grad accumulation doesn't interfere + logits_rep = logits.detach().requires_grad_(True) + + torch.cuda.reset_peak_memory_stats(device) + torch.cuda.synchronize(device) + t0 = time.perf_counter() + + out = logprobs_chunked(logits_rep, labels, chunk_size=chunk_size) + loss = out.sum() + loss.backward() + + torch.cuda.synchronize(device) + t1 = time.perf_counter() + + times.append((t1 - t0) * 1000.0) + peak_mems.append(torch.cuda.max_memory_allocated(device)) + + return sum(times) / len(times), sum(peak_mems) / len(peak_mems) + + +def main(): + if not torch.cuda.is_available(): + raise SystemError("No CUDA device found. Set CUDA_VISIBLE_DEVICES=0.") + + device = torch.device("cuda", 0) + print(f"Device : {torch.cuda.get_device_name(device)}") + print(f"Vocab : {VOCAB_SIZE:,} | chunk_size={CHUNK_SIZE} | warmup={WARMUP_REPS} bench={BENCH_REPS}\n") + + col_w = 14 + header = ( + f"{'seq_len':>10} " + f"{'no-chunk ms':>{col_w}} " + f"{'chunk ms':>{col_w}} " + f"{'speedup':>{col_w}} " + f"{'no-chunk MB':>{col_w}} " + f"{'chunk MB':>{col_w}} " + f"{'mem saved':>{col_w}}" + ) + sep = "-" * len(header) + print(header) + print(sep) + + for seq_len in SEQ_LENS: + # Allocate logits in bfloat16 (typical LLM dtype) with gradient tracking + logits = torch.randn(seq_len, VOCAB_SIZE, dtype=torch.bfloat16, device=device) + labels = torch.randint(0, VOCAB_SIZE, (seq_len,), device=device) + + # ----- warmup ----- + for _ in range(WARMUP_REPS): + _ = measure(logits, labels, chunk_size=None, reps=1) + _ = measure(logits, labels, chunk_size=CHUNK_SIZE, reps=1) + + # ----- benchmark ----- + t_none, mem_none = measure(logits, labels, chunk_size=None, reps=BENCH_REPS) + t_chunk, mem_chunk = measure(logits, labels, chunk_size=CHUNK_SIZE, reps=BENCH_REPS) + + speedup = t_none / t_chunk if t_chunk > 0 else float("inf") + mem_none_mb = mem_none / (1024**2) + mem_chunk_mb = mem_chunk / (1024**2) + mem_saved_mb = mem_none_mb - mem_chunk_mb + + print( + f"{seq_len:>10,} " + f"{t_none:>{col_w}.1f} " + f"{t_chunk:>{col_w}.1f} " + f"{speedup:>{col_w}.2f}x " + f"{mem_none_mb:>{col_w}.0f} " + f"{mem_chunk_mb:>{col_w}.0f} " + f"{mem_saved_mb:>{col_w}.0f}" + ) + + # Free memory before next iteration + del logits, labels + torch.cuda.empty_cache() + + print(sep) + print("All times are mean wall-clock (ms) over forward+backward passes.") + + +if __name__ == "__main__": + main() diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py index 152348590a..144a0f2e4a 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py @@ -104,7 +104,7 @@ def collection_func(logits, data): tp_group=tp_grp, inference_only=True, cp_group=None, # we handle cp gathering in `postprocess_packed_seqs` - chunk_size=None, + chunk_size=1024, ) return torch.tensor(0.0, device=token_logprobs.device), {"log_probs": token_logprobs} @@ -264,7 +264,7 @@ def loss_func(logits, data): tp_group=tp_grp, inference_only=False, cp_group=None, # we handle cp gathering in `postprocess_packed_seqs` - chunk_size=None, + chunk_size=1024, ) action_log_probs = token_logprobs[:, -num_actions:] From 9ee97596c95aefb1e6de610e01e043d819c93db5 Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Fri, 1 May 2026 19:42:28 +0000 Subject: [PATCH 02/11] benchmarks: sweep chunk_size over [None, 1024, 4096, 8192] Signed-off-by: SumanthRH --- examples/benchmarks/bench_chunked_logprobs.py | 73 ++++++++++--------- 1 file changed, 40 insertions(+), 33 deletions(-) diff --git a/examples/benchmarks/bench_chunked_logprobs.py b/examples/benchmarks/bench_chunked_logprobs.py index 235659ce4a..2887135313 100644 --- a/examples/benchmarks/bench_chunked_logprobs.py +++ b/examples/benchmarks/bench_chunked_logprobs.py @@ -14,7 +14,7 @@ VOCAB_SIZE = 32000 SEQ_LENS = [32768, 65536, 131072] -CHUNK_SIZE = 1024 +CHUNK_SIZES = [None, 1024, 4096, 8192] WARMUP_REPS = 2 BENCH_REPS = 5 @@ -73,17 +73,16 @@ def main(): device = torch.device("cuda", 0) print(f"Device : {torch.cuda.get_device_name(device)}") - print(f"Vocab : {VOCAB_SIZE:,} | chunk_size={CHUNK_SIZE} | warmup={WARMUP_REPS} bench={BENCH_REPS}\n") + print(f"Vocab : {VOCAB_SIZE:,} | chunk_sizes={CHUNK_SIZES} | warmup={WARMUP_REPS} bench={BENCH_REPS}\n") col_w = 14 header = ( f"{'seq_len':>10} " - f"{'no-chunk ms':>{col_w}} " - f"{'chunk ms':>{col_w}} " - f"{'speedup':>{col_w}} " - f"{'no-chunk MB':>{col_w}} " - f"{'chunk MB':>{col_w}} " - f"{'mem saved':>{col_w}}" + f"{'chunk_size':>10} " + f"{'time ms':>{col_w}} " + f"{'peak MB':>{col_w}} " + f"{'vs no-chunk':>{col_w}} " + f"{'mem saved MB':>{col_w}}" ) sep = "-" * len(header) print(header) @@ -95,35 +94,43 @@ def main(): labels = torch.randint(0, VOCAB_SIZE, (seq_len,), device=device) # ----- warmup ----- - for _ in range(WARMUP_REPS): - _ = measure(logits, labels, chunk_size=None, reps=1) - _ = measure(logits, labels, chunk_size=CHUNK_SIZE, reps=1) - - # ----- benchmark ----- - t_none, mem_none = measure(logits, labels, chunk_size=None, reps=BENCH_REPS) - t_chunk, mem_chunk = measure(logits, labels, chunk_size=CHUNK_SIZE, reps=BENCH_REPS) - - speedup = t_none / t_chunk if t_chunk > 0 else float("inf") - mem_none_mb = mem_none / (1024**2) - mem_chunk_mb = mem_chunk / (1024**2) - mem_saved_mb = mem_none_mb - mem_chunk_mb - - print( - f"{seq_len:>10,} " - f"{t_none:>{col_w}.1f} " - f"{t_chunk:>{col_w}.1f} " - f"{speedup:>{col_w}.2f}x " - f"{mem_none_mb:>{col_w}.0f} " - f"{mem_chunk_mb:>{col_w}.0f} " - f"{mem_saved_mb:>{col_w}.0f}" - ) - - # Free memory before next iteration + for cs in CHUNK_SIZES: + for _ in range(WARMUP_REPS): + _ = measure(logits, labels, chunk_size=cs, reps=1) + + # ----- benchmark: collect baseline (no-chunk) first ----- + t_baseline, mem_baseline = measure(logits, labels, chunk_size=None, reps=BENCH_REPS) + + # ----- print one row per chunk_size ----- + for cs in CHUNK_SIZES: + if cs is None: + t_cs, mem_cs = t_baseline, mem_baseline + else: + t_cs, mem_cs = measure(logits, labels, chunk_size=cs, reps=BENCH_REPS) + + speedup = t_baseline / t_cs if t_cs > 0 else float("inf") + mem_cs_mb = mem_cs / (1024**2) + mem_baseline_mb = mem_baseline / (1024**2) + mem_saved_mb = mem_baseline_mb - mem_cs_mb + cs_label = "None" if cs is None else str(cs) + + print( + f"{seq_len:>10,} " + f"{cs_label:>10} " + f"{t_cs:>{col_w}.1f} " + f"{mem_cs_mb:>{col_w}.0f} " + f"{speedup:>{col_w}.2f}x " + f"{mem_saved_mb:>{col_w}.0f}" + ) + + print(sep) + + # Free memory before next seq_len del logits, labels torch.cuda.empty_cache() - print(sep) print("All times are mean wall-clock (ms) over forward+backward passes.") + print("vs no-chunk: speedup relative to chunk_size=None (>1 = faster).") if __name__ == "__main__": From b8712e3075bad1f498ac1a71d78dd0ae68a2d397 Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Fri, 1 May 2026 19:49:44 +0000 Subject: [PATCH 03/11] benchmarks: add chunk_size=32 to sweep Signed-off-by: SumanthRH --- examples/benchmarks/bench_chunked_logprobs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/benchmarks/bench_chunked_logprobs.py b/examples/benchmarks/bench_chunked_logprobs.py index 2887135313..af125a3042 100644 --- a/examples/benchmarks/bench_chunked_logprobs.py +++ b/examples/benchmarks/bench_chunked_logprobs.py @@ -14,7 +14,7 @@ VOCAB_SIZE = 32000 SEQ_LENS = [32768, 65536, 131072] -CHUNK_SIZES = [None, 1024, 4096, 8192] +CHUNK_SIZES = [None, 32, 1024, 4096, 8192] WARMUP_REPS = 2 BENCH_REPS = 5 From 198313ac763cf56b0f00f781ac58a5319f25249b Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Fri, 1 May 2026 21:42:38 +0000 Subject: [PATCH 04/11] x Signed-off-by: SumanthRH --- examples/benchmarks/bench_chunked_logprobs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/benchmarks/bench_chunked_logprobs.py b/examples/benchmarks/bench_chunked_logprobs.py index af125a3042..2887135313 100644 --- a/examples/benchmarks/bench_chunked_logprobs.py +++ b/examples/benchmarks/bench_chunked_logprobs.py @@ -14,7 +14,7 @@ VOCAB_SIZE = 32000 SEQ_LENS = [32768, 65536, 131072] -CHUNK_SIZES = [None, 32, 1024, 4096, 8192] +CHUNK_SIZES = [None, 1024, 4096, 8192] WARMUP_REPS = 2 BENCH_REPS = 5 From 1dd293d4d9c205c18600f6fc38ac5f59cdf6bd98 Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Fri, 1 May 2026 22:52:23 +0000 Subject: [PATCH 05/11] Add vocab sizes 64K and 128K, OOM handling Signed-off-by: SumanthRH --- examples/benchmarks/bench_chunked_logprobs.py | 145 ++++++++++++------ .../megatron/megatron_model_wrapper.py | 4 +- 2 files changed, 101 insertions(+), 48 deletions(-) diff --git a/examples/benchmarks/bench_chunked_logprobs.py b/examples/benchmarks/bench_chunked_logprobs.py index 2887135313..5e3e7df85d 100644 --- a/examples/benchmarks/bench_chunked_logprobs.py +++ b/examples/benchmarks/bench_chunked_logprobs.py @@ -12,9 +12,9 @@ import torch import torch.nn.functional as F -VOCAB_SIZE = 32000 +VOCAB_SIZES = [32000, 64000, 128000] SEQ_LENS = [32768, 65536, 131072] -CHUNK_SIZES = [None, 1024, 4096, 8192] +CHUNK_SIZES = [None, 1024, 4096, 8192, 16384] WARMUP_REPS = 2 BENCH_REPS = 5 @@ -72,11 +72,12 @@ def main(): raise SystemError("No CUDA device found. Set CUDA_VISIBLE_DEVICES=0.") device = torch.device("cuda", 0) - print(f"Device : {torch.cuda.get_device_name(device)}") - print(f"Vocab : {VOCAB_SIZE:,} | chunk_sizes={CHUNK_SIZES} | warmup={WARMUP_REPS} bench={BENCH_REPS}\n") + print(f"Device : {torch.cuda.get_device_name(device)}") + print(f"Vocab sizes : {VOCAB_SIZES} | chunk_sizes={CHUNK_SIZES} | warmup={WARMUP_REPS} bench={BENCH_REPS}\n") col_w = 14 header = ( + f"{'vocab_size':>10} " f"{'seq_len':>10} " f"{'chunk_size':>10} " f"{'time ms':>{col_w}} " @@ -85,51 +86,103 @@ def main(): f"{'mem saved MB':>{col_w}}" ) sep = "-" * len(header) - print(header) - print(sep) - - for seq_len in SEQ_LENS: - # Allocate logits in bfloat16 (typical LLM dtype) with gradient tracking - logits = torch.randn(seq_len, VOCAB_SIZE, dtype=torch.bfloat16, device=device) - labels = torch.randint(0, VOCAB_SIZE, (seq_len,), device=device) - - # ----- warmup ----- - for cs in CHUNK_SIZES: - for _ in range(WARMUP_REPS): - _ = measure(logits, labels, chunk_size=cs, reps=1) - - # ----- benchmark: collect baseline (no-chunk) first ----- - t_baseline, mem_baseline = measure(logits, labels, chunk_size=None, reps=BENCH_REPS) - - # ----- print one row per chunk_size ----- - for cs in CHUNK_SIZES: - if cs is None: - t_cs, mem_cs = t_baseline, mem_baseline - else: - t_cs, mem_cs = measure(logits, labels, chunk_size=cs, reps=BENCH_REPS) - - speedup = t_baseline / t_cs if t_cs > 0 else float("inf") - mem_cs_mb = mem_cs / (1024**2) - mem_baseline_mb = mem_baseline / (1024**2) - mem_saved_mb = mem_baseline_mb - mem_cs_mb - cs_label = "None" if cs is None else str(cs) - - print( - f"{seq_len:>10,} " - f"{cs_label:>10} " - f"{t_cs:>{col_w}.1f} " - f"{mem_cs_mb:>{col_w}.0f} " - f"{speedup:>{col_w}.2f}x " - f"{mem_saved_mb:>{col_w}.0f}" - ) + for vocab_size in VOCAB_SIZES: + print(f"\n=== vocab_size={vocab_size:,} ===") + print(header) print(sep) - # Free memory before next seq_len - del logits, labels - torch.cuda.empty_cache() - - print("All times are mean wall-clock (ms) over forward+backward passes.") + for seq_len in SEQ_LENS: + # Allocate logits in bfloat16 (typical LLM dtype) with gradient tracking + try: + logits = torch.randn(seq_len, vocab_size, dtype=torch.bfloat16, device=device) + labels = torch.randint(0, vocab_size, (seq_len,), device=device) + except torch.OutOfMemoryError: + oom_row = ( + f"{vocab_size:>10,} " + f"{seq_len:>10,} " + f"{'(all)':>10} " + f"{'OOM':>{col_w}} " + f"{'OOM':>{col_w}} " + f"{'OOM':>{col_w}} " + f"{'OOM':>{col_w}}" + ) + print(oom_row) + print(sep) + torch.cuda.empty_cache() + continue + + # ----- warmup (skip OOM-prone chunk sizes gracefully) ----- + oom_chunk_sizes = set() + for cs in CHUNK_SIZES: + try: + for _ in range(WARMUP_REPS): + _ = measure(logits, labels, chunk_size=cs, reps=1) + except torch.OutOfMemoryError: + oom_chunk_sizes.add(cs) + torch.cuda.empty_cache() + + # ----- benchmark: collect baseline (no-chunk) first ----- + if None in oom_chunk_sizes: + t_baseline, mem_baseline = None, None + else: + t_baseline, mem_baseline = measure(logits, labels, chunk_size=None, reps=BENCH_REPS) + + # ----- print one row per chunk_size ----- + for cs in CHUNK_SIZES: + cs_label = "None" if cs is None else str(cs) + if cs in oom_chunk_sizes: + print( + f"{vocab_size:>10,} " + f"{seq_len:>10,} " + f"{cs_label:>10} " + f"{'OOM':>{col_w}} " + f"{'OOM':>{col_w}} " + f"{'OOM':>{col_w}} " + f"{'OOM':>{col_w}}" + ) + continue + + if cs is None: + t_cs, mem_cs = t_baseline, mem_baseline + else: + try: + t_cs, mem_cs = measure(logits, labels, chunk_size=cs, reps=BENCH_REPS) + except torch.OutOfMemoryError: + torch.cuda.empty_cache() + print( + f"{vocab_size:>10,} " + f"{seq_len:>10,} " + f"{cs_label:>10} " + f"{'OOM':>{col_w}} " + f"{'OOM':>{col_w}} " + f"{'OOM':>{col_w}} " + f"{'OOM':>{col_w}}" + ) + continue + + speedup = t_baseline / t_cs if (t_baseline is not None and t_cs > 0) else float("inf") + mem_cs_mb = mem_cs / (1024**2) + mem_baseline_mb = mem_baseline / (1024**2) if mem_baseline is not None else 0 + mem_saved_mb = mem_baseline_mb - mem_cs_mb + + print( + f"{vocab_size:>10,} " + f"{seq_len:>10,} " + f"{cs_label:>10} " + f"{t_cs:>{col_w}.1f} " + f"{mem_cs_mb:>{col_w}.0f} " + f"{speedup:>{col_w}.2f}x " + f"{mem_saved_mb:>{col_w}.0f}" + ) + + print(sep) + + # Free memory before next seq_len + del logits, labels + torch.cuda.empty_cache() + + print("\nAll times are mean wall-clock (ms) over forward+backward passes.") print("vs no-chunk: speedup relative to chunk_size=None (>1 = faster).") diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py index 144a0f2e4a..e951428c10 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py @@ -104,7 +104,7 @@ def collection_func(logits, data): tp_group=tp_grp, inference_only=True, cp_group=None, # we handle cp gathering in `postprocess_packed_seqs` - chunk_size=1024, + chunk_size=8192, ) return torch.tensor(0.0, device=token_logprobs.device), {"log_probs": token_logprobs} @@ -264,7 +264,7 @@ def loss_func(logits, data): tp_group=tp_grp, inference_only=False, cp_group=None, # we handle cp gathering in `postprocess_packed_seqs` - chunk_size=1024, + chunk_size=8192, ) action_log_probs = token_logprobs[:, -num_actions:] From 4d04fa476081c9940fd3a1ecca158e2159eaaf0b Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Fri, 1 May 2026 23:30:49 +0000 Subject: [PATCH 06/11] fix benchmark script Signed-off-by: SumanthRH --- examples/benchmarks/bench_chunked_logprobs.py | 235 +++++++++++------- .../megatron/megatron_model_wrapper.py | 4 +- 2 files changed, 151 insertions(+), 88 deletions(-) diff --git a/examples/benchmarks/bench_chunked_logprobs.py b/examples/benchmarks/bench_chunked_logprobs.py index 5e3e7df85d..fc3b3ae3e5 100644 --- a/examples/benchmarks/bench_chunked_logprobs.py +++ b/examples/benchmarks/bench_chunked_logprobs.py @@ -1,60 +1,86 @@ """ Benchmark: chunked vs non-chunked logprob computation. -Tests log_softmax + gather over large vocab × large sequence, -which is the bottleneck in from_parallel_logits_to_logprobs. +Tests ChunkedDistributedLogprob and DistributedLogprob from +skyrl/backends/skyrl_train/distributed/megatron/model_utils.py, +which is the actual code path used in SkyRL's MegatronModelWrapper. -Usage: - CUDA_VISIBLE_DEVICES=0 python examples/benchmarks/bench_chunked_logprobs.py +Usage (single GPU, torchrun required for distributed init): + uv run --isolated --extra megatron torchrun --nproc_per_node=1 \\ + examples/benchmarks/bench_chunked_logprobs.py """ +import os import time + import torch -import torch.nn.functional as F +import torch.distributed as dist + +# Must set NCCL env before initialising the process group +os.environ.setdefault("NCCL_NET", "Socket") +os.environ.setdefault("NCCL_NET_PLUGIN", "none") VOCAB_SIZES = [32000, 64000, 128000] SEQ_LENS = [32768, 65536, 131072] -CHUNK_SIZES = [None, 1024, 4096, 8192, 16384] +# chunk_size=None routes through DistributedLogprob (no chunking); all others use +# ChunkedDistributedLogprob with the given chunk size. +CHUNK_SIZES = [None, 32, 1024, 4096, 8192, 16384] WARMUP_REPS = 2 BENCH_REPS = 5 -def logprobs_chunked(logits: torch.Tensor, labels: torch.Tensor, chunk_size=None) -> torch.Tensor: - """ - Compute log-probs matching the SkyRL chunked pattern. +def measure( + vocab_parallel_logits: torch.Tensor, + target: torch.Tensor, + vocab_start_index: int, + vocab_end_index: int, + chunk_size, + tp_group, + reps: int, +): + """Run forward+backward through the real SkyRL logprob kernel. - logits : [T, V] — requires_grad must be True for gradient path - labels : [T] — token indices in [0, V) - Returns: [T] — per-token log-probs + Returns (mean_wall_ms, mean_peak_mem_bytes). """ - if chunk_size is None: - # Non-chunked: materialise full float32 logits at once - log_probs = F.log_softmax(logits.float(), dim=-1) - return log_probs.gather(-1, labels.unsqueeze(-1)).squeeze(-1) - - results = [] - for i in range(0, logits.shape[0], chunk_size): - chunk = logits[i : i + chunk_size].float() - lp = F.log_softmax(chunk, dim=-1) - results.append(lp.gather(-1, labels[i : i + chunk_size].unsqueeze(-1)).squeeze(-1)) - return torch.cat(results) - - -def measure(logits: torch.Tensor, labels: torch.Tensor, chunk_size, reps: int): - """Run forward+backward and return (mean_wall_ms, peak_mem_bytes).""" - device = logits.device + from skyrl.backends.skyrl_train.distributed.megatron.model_utils import ( + ChunkedDistributedLogprob, + DistributedLogprob, + ) + + device = vocab_parallel_logits.device times = [] peak_mems = [] for _ in range(reps): - # Fresh leaf tensor each rep so grad accumulation doesn't interfere - logits_rep = logits.detach().requires_grad_(True) + # Fresh leaf tensor each rep so grad accumulation does not interfere. + logits_rep = vocab_parallel_logits.detach().requires_grad_(True) torch.cuda.reset_peak_memory_stats(device) torch.cuda.synchronize(device) t0 = time.perf_counter() - out = logprobs_chunked(logits_rep, labels, chunk_size=chunk_size) + if chunk_size is None: + # Non-chunked real implementation (DistributedLogprob) + out = DistributedLogprob.apply( + logits_rep, + target, + vocab_start_index, + vocab_end_index, + tp_group, + False, # inference_only=False -> saves tensors for backward + ) + else: + # Chunked real implementation (ChunkedDistributedLogprob) + out = ChunkedDistributedLogprob.apply( + logits_rep, + target, + vocab_start_index, + vocab_end_index, + chunk_size, + tp_group, + False, # inference_only=False -> saves tensors for backward + ) + loss = out.sum() loss.backward() @@ -68,12 +94,28 @@ def measure(logits: torch.Tensor, labels: torch.Tensor, chunk_size, reps: int): def main(): - if not torch.cuda.is_available(): - raise SystemError("No CUDA device found. Set CUDA_VISIBLE_DEVICES=0.") + # --- Distributed init (required by the real SkyRL kernel) --- + dist.init_process_group(backend="nccl") + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + + # Initialise Megatron model-parallel state (TP=1, single GPU). + import megatron.core.parallel_state as mpu + + mpu.initialize_model_parallel(tensor_model_parallel_size=1) - device = torch.device("cuda", 0) - print(f"Device : {torch.cuda.get_device_name(device)}") - print(f"Vocab sizes : {VOCAB_SIZES} | chunk_sizes={CHUNK_SIZES} | warmup={WARMUP_REPS} bench={BENCH_REPS}\n") + tp_group = dist.group.WORLD # TP=1, so the whole-world group is the TP group + + device = torch.device("cuda", local_rank) + + if dist.get_rank() == 0: + print(f"Device : {torch.cuda.get_device_name(device)}") + print(f"World size : {dist.get_world_size()}") + print( + f"Vocab sizes : {VOCAB_SIZES} | chunk_sizes={CHUNK_SIZES} " + f"| warmup={WARMUP_REPS} bench={BENCH_REPS}" + ) + print("Implementation: real SkyRL ChunkedDistributedLogprob / DistributedLogprob\n") col_w = 14 header = ( @@ -88,27 +130,35 @@ def main(): sep = "-" * len(header) for vocab_size in VOCAB_SIZES: - print(f"\n=== vocab_size={vocab_size:,} ===") - print(header) - print(sep) + if dist.get_rank() == 0: + print(f"\n=== vocab_size={vocab_size:,} ===") + print(header) + print(sep) for seq_len in SEQ_LENS: - # Allocate logits in bfloat16 (typical LLM dtype) with gradient tracking + # With TP=1 the full vocab lives on this rank. + vocab_start_index = 0 + vocab_end_index = vocab_size + + # Shape expected by the real kernel: [batch, seq, vocab // TP] + # We use batch=1 to keep allocations comparable to a single-sequence workload. try: - logits = torch.randn(seq_len, vocab_size, dtype=torch.bfloat16, device=device) - labels = torch.randint(0, vocab_size, (seq_len,), device=device) + logits = torch.randn(1, seq_len, vocab_size, dtype=torch.bfloat16, device=device) + # targets: [batch, seq], values in [0, vocab_size) + target = torch.randint(0, vocab_size, (1, seq_len), device=device) except torch.OutOfMemoryError: - oom_row = ( - f"{vocab_size:>10,} " - f"{seq_len:>10,} " - f"{'(all)':>10} " - f"{'OOM':>{col_w}} " - f"{'OOM':>{col_w}} " - f"{'OOM':>{col_w}} " - f"{'OOM':>{col_w}}" - ) - print(oom_row) - print(sep) + if dist.get_rank() == 0: + oom_row = ( + f"{vocab_size:>10,} " + f"{seq_len:>10,} " + f"{'(all)':>10} " + f"{'OOM':>{col_w}} " + f"{'OOM':>{col_w}} " + f"{'OOM':>{col_w}} " + f"{'OOM':>{col_w}}" + ) + print(oom_row) + print(sep) torch.cuda.empty_cache() continue @@ -117,7 +167,7 @@ def main(): for cs in CHUNK_SIZES: try: for _ in range(WARMUP_REPS): - _ = measure(logits, labels, chunk_size=cs, reps=1) + measure(logits, target, vocab_start_index, vocab_end_index, cs, tp_group, reps=1) except torch.OutOfMemoryError: oom_chunk_sizes.add(cs) torch.cuda.empty_cache() @@ -126,30 +176,16 @@ def main(): if None in oom_chunk_sizes: t_baseline, mem_baseline = None, None else: - t_baseline, mem_baseline = measure(logits, labels, chunk_size=None, reps=BENCH_REPS) + t_baseline, mem_baseline = measure( + logits, target, vocab_start_index, vocab_end_index, None, tp_group, reps=BENCH_REPS + ) # ----- print one row per chunk_size ----- for cs in CHUNK_SIZES: cs_label = "None" if cs is None else str(cs) - if cs in oom_chunk_sizes: - print( - f"{vocab_size:>10,} " - f"{seq_len:>10,} " - f"{cs_label:>10} " - f"{'OOM':>{col_w}} " - f"{'OOM':>{col_w}} " - f"{'OOM':>{col_w}} " - f"{'OOM':>{col_w}}" - ) - continue - if cs is None: - t_cs, mem_cs = t_baseline, mem_baseline - else: - try: - t_cs, mem_cs = measure(logits, labels, chunk_size=cs, reps=BENCH_REPS) - except torch.OutOfMemoryError: - torch.cuda.empty_cache() + if cs in oom_chunk_sizes: + if dist.get_rank() == 0: print( f"{vocab_size:>10,} " f"{seq_len:>10,} " @@ -159,6 +195,27 @@ def main(): f"{'OOM':>{col_w}} " f"{'OOM':>{col_w}}" ) + continue + + if cs is None: + t_cs, mem_cs = t_baseline, mem_baseline + else: + try: + t_cs, mem_cs = measure( + logits, target, vocab_start_index, vocab_end_index, cs, tp_group, reps=BENCH_REPS + ) + except torch.OutOfMemoryError: + torch.cuda.empty_cache() + if dist.get_rank() == 0: + print( + f"{vocab_size:>10,} " + f"{seq_len:>10,} " + f"{cs_label:>10} " + f"{'OOM':>{col_w}} " + f"{'OOM':>{col_w}} " + f"{'OOM':>{col_w}} " + f"{'OOM':>{col_w}}" + ) continue speedup = t_baseline / t_cs if (t_baseline is not None and t_cs > 0) else float("inf") @@ -166,24 +223,30 @@ def main(): mem_baseline_mb = mem_baseline / (1024**2) if mem_baseline is not None else 0 mem_saved_mb = mem_baseline_mb - mem_cs_mb - print( - f"{vocab_size:>10,} " - f"{seq_len:>10,} " - f"{cs_label:>10} " - f"{t_cs:>{col_w}.1f} " - f"{mem_cs_mb:>{col_w}.0f} " - f"{speedup:>{col_w}.2f}x " - f"{mem_saved_mb:>{col_w}.0f}" - ) + if dist.get_rank() == 0: + print( + f"{vocab_size:>10,} " + f"{seq_len:>10,} " + f"{cs_label:>10} " + f"{t_cs:>{col_w}.1f} " + f"{mem_cs_mb:>{col_w}.0f} " + f"{speedup:>{col_w}.2f}x " + f"{mem_saved_mb:>{col_w}.0f}" + ) - print(sep) + if dist.get_rank() == 0: + print(sep) # Free memory before next seq_len - del logits, labels + del logits, target torch.cuda.empty_cache() - print("\nAll times are mean wall-clock (ms) over forward+backward passes.") - print("vs no-chunk: speedup relative to chunk_size=None (>1 = faster).") + if dist.get_rank() == 0: + print("\nAll times are mean wall-clock (ms) over forward+backward passes.") + print("vs no-chunk: speedup relative to chunk_size=None (>1 = faster).") + print("chunk_size=None uses DistributedLogprob; all others use ChunkedDistributedLogprob.") + + dist.destroy_process_group() if __name__ == "__main__": diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py index e951428c10..144a0f2e4a 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py @@ -104,7 +104,7 @@ def collection_func(logits, data): tp_group=tp_grp, inference_only=True, cp_group=None, # we handle cp gathering in `postprocess_packed_seqs` - chunk_size=8192, + chunk_size=1024, ) return torch.tensor(0.0, device=token_logprobs.device), {"log_probs": token_logprobs} @@ -264,7 +264,7 @@ def loss_func(logits, data): tp_group=tp_grp, inference_only=False, cp_group=None, # we handle cp gathering in `postprocess_packed_seqs` - chunk_size=8192, + chunk_size=1024, ) action_log_probs = token_logprobs[:, -num_actions:] From 9e08a177d1d46a402b4739b24066aee2683fcbc4 Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Sat, 2 May 2026 00:18:29 +0000 Subject: [PATCH 07/11] benchmarks: address review comments (type annotations, comments, OOM display) Signed-off-by: SumanthRH --- examples/benchmarks/bench_chunked_logprobs.py | 18 +++++++++++------- .../workers/megatron/megatron_model_wrapper.py | 9 +++++++-- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/examples/benchmarks/bench_chunked_logprobs.py b/examples/benchmarks/bench_chunked_logprobs.py index fc3b3ae3e5..391098e636 100644 --- a/examples/benchmarks/bench_chunked_logprobs.py +++ b/examples/benchmarks/bench_chunked_logprobs.py @@ -12,6 +12,7 @@ import os import time +from typing import Optional import torch import torch.distributed as dist @@ -34,8 +35,8 @@ def measure( target: torch.Tensor, vocab_start_index: int, vocab_end_index: int, - chunk_size, - tp_group, + chunk_size: Optional[int], + tp_group: torch.distributed.ProcessGroup, reps: int, ): """Run forward+backward through the real SkyRL logprob kernel. @@ -218,10 +219,13 @@ def main(): ) continue - speedup = t_baseline / t_cs if (t_baseline is not None and t_cs > 0) else float("inf") mem_cs_mb = mem_cs / (1024**2) - mem_baseline_mb = mem_baseline / (1024**2) if mem_baseline is not None else 0 - mem_saved_mb = mem_baseline_mb - mem_cs_mb + if t_baseline is not None and t_cs > 0: + speedup_str = f"{t_baseline / t_cs:>{col_w}.2f}x" + mem_saved_str = f"{mem_baseline / (1024**2) - mem_cs_mb:>{col_w}.0f}" + else: + speedup_str = f"{'N/A':>{col_w}}" + mem_saved_str = f"{'N/A':>{col_w}}" if dist.get_rank() == 0: print( @@ -230,8 +234,8 @@ def main(): f"{cs_label:>10} " f"{t_cs:>{col_w}.1f} " f"{mem_cs_mb:>{col_w}.0f} " - f"{speedup:>{col_w}.2f}x " - f"{mem_saved_mb:>{col_w}.0f}" + f"{speedup_str} " + f"{mem_saved_str}" ) if dist.get_rank() == 0: diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py index 144a0f2e4a..d6252d98ca 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py @@ -32,6 +32,11 @@ from skyrl.backends.skyrl_train.utils.torch_utils import masked_mean from skyrl.train.config import TrainerConfig +# NOTE (sumanthrh): We use a chunk size of 1024 for calaulating logprobs +# from logits to avoid OOMs for large sequence lengths. +# For more details, see https://github.com/NovaSky-AI/SkyRL/pull/1610 +CHUNK_SIZE_LOGPROBS = 1024 + class MegatronModelWrapper: def __init__( @@ -104,7 +109,7 @@ def collection_func(logits, data): tp_group=tp_grp, inference_only=True, cp_group=None, # we handle cp gathering in `postprocess_packed_seqs` - chunk_size=1024, + chunk_size=CHUNK_SIZE_LOGPROBS, # chunk seq dim to bound peak memory; see examples/benchmarks/bench_chunked_logprobs.py ) return torch.tensor(0.0, device=token_logprobs.device), {"log_probs": token_logprobs} @@ -264,7 +269,7 @@ def loss_func(logits, data): tp_group=tp_grp, inference_only=False, cp_group=None, # we handle cp gathering in `postprocess_packed_seqs` - chunk_size=1024, + chunk_size=CHUNK_SIZE_LOGPROBS, # chunk seq dim to bound peak memory; see examples/benchmarks/bench_chunked_logprobs.py ) action_log_probs = token_logprobs[:, -num_actions:] From c2e6d6689b19808e4d7f96a95638db65a75c54f1 Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Sat, 2 May 2026 00:18:58 +0000 Subject: [PATCH 08/11] megatron: fix typo calaulating -> calculating in CHUNK_SIZE_LOGPROBS comment Signed-off-by: SumanthRH --- .../skyrl_train/workers/megatron/megatron_model_wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py index d6252d98ca..d8db6ad37a 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py @@ -32,7 +32,7 @@ from skyrl.backends.skyrl_train.utils.torch_utils import masked_mean from skyrl.train.config import TrainerConfig -# NOTE (sumanthrh): We use a chunk size of 1024 for calaulating logprobs +# NOTE (sumanthrh): We use a chunk size of 1024 for calculating logprobs # from logits to avoid OOMs for large sequence lengths. # For more details, see https://github.com/NovaSky-AI/SkyRL/pull/1610 CHUNK_SIZE_LOGPROBS = 1024 From ca1976fdaec7f91fcf2488123e7eca64273880fb Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Wed, 6 May 2026 21:04:26 +0000 Subject: [PATCH 09/11] move benchmark file to benchmarks/ Signed-off-by: SumanthRH --- .../backends/skyrl_train/utils/torch_utils.py | 10 ++-- .../skyrl_train/workers/fsdp/fsdp_worker.py | 2 + .../megatron/megatron_model_wrapper.py | 9 +--- .../skyrl_train/workers/model_wrapper.py | 8 ++- .../benchmarks/bench_chunked_logprobs.py | 51 +++++-------------- skyrl/train/config/config.py | 4 ++ 6 files changed, 32 insertions(+), 52 deletions(-) rename {examples => skyrl}/benchmarks/bench_chunked_logprobs.py (81%) diff --git a/skyrl/backends/skyrl_train/utils/torch_utils.py b/skyrl/backends/skyrl_train/utils/torch_utils.py index 70ea95cc0b..a7a368f074 100644 --- a/skyrl/backends/skyrl_train/utils/torch_utils.py +++ b/skyrl/backends/skyrl_train/utils/torch_utils.py @@ -28,16 +28,14 @@ FLASH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE = False -CHUNK_SIZE = 1024 - - def chunked_cross_entropy_from_log_probs( - logprobs: Float[torch.Tensor, "batch_size seqlen vocab_size"], requires_grad: bool = False + logprobs: Float[torch.Tensor, "batch_size seqlen vocab_size"], + requires_grad: bool = False, + chunk_size: int = 1024, ) -> Float[torch.Tensor, "batch_size seqlen"]: cm = nullcontext() if requires_grad else torch.no_grad() with cm: # Calculate entropy in chunks to avoid OOM - chunk_size = CHUNK_SIZE num_chunks = (logprobs.size(1) + chunk_size - 1) // chunk_size entropy_tensor = torch.zeros( (logprobs.shape[0], logprobs.shape[1]), dtype=logprobs.dtype, device=logprobs.device @@ -61,6 +59,7 @@ def chunked_entropy_from_logits( logits: Float[torch.Tensor, "batch_size seqlen vocab"], requires_grad: bool = False, attention_mask: Float[torch.Tensor, "batch_size seqlen"] = None, + chunk_size: int = 1024, ) -> Float[torch.Tensor, "batch_size seqlen"]: """Chunked entropy calculation from logits. @@ -88,7 +87,6 @@ def chunked_entropy_from_logits( cm = nullcontext() if requires_grad else torch.no_grad() with cm: # Calculate entropy in chunks to avoid OOM - chunk_size = CHUNK_SIZE num_chunks = (logits.size(1) + chunk_size - 1) // chunk_size entropy_tensor = torch.zeros((logits.shape[0], logits.shape[1]), dtype=logits.dtype, device=logits.device) diff --git a/skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py b/skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py index a82440c91e..b8026d1854 100644 --- a/skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py +++ b/skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py @@ -193,6 +193,7 @@ def init_model(self, model_path, num_training_steps: int = None): model_config_kwargs=self.cfg.policy.model_config_kwargs, meta_init=use_meta, language_model_only=self.cfg.policy.language_model_only, + logprobs_chunk_size=self.cfg.logprobs_chunk_size, ) self._seq_parallel_monkey_patch(model=wrapped_model.model) @@ -415,6 +416,7 @@ def init_model(self, model_path): model_config_kwargs=self.cfg.ref.model_config_kwargs, meta_init=use_meta, language_model_only=self.cfg.ref.language_model_only, + logprobs_chunk_size=self.cfg.logprobs_chunk_size, ) self._seq_parallel_monkey_patch(model=wrapped_model.model) diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py index d8db6ad37a..f2d1aa0a44 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py @@ -32,11 +32,6 @@ from skyrl.backends.skyrl_train.utils.torch_utils import masked_mean from skyrl.train.config import TrainerConfig -# NOTE (sumanthrh): We use a chunk size of 1024 for calculating logprobs -# from logits to avoid OOMs for large sequence lengths. -# For more details, see https://github.com/NovaSky-AI/SkyRL/pull/1610 -CHUNK_SIZE_LOGPROBS = 1024 - class MegatronModelWrapper: def __init__( @@ -109,7 +104,7 @@ def collection_func(logits, data): tp_group=tp_grp, inference_only=True, cp_group=None, # we handle cp gathering in `postprocess_packed_seqs` - chunk_size=CHUNK_SIZE_LOGPROBS, # chunk seq dim to bound peak memory; see examples/benchmarks/bench_chunked_logprobs.py + chunk_size=self.cfg.logprobs_chunk_size, # chunk seq dim to bound peak memory ) return torch.tensor(0.0, device=token_logprobs.device), {"log_probs": token_logprobs} @@ -269,7 +264,7 @@ def loss_func(logits, data): tp_group=tp_grp, inference_only=False, cp_group=None, # we handle cp gathering in `postprocess_packed_seqs` - chunk_size=CHUNK_SIZE_LOGPROBS, # chunk seq dim to bound peak memory; see examples/benchmarks/bench_chunked_logprobs.py + chunk_size=self.cfg.logprobs_chunk_size, # chunk seq dim to bound peak memory ) action_log_probs = token_logprobs[:, -num_actions:] diff --git a/skyrl/backends/skyrl_train/workers/model_wrapper.py b/skyrl/backends/skyrl_train/workers/model_wrapper.py index 53c876f73e..ba2069f27a 100644 --- a/skyrl/backends/skyrl_train/workers/model_wrapper.py +++ b/skyrl/backends/skyrl_train/workers/model_wrapper.py @@ -80,6 +80,7 @@ def __init__( model_config_kwargs: dict = {}, meta_init: bool = False, language_model_only: bool = False, + logprobs_chunk_size: int = 1024, **kwargs, ) -> None: super().__init__() @@ -218,6 +219,8 @@ def __init__( else: self.model = pretrain_or_model + self.logprobs_chunk_size = logprobs_chunk_size + # TODO (sumanthrh): do the same for `logprobs_from_logits` and test. # Credits: https://www.tylerromero.com/posts/2025-02-selective-log-softmax/#efficient-solution self.chunked_entropy_from_logits_fn = ( @@ -351,7 +354,10 @@ def forward( entropy_mask = attention_mask_fwd entropy_BS = self.chunked_entropy_from_logits_fn( - logits_BSV, requires_grad=entropy_requires_grad, attention_mask=entropy_mask + logits_BSV, + requires_grad=entropy_requires_grad, + attention_mask=entropy_mask, + chunk_size=self.logprobs_chunk_size, ) if self.sequence_parallel_size > 1: diff --git a/examples/benchmarks/bench_chunked_logprobs.py b/skyrl/benchmarks/bench_chunked_logprobs.py similarity index 81% rename from examples/benchmarks/bench_chunked_logprobs.py rename to skyrl/benchmarks/bench_chunked_logprobs.py index 391098e636..1c149ff2ba 100644 --- a/examples/benchmarks/bench_chunked_logprobs.py +++ b/skyrl/benchmarks/bench_chunked_logprobs.py @@ -7,7 +7,7 @@ Usage (single GPU, torchrun required for distributed init): uv run --isolated --extra megatron torchrun --nproc_per_node=1 \\ - examples/benchmarks/bench_chunked_logprobs.py + skyrl/benchmarks/bench_chunked_logprobs.py """ import os @@ -17,10 +17,6 @@ import torch import torch.distributed as dist -# Must set NCCL env before initialising the process group -os.environ.setdefault("NCCL_NET", "Socket") -os.environ.setdefault("NCCL_NET_PLUGIN", "none") - VOCAB_SIZES = [32000, 64000, 128000] SEQ_LENS = [32768, 65536, 131072] # chunk_size=None routes through DistributedLogprob (no chunking); all others use @@ -163,29 +159,29 @@ def main(): torch.cuda.empty_cache() continue - # ----- warmup (skip OOM-prone chunk sizes gracefully) ----- - oom_chunk_sizes = set() + # ----- single pass: warmup + benchmark inline per chunk size ----- + results: dict[Optional[int], tuple[Optional[float], Optional[float]]] = {} for cs in CHUNK_SIZES: try: for _ in range(WARMUP_REPS): measure(logits, target, vocab_start_index, vocab_end_index, cs, tp_group, reps=1) + t_cs, mem_cs = measure( + logits, target, vocab_start_index, vocab_end_index, cs, tp_group, reps=BENCH_REPS + ) + results[cs] = (t_cs, mem_cs) except torch.OutOfMemoryError: - oom_chunk_sizes.add(cs) - torch.cuda.empty_cache() + results[cs] = (None, None) + finally: + torch.cuda.empty_cache() # isolate between chunk sizes - # ----- benchmark: collect baseline (no-chunk) first ----- - if None in oom_chunk_sizes: - t_baseline, mem_baseline = None, None - else: - t_baseline, mem_baseline = measure( - logits, target, vocab_start_index, vocab_end_index, None, tp_group, reps=BENCH_REPS - ) + t_baseline, mem_baseline = results[None] # ----- print one row per chunk_size ----- for cs in CHUNK_SIZES: cs_label = "None" if cs is None else str(cs) + t_cs, mem_cs = results[cs] - if cs in oom_chunk_sizes: + if t_cs is None: if dist.get_rank() == 0: print( f"{vocab_size:>10,} " @@ -198,27 +194,6 @@ def main(): ) continue - if cs is None: - t_cs, mem_cs = t_baseline, mem_baseline - else: - try: - t_cs, mem_cs = measure( - logits, target, vocab_start_index, vocab_end_index, cs, tp_group, reps=BENCH_REPS - ) - except torch.OutOfMemoryError: - torch.cuda.empty_cache() - if dist.get_rank() == 0: - print( - f"{vocab_size:>10,} " - f"{seq_len:>10,} " - f"{cs_label:>10} " - f"{'OOM':>{col_w}} " - f"{'OOM':>{col_w}} " - f"{'OOM':>{col_w}} " - f"{'OOM':>{col_w}}" - ) - continue - mem_cs_mb = mem_cs / (1024**2) if t_baseline is not None and t_cs > 0: speedup_str = f"{t_baseline / t_cs:>{col_w}.2f}x" diff --git a/skyrl/train/config/config.py b/skyrl/train/config/config.py index aae26ba315..6414548cf0 100644 --- a/skyrl/train/config/config.py +++ b/skyrl/train/config/config.py @@ -639,6 +639,10 @@ class TrainerConfig(BaseConfig): rope_theta: Optional[float] = None log_example_interval: int = 1 """Log an example prompt every N training steps, ``0``/``-1`` to disable""" + logprobs_chunk_size: int = 1024 + """Chunk size along the sequence dimension when computing log-probs from logits. + Reducing this lowers peak GPU memory at the cost of ~2x wall-clock time. + Set to None to disable chunking. See https://github.com/NovaSky-AI/SkyRL/pull/1610.""" def __post_init__(self): # ref model defaults to the policy model From 4fa27ebfe125ca73cdd0f49bd3a93f57b09be050 Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Thu, 7 May 2026 00:25:30 +0000 Subject: [PATCH 10/11] x Signed-off-by: SumanthRH --- skyrl/train/config/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skyrl/train/config/config.py b/skyrl/train/config/config.py index 6414548cf0..877422b66f 100644 --- a/skyrl/train/config/config.py +++ b/skyrl/train/config/config.py @@ -641,8 +641,8 @@ class TrainerConfig(BaseConfig): """Log an example prompt every N training steps, ``0``/``-1`` to disable""" logprobs_chunk_size: int = 1024 """Chunk size along the sequence dimension when computing log-probs from logits. - Reducing this lowers peak GPU memory at the cost of ~2x wall-clock time. - Set to None to disable chunking. See https://github.com/NovaSky-AI/SkyRL/pull/1610.""" + This lowers peak GPU memory at the cost of ~2x wall-clock time. + Set to None to disable chunking. See https://github.com/NovaSky-AI/SkyRL/pull/1610 for more details.""" def __post_init__(self): # ref model defaults to the policy model From 87aac06a51371569e1eef638c04ffd27b9a6bee4 Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Thu, 7 May 2026 00:26:27 +0000 Subject: [PATCH 11/11] x Signed-off-by: SumanthRH --- skyrl/backends/skyrl_train/utils/torch_utils.py | 1 + skyrl/train/config/config.py | 17 +++++++++++++++-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/skyrl/backends/skyrl_train/utils/torch_utils.py b/skyrl/backends/skyrl_train/utils/torch_utils.py index a7a368f074..3da9dc4077 100644 --- a/skyrl/backends/skyrl_train/utils/torch_utils.py +++ b/skyrl/backends/skyrl_train/utils/torch_utils.py @@ -70,6 +70,7 @@ def chunked_entropy_from_logits( requires_grad: Whether to enable gradient computation attention_mask: Optional attention mask of shape (batch_size, seqlen). When provided, entropy values for padded positions (mask=0) will be zeroed out. + chunk_size: Sequence dimension chunk size (must be a positive integer). Returns: Entropy tensor of shape (batch_size, seqlen). If attention_mask is provided, diff --git a/skyrl/train/config/config.py b/skyrl/train/config/config.py index 877422b66f..e895883774 100644 --- a/skyrl/train/config/config.py +++ b/skyrl/train/config/config.py @@ -639,16 +639,29 @@ class TrainerConfig(BaseConfig): rope_theta: Optional[float] = None log_example_interval: int = 1 """Log an example prompt every N training steps, ``0``/``-1`` to disable""" - logprobs_chunk_size: int = 1024 + logprobs_chunk_size: Optional[int] = 1024 """Chunk size along the sequence dimension when computing log-probs from logits. This lowers peak GPU memory at the cost of ~2x wall-clock time. - Set to None to disable chunking. See https://github.com/NovaSky-AI/SkyRL/pull/1610 for more details.""" + ``None`` disables chunking (Megatron backend only; FSDP requires a positive int). + See https://github.com/NovaSky-AI/SkyRL/pull/1610 for more details.""" def __post_init__(self): # ref model defaults to the policy model if self.ref.model.path is None: self.ref.model.path = self.policy.model.path + if self.logprobs_chunk_size is not None and ( + not isinstance(self.logprobs_chunk_size, int) or self.logprobs_chunk_size <= 0 + ): + raise ValueError( + f"logprobs_chunk_size must be a positive integer or None, got {self.logprobs_chunk_size!r}." + ) + if self.logprobs_chunk_size is None and self.strategy != "megatron": + raise ValueError( + "logprobs_chunk_size=None (no chunking) is only supported with the Megatron backend. " + f"Set a positive integer for strategy={self.strategy!r}." + ) + def validate_dict_keys_against_dataclass(datacls: Type[Any], d: dict): """