diff --git a/skyrl/backends/skyrl_train/utils/torch_utils.py b/skyrl/backends/skyrl_train/utils/torch_utils.py index 70ea95cc0b..3da9dc4077 100644 --- a/skyrl/backends/skyrl_train/utils/torch_utils.py +++ b/skyrl/backends/skyrl_train/utils/torch_utils.py @@ -28,16 +28,14 @@ FLASH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE = False -CHUNK_SIZE = 1024 - - def chunked_cross_entropy_from_log_probs( - logprobs: Float[torch.Tensor, "batch_size seqlen vocab_size"], requires_grad: bool = False + logprobs: Float[torch.Tensor, "batch_size seqlen vocab_size"], + requires_grad: bool = False, + chunk_size: int = 1024, ) -> Float[torch.Tensor, "batch_size seqlen"]: cm = nullcontext() if requires_grad else torch.no_grad() with cm: # Calculate entropy in chunks to avoid OOM - chunk_size = CHUNK_SIZE num_chunks = (logprobs.size(1) + chunk_size - 1) // chunk_size entropy_tensor = torch.zeros( (logprobs.shape[0], logprobs.shape[1]), dtype=logprobs.dtype, device=logprobs.device @@ -61,6 +59,7 @@ def chunked_entropy_from_logits( logits: Float[torch.Tensor, "batch_size seqlen vocab"], requires_grad: bool = False, attention_mask: Float[torch.Tensor, "batch_size seqlen"] = None, + chunk_size: int = 1024, ) -> Float[torch.Tensor, "batch_size seqlen"]: """Chunked entropy calculation from logits. @@ -71,6 +70,7 @@ def chunked_entropy_from_logits( requires_grad: Whether to enable gradient computation attention_mask: Optional attention mask of shape (batch_size, seqlen). When provided, entropy values for padded positions (mask=0) will be zeroed out. + chunk_size: Sequence dimension chunk size (must be a positive integer). Returns: Entropy tensor of shape (batch_size, seqlen). If attention_mask is provided, @@ -88,7 +88,6 @@ def chunked_entropy_from_logits( cm = nullcontext() if requires_grad else torch.no_grad() with cm: # Calculate entropy in chunks to avoid OOM - chunk_size = CHUNK_SIZE num_chunks = (logits.size(1) + chunk_size - 1) // chunk_size entropy_tensor = torch.zeros((logits.shape[0], logits.shape[1]), dtype=logits.dtype, device=logits.device) diff --git a/skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py b/skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py index a82440c91e..b8026d1854 100644 --- a/skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py +++ b/skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py @@ -193,6 +193,7 @@ def init_model(self, model_path, num_training_steps: int = None): model_config_kwargs=self.cfg.policy.model_config_kwargs, meta_init=use_meta, language_model_only=self.cfg.policy.language_model_only, + logprobs_chunk_size=self.cfg.logprobs_chunk_size, ) self._seq_parallel_monkey_patch(model=wrapped_model.model) @@ -415,6 +416,7 @@ def init_model(self, model_path): model_config_kwargs=self.cfg.ref.model_config_kwargs, meta_init=use_meta, language_model_only=self.cfg.ref.language_model_only, + logprobs_chunk_size=self.cfg.logprobs_chunk_size, ) self._seq_parallel_monkey_patch(model=wrapped_model.model) diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py index 152348590a..f2d1aa0a44 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py @@ -104,7 +104,7 @@ def collection_func(logits, data): tp_group=tp_grp, inference_only=True, cp_group=None, # we handle cp gathering in `postprocess_packed_seqs` - chunk_size=None, + chunk_size=self.cfg.logprobs_chunk_size, # chunk seq dim to bound peak memory ) return torch.tensor(0.0, device=token_logprobs.device), {"log_probs": token_logprobs} @@ -264,7 +264,7 @@ def loss_func(logits, data): tp_group=tp_grp, inference_only=False, cp_group=None, # we handle cp gathering in `postprocess_packed_seqs` - chunk_size=None, + chunk_size=self.cfg.logprobs_chunk_size, # chunk seq dim to bound peak memory ) action_log_probs = token_logprobs[:, -num_actions:] diff --git a/skyrl/backends/skyrl_train/workers/model_wrapper.py b/skyrl/backends/skyrl_train/workers/model_wrapper.py index 53c876f73e..ba2069f27a 100644 --- a/skyrl/backends/skyrl_train/workers/model_wrapper.py +++ b/skyrl/backends/skyrl_train/workers/model_wrapper.py @@ -80,6 +80,7 @@ def __init__( model_config_kwargs: dict = {}, meta_init: bool = False, language_model_only: bool = False, + logprobs_chunk_size: int = 1024, **kwargs, ) -> None: super().__init__() @@ -218,6 +219,8 @@ def __init__( else: self.model = pretrain_or_model + self.logprobs_chunk_size = logprobs_chunk_size + # TODO (sumanthrh): do the same for `logprobs_from_logits` and test. # Credits: https://www.tylerromero.com/posts/2025-02-selective-log-softmax/#efficient-solution self.chunked_entropy_from_logits_fn = ( @@ -351,7 +354,10 @@ def forward( entropy_mask = attention_mask_fwd entropy_BS = self.chunked_entropy_from_logits_fn( - logits_BSV, requires_grad=entropy_requires_grad, attention_mask=entropy_mask + logits_BSV, + requires_grad=entropy_requires_grad, + attention_mask=entropy_mask, + chunk_size=self.logprobs_chunk_size, ) if self.sequence_parallel_size > 1: diff --git a/skyrl/benchmarks/bench_chunked_logprobs.py b/skyrl/benchmarks/bench_chunked_logprobs.py new file mode 100644 index 0000000000..1c149ff2ba --- /dev/null +++ b/skyrl/benchmarks/bench_chunked_logprobs.py @@ -0,0 +1,232 @@ +""" +Benchmark: chunked vs non-chunked logprob computation. + +Tests ChunkedDistributedLogprob and DistributedLogprob from +skyrl/backends/skyrl_train/distributed/megatron/model_utils.py, +which is the actual code path used in SkyRL's MegatronModelWrapper. + +Usage (single GPU, torchrun required for distributed init): + uv run --isolated --extra megatron torchrun --nproc_per_node=1 \\ + skyrl/benchmarks/bench_chunked_logprobs.py +""" + +import os +import time +from typing import Optional + +import torch +import torch.distributed as dist + +VOCAB_SIZES = [32000, 64000, 128000] +SEQ_LENS = [32768, 65536, 131072] +# chunk_size=None routes through DistributedLogprob (no chunking); all others use +# ChunkedDistributedLogprob with the given chunk size. +CHUNK_SIZES = [None, 32, 1024, 4096, 8192, 16384] +WARMUP_REPS = 2 +BENCH_REPS = 5 + + +def measure( + vocab_parallel_logits: torch.Tensor, + target: torch.Tensor, + vocab_start_index: int, + vocab_end_index: int, + chunk_size: Optional[int], + tp_group: torch.distributed.ProcessGroup, + reps: int, +): + """Run forward+backward through the real SkyRL logprob kernel. + + Returns (mean_wall_ms, mean_peak_mem_bytes). + """ + from skyrl.backends.skyrl_train.distributed.megatron.model_utils import ( + ChunkedDistributedLogprob, + DistributedLogprob, + ) + + device = vocab_parallel_logits.device + times = [] + peak_mems = [] + + for _ in range(reps): + # Fresh leaf tensor each rep so grad accumulation does not interfere. + logits_rep = vocab_parallel_logits.detach().requires_grad_(True) + + torch.cuda.reset_peak_memory_stats(device) + torch.cuda.synchronize(device) + t0 = time.perf_counter() + + if chunk_size is None: + # Non-chunked real implementation (DistributedLogprob) + out = DistributedLogprob.apply( + logits_rep, + target, + vocab_start_index, + vocab_end_index, + tp_group, + False, # inference_only=False -> saves tensors for backward + ) + else: + # Chunked real implementation (ChunkedDistributedLogprob) + out = ChunkedDistributedLogprob.apply( + logits_rep, + target, + vocab_start_index, + vocab_end_index, + chunk_size, + tp_group, + False, # inference_only=False -> saves tensors for backward + ) + + loss = out.sum() + loss.backward() + + torch.cuda.synchronize(device) + t1 = time.perf_counter() + + times.append((t1 - t0) * 1000.0) + peak_mems.append(torch.cuda.max_memory_allocated(device)) + + return sum(times) / len(times), sum(peak_mems) / len(peak_mems) + + +def main(): + # --- Distributed init (required by the real SkyRL kernel) --- + dist.init_process_group(backend="nccl") + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + + # Initialise Megatron model-parallel state (TP=1, single GPU). + import megatron.core.parallel_state as mpu + + mpu.initialize_model_parallel(tensor_model_parallel_size=1) + + tp_group = dist.group.WORLD # TP=1, so the whole-world group is the TP group + + device = torch.device("cuda", local_rank) + + if dist.get_rank() == 0: + print(f"Device : {torch.cuda.get_device_name(device)}") + print(f"World size : {dist.get_world_size()}") + print( + f"Vocab sizes : {VOCAB_SIZES} | chunk_sizes={CHUNK_SIZES} " + f"| warmup={WARMUP_REPS} bench={BENCH_REPS}" + ) + print("Implementation: real SkyRL ChunkedDistributedLogprob / DistributedLogprob\n") + + col_w = 14 + header = ( + f"{'vocab_size':>10} " + f"{'seq_len':>10} " + f"{'chunk_size':>10} " + f"{'time ms':>{col_w}} " + f"{'peak MB':>{col_w}} " + f"{'vs no-chunk':>{col_w}} " + f"{'mem saved MB':>{col_w}}" + ) + sep = "-" * len(header) + + for vocab_size in VOCAB_SIZES: + if dist.get_rank() == 0: + print(f"\n=== vocab_size={vocab_size:,} ===") + print(header) + print(sep) + + for seq_len in SEQ_LENS: + # With TP=1 the full vocab lives on this rank. + vocab_start_index = 0 + vocab_end_index = vocab_size + + # Shape expected by the real kernel: [batch, seq, vocab // TP] + # We use batch=1 to keep allocations comparable to a single-sequence workload. + try: + logits = torch.randn(1, seq_len, vocab_size, dtype=torch.bfloat16, device=device) + # targets: [batch, seq], values in [0, vocab_size) + target = torch.randint(0, vocab_size, (1, seq_len), device=device) + except torch.OutOfMemoryError: + if dist.get_rank() == 0: + oom_row = ( + f"{vocab_size:>10,} " + f"{seq_len:>10,} " + f"{'(all)':>10} " + f"{'OOM':>{col_w}} " + f"{'OOM':>{col_w}} " + f"{'OOM':>{col_w}} " + f"{'OOM':>{col_w}}" + ) + print(oom_row) + print(sep) + torch.cuda.empty_cache() + continue + + # ----- single pass: warmup + benchmark inline per chunk size ----- + results: dict[Optional[int], tuple[Optional[float], Optional[float]]] = {} + for cs in CHUNK_SIZES: + try: + for _ in range(WARMUP_REPS): + measure(logits, target, vocab_start_index, vocab_end_index, cs, tp_group, reps=1) + t_cs, mem_cs = measure( + logits, target, vocab_start_index, vocab_end_index, cs, tp_group, reps=BENCH_REPS + ) + results[cs] = (t_cs, mem_cs) + except torch.OutOfMemoryError: + results[cs] = (None, None) + finally: + torch.cuda.empty_cache() # isolate between chunk sizes + + t_baseline, mem_baseline = results[None] + + # ----- print one row per chunk_size ----- + for cs in CHUNK_SIZES: + cs_label = "None" if cs is None else str(cs) + t_cs, mem_cs = results[cs] + + if t_cs is None: + if dist.get_rank() == 0: + print( + f"{vocab_size:>10,} " + f"{seq_len:>10,} " + f"{cs_label:>10} " + f"{'OOM':>{col_w}} " + f"{'OOM':>{col_w}} " + f"{'OOM':>{col_w}} " + f"{'OOM':>{col_w}}" + ) + continue + + mem_cs_mb = mem_cs / (1024**2) + if t_baseline is not None and t_cs > 0: + speedup_str = f"{t_baseline / t_cs:>{col_w}.2f}x" + mem_saved_str = f"{mem_baseline / (1024**2) - mem_cs_mb:>{col_w}.0f}" + else: + speedup_str = f"{'N/A':>{col_w}}" + mem_saved_str = f"{'N/A':>{col_w}}" + + if dist.get_rank() == 0: + print( + f"{vocab_size:>10,} " + f"{seq_len:>10,} " + f"{cs_label:>10} " + f"{t_cs:>{col_w}.1f} " + f"{mem_cs_mb:>{col_w}.0f} " + f"{speedup_str} " + f"{mem_saved_str}" + ) + + if dist.get_rank() == 0: + print(sep) + + # Free memory before next seq_len + del logits, target + torch.cuda.empty_cache() + + if dist.get_rank() == 0: + print("\nAll times are mean wall-clock (ms) over forward+backward passes.") + print("vs no-chunk: speedup relative to chunk_size=None (>1 = faster).") + print("chunk_size=None uses DistributedLogprob; all others use ChunkedDistributedLogprob.") + + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/skyrl/train/config/config.py b/skyrl/train/config/config.py index aae26ba315..e895883774 100644 --- a/skyrl/train/config/config.py +++ b/skyrl/train/config/config.py @@ -639,12 +639,29 @@ class TrainerConfig(BaseConfig): rope_theta: Optional[float] = None log_example_interval: int = 1 """Log an example prompt every N training steps, ``0``/``-1`` to disable""" + logprobs_chunk_size: Optional[int] = 1024 + """Chunk size along the sequence dimension when computing log-probs from logits. + This lowers peak GPU memory at the cost of ~2x wall-clock time. + ``None`` disables chunking (Megatron backend only; FSDP requires a positive int). + See https://github.com/NovaSky-AI/SkyRL/pull/1610 for more details.""" def __post_init__(self): # ref model defaults to the policy model if self.ref.model.path is None: self.ref.model.path = self.policy.model.path + if self.logprobs_chunk_size is not None and ( + not isinstance(self.logprobs_chunk_size, int) or self.logprobs_chunk_size <= 0 + ): + raise ValueError( + f"logprobs_chunk_size must be a positive integer or None, got {self.logprobs_chunk_size!r}." + ) + if self.logprobs_chunk_size is None and self.strategy != "megatron": + raise ValueError( + "logprobs_chunk_size=None (no chunking) is only supported with the Megatron backend. " + f"Set a positive integer for strategy={self.strategy!r}." + ) + def validate_dict_keys_against_dataclass(datacls: Type[Any], d: dict): """