-
Notifications
You must be signed in to change notification settings - Fork 346
[train] Default to chunked logprobs for Megatron #1610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8745725
9ee9759
b8712e3
198313a
1dd293d
4d04fa4
9e08a17
c2e6d66
ca1976f
4fa27eb
87aac06
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -28,16 +28,14 @@ | |||||||||||||||||||||
| FLASH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE = False | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| CHUNK_SIZE = 1024 | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def chunked_cross_entropy_from_log_probs( | ||||||||||||||||||||||
| logprobs: Float[torch.Tensor, "batch_size seqlen vocab_size"], requires_grad: bool = False | ||||||||||||||||||||||
| logprobs: Float[torch.Tensor, "batch_size seqlen vocab_size"], | ||||||||||||||||||||||
| requires_grad: bool = False, | ||||||||||||||||||||||
| chunk_size: int = 1024, | ||||||||||||||||||||||
| ) -> Float[torch.Tensor, "batch_size seqlen"]: | ||||||||||||||||||||||
| cm = nullcontext() if requires_grad else torch.no_grad() | ||||||||||||||||||||||
| with cm: | ||||||||||||||||||||||
| # Calculate entropy in chunks to avoid OOM | ||||||||||||||||||||||
| chunk_size = CHUNK_SIZE | ||||||||||||||||||||||
| num_chunks = (logprobs.size(1) + chunk_size - 1) // chunk_size | ||||||||||||||||||||||
|
Comment on lines
31
to
39
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The def chunked_cross_entropy_from_log_probs(
logprobs: Float[torch.Tensor, "batch_size seqlen vocab_size"],
requires_grad: bool = False,
chunk_size: int | None = 1024,
) -> Float[torch.Tensor, "batch_size seqlen"]:
cm = nullcontext() if requires_grad else torch.no_grad()
with cm:
# Calculate entropy in chunks to avoid OOM
chunk_size = chunk_size or logprobs.size(1)
num_chunks = (logprobs.size(1) + chunk_size - 1) // chunk_size |
||||||||||||||||||||||
| entropy_tensor = torch.zeros( | ||||||||||||||||||||||
| (logprobs.shape[0], logprobs.shape[1]), dtype=logprobs.dtype, device=logprobs.device | ||||||||||||||||||||||
|
|
@@ -61,6 +59,7 @@ def chunked_entropy_from_logits( | |||||||||||||||||||||
| logits: Float[torch.Tensor, "batch_size seqlen vocab"], | ||||||||||||||||||||||
| requires_grad: bool = False, | ||||||||||||||||||||||
| attention_mask: Float[torch.Tensor, "batch_size seqlen"] = None, | ||||||||||||||||||||||
| chunk_size: int = 1024, | ||||||||||||||||||||||
| ) -> Float[torch.Tensor, "batch_size seqlen"]: | ||||||||||||||||||||||
|
Comment on lines
59
to
63
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update the type hint for
Suggested change
|
||||||||||||||||||||||
| """Chunked entropy calculation from logits. | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
@@ -71,6 +70,7 @@ def chunked_entropy_from_logits( | |||||||||||||||||||||
| requires_grad: Whether to enable gradient computation | ||||||||||||||||||||||
| attention_mask: Optional attention mask of shape (batch_size, seqlen). When provided, | ||||||||||||||||||||||
| entropy values for padded positions (mask=0) will be zeroed out. | ||||||||||||||||||||||
| chunk_size: Sequence dimension chunk size (must be a positive integer). | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||
| Entropy tensor of shape (batch_size, seqlen). If attention_mask is provided, | ||||||||||||||||||||||
|
|
@@ -88,7 +88,6 @@ def chunked_entropy_from_logits( | |||||||||||||||||||||
| cm = nullcontext() if requires_grad else torch.no_grad() | ||||||||||||||||||||||
| with cm: | ||||||||||||||||||||||
| # Calculate entropy in chunks to avoid OOM | ||||||||||||||||||||||
| chunk_size = CHUNK_SIZE | ||||||||||||||||||||||
| num_chunks = (logits.size(1) + chunk_size - 1) // chunk_size | ||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ensure
Suggested change
|
||||||||||||||||||||||
| entropy_tensor = torch.zeros((logits.shape[0], logits.shape[1]), dtype=logits.dtype, device=logits.device) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -80,6 +80,7 @@ def __init__( | |||||||||||||||||||||||||
| model_config_kwargs: dict = {}, | ||||||||||||||||||||||||||
| meta_init: bool = False, | ||||||||||||||||||||||||||
| language_model_only: bool = False, | ||||||||||||||||||||||||||
| logprobs_chunk_size: int = 1024, | ||||||||||||||||||||||||||
| **kwargs, | ||||||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||||||
|
Comment on lines
80
to
85
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||||||||||||||||||
| super().__init__() | ||||||||||||||||||||||||||
|
|
@@ -218,6 +219,8 @@ def __init__( | |||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| self.model = pretrain_or_model | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| self.logprobs_chunk_size = logprobs_chunk_size | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # TODO (sumanthrh): do the same for `logprobs_from_logits` and test. | ||||||||||||||||||||||||||
| # Credits: https://www.tylerromero.com/posts/2025-02-selective-log-softmax/#efficient-solution | ||||||||||||||||||||||||||
| self.chunked_entropy_from_logits_fn = ( | ||||||||||||||||||||||||||
|
|
@@ -351,7 +354,10 @@ def forward( | |||||||||||||||||||||||||
| entropy_mask = attention_mask_fwd | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| entropy_BS = self.chunked_entropy_from_logits_fn( | ||||||||||||||||||||||||||
| logits_BSV, requires_grad=entropy_requires_grad, attention_mask=entropy_mask | ||||||||||||||||||||||||||
| logits_BSV, | ||||||||||||||||||||||||||
| requires_grad=entropy_requires_grad, | ||||||||||||||||||||||||||
| attention_mask=entropy_mask, | ||||||||||||||||||||||||||
| chunk_size=self.logprobs_chunk_size, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if self.sequence_parallel_size > 1: | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,232 @@ | ||
| """ | ||
| Benchmark: chunked vs non-chunked logprob computation. | ||
|
|
||
| Tests ChunkedDistributedLogprob and DistributedLogprob from | ||
| skyrl/backends/skyrl_train/distributed/megatron/model_utils.py, | ||
| which is the actual code path used in SkyRL's MegatronModelWrapper. | ||
|
|
||
| Usage (single GPU, torchrun required for distributed init): | ||
| uv run --isolated --extra megatron torchrun --nproc_per_node=1 \\ | ||
| skyrl/benchmarks/bench_chunked_logprobs.py | ||
| """ | ||
|
|
||
| import os | ||
| import time | ||
| from typing import Optional | ||
|
|
||
| import torch | ||
| import torch.distributed as dist | ||
|
|
||
| VOCAB_SIZES = [32000, 64000, 128000] | ||
| SEQ_LENS = [32768, 65536, 131072] | ||
| # chunk_size=None routes through DistributedLogprob (no chunking); all others use | ||
| # ChunkedDistributedLogprob with the given chunk size. | ||
| CHUNK_SIZES = [None, 32, 1024, 4096, 8192, 16384] | ||
| WARMUP_REPS = 2 | ||
| BENCH_REPS = 5 | ||
|
|
||
|
|
||
| def measure( | ||
| vocab_parallel_logits: torch.Tensor, | ||
| target: torch.Tensor, | ||
| vocab_start_index: int, | ||
| vocab_end_index: int, | ||
| chunk_size: Optional[int], | ||
| tp_group: torch.distributed.ProcessGroup, | ||
| reps: int, | ||
| ): | ||
| """Run forward+backward through the real SkyRL logprob kernel. | ||
|
|
||
| Returns (mean_wall_ms, mean_peak_mem_bytes). | ||
| """ | ||
| from skyrl.backends.skyrl_train.distributed.megatron.model_utils import ( | ||
| ChunkedDistributedLogprob, | ||
| DistributedLogprob, | ||
| ) | ||
|
|
||
| device = vocab_parallel_logits.device | ||
| times = [] | ||
| peak_mems = [] | ||
|
|
||
| for _ in range(reps): | ||
| # Fresh leaf tensor each rep so grad accumulation does not interfere. | ||
| logits_rep = vocab_parallel_logits.detach().requires_grad_(True) | ||
|
|
||
| torch.cuda.reset_peak_memory_stats(device) | ||
| torch.cuda.synchronize(device) | ||
| t0 = time.perf_counter() | ||
|
|
||
| if chunk_size is None: | ||
| # Non-chunked real implementation (DistributedLogprob) | ||
| out = DistributedLogprob.apply( | ||
| logits_rep, | ||
| target, | ||
| vocab_start_index, | ||
| vocab_end_index, | ||
| tp_group, | ||
| False, # inference_only=False -> saves tensors for backward | ||
| ) | ||
| else: | ||
| # Chunked real implementation (ChunkedDistributedLogprob) | ||
| out = ChunkedDistributedLogprob.apply( | ||
| logits_rep, | ||
| target, | ||
| vocab_start_index, | ||
| vocab_end_index, | ||
| chunk_size, | ||
| tp_group, | ||
| False, # inference_only=False -> saves tensors for backward | ||
| ) | ||
|
|
||
| loss = out.sum() | ||
| loss.backward() | ||
|
|
||
| torch.cuda.synchronize(device) | ||
| t1 = time.perf_counter() | ||
|
|
||
| times.append((t1 - t0) * 1000.0) | ||
| peak_mems.append(torch.cuda.max_memory_allocated(device)) | ||
|
|
||
| return sum(times) / len(times), sum(peak_mems) / len(peak_mems) | ||
|
|
||
|
|
||
| def main(): | ||
| # --- Distributed init (required by the real SkyRL kernel) --- | ||
| dist.init_process_group(backend="nccl") | ||
| local_rank = int(os.environ.get("LOCAL_RANK", 0)) | ||
| torch.cuda.set_device(local_rank) | ||
|
|
||
| # Initialise Megatron model-parallel state (TP=1, single GPU). | ||
| import megatron.core.parallel_state as mpu | ||
|
|
||
| mpu.initialize_model_parallel(tensor_model_parallel_size=1) | ||
|
|
||
| tp_group = dist.group.WORLD # TP=1, so the whole-world group is the TP group | ||
|
|
||
| device = torch.device("cuda", local_rank) | ||
|
|
||
| if dist.get_rank() == 0: | ||
| print(f"Device : {torch.cuda.get_device_name(device)}") | ||
| print(f"World size : {dist.get_world_size()}") | ||
| print( | ||
| f"Vocab sizes : {VOCAB_SIZES} | chunk_sizes={CHUNK_SIZES} " | ||
| f"| warmup={WARMUP_REPS} bench={BENCH_REPS}" | ||
| ) | ||
| print("Implementation: real SkyRL ChunkedDistributedLogprob / DistributedLogprob\n") | ||
|
|
||
| col_w = 14 | ||
| header = ( | ||
| f"{'vocab_size':>10} " | ||
| f"{'seq_len':>10} " | ||
| f"{'chunk_size':>10} " | ||
| f"{'time ms':>{col_w}} " | ||
| f"{'peak MB':>{col_w}} " | ||
| f"{'vs no-chunk':>{col_w}} " | ||
| f"{'mem saved MB':>{col_w}}" | ||
| ) | ||
| sep = "-" * len(header) | ||
|
|
||
| for vocab_size in VOCAB_SIZES: | ||
| if dist.get_rank() == 0: | ||
| print(f"\n=== vocab_size={vocab_size:,} ===") | ||
| print(header) | ||
| print(sep) | ||
|
|
||
| for seq_len in SEQ_LENS: | ||
| # With TP=1 the full vocab lives on this rank. | ||
| vocab_start_index = 0 | ||
| vocab_end_index = vocab_size | ||
|
|
||
| # Shape expected by the real kernel: [batch, seq, vocab // TP] | ||
| # We use batch=1 to keep allocations comparable to a single-sequence workload. | ||
| try: | ||
| logits = torch.randn(1, seq_len, vocab_size, dtype=torch.bfloat16, device=device) | ||
| # targets: [batch, seq], values in [0, vocab_size) | ||
| target = torch.randint(0, vocab_size, (1, seq_len), device=device) | ||
| except torch.OutOfMemoryError: | ||
| if dist.get_rank() == 0: | ||
| oom_row = ( | ||
| f"{vocab_size:>10,} " | ||
| f"{seq_len:>10,} " | ||
| f"{'(all)':>10} " | ||
| f"{'OOM':>{col_w}} " | ||
| f"{'OOM':>{col_w}} " | ||
| f"{'OOM':>{col_w}} " | ||
| f"{'OOM':>{col_w}}" | ||
| ) | ||
| print(oom_row) | ||
| print(sep) | ||
| torch.cuda.empty_cache() | ||
| continue | ||
|
|
||
| # ----- single pass: warmup + benchmark inline per chunk size ----- | ||
| results: dict[Optional[int], tuple[Optional[float], Optional[float]]] = {} | ||
| for cs in CHUNK_SIZES: | ||
| try: | ||
| for _ in range(WARMUP_REPS): | ||
| measure(logits, target, vocab_start_index, vocab_end_index, cs, tp_group, reps=1) | ||
| t_cs, mem_cs = measure( | ||
| logits, target, vocab_start_index, vocab_end_index, cs, tp_group, reps=BENCH_REPS | ||
| ) | ||
| results[cs] = (t_cs, mem_cs) | ||
| except torch.OutOfMemoryError: | ||
| results[cs] = (None, None) | ||
| finally: | ||
| torch.cuda.empty_cache() # isolate between chunk sizes | ||
|
|
||
| t_baseline, mem_baseline = results[None] | ||
|
|
||
| # ----- print one row per chunk_size ----- | ||
| for cs in CHUNK_SIZES: | ||
| cs_label = "None" if cs is None else str(cs) | ||
| t_cs, mem_cs = results[cs] | ||
|
|
||
| if t_cs is None: | ||
| if dist.get_rank() == 0: | ||
| print( | ||
| f"{vocab_size:>10,} " | ||
| f"{seq_len:>10,} " | ||
| f"{cs_label:>10} " | ||
| f"{'OOM':>{col_w}} " | ||
| f"{'OOM':>{col_w}} " | ||
| f"{'OOM':>{col_w}} " | ||
| f"{'OOM':>{col_w}}" | ||
| ) | ||
| continue | ||
|
|
||
| mem_cs_mb = mem_cs / (1024**2) | ||
| if t_baseline is not None and t_cs > 0: | ||
| speedup_str = f"{t_baseline / t_cs:>{col_w}.2f}x" | ||
| mem_saved_str = f"{mem_baseline / (1024**2) - mem_cs_mb:>{col_w}.0f}" | ||
| else: | ||
| speedup_str = f"{'N/A':>{col_w}}" | ||
| mem_saved_str = f"{'N/A':>{col_w}}" | ||
|
|
||
| if dist.get_rank() == 0: | ||
| print( | ||
| f"{vocab_size:>10,} " | ||
| f"{seq_len:>10,} " | ||
| f"{cs_label:>10} " | ||
| f"{t_cs:>{col_w}.1f} " | ||
| f"{mem_cs_mb:>{col_w}.0f} " | ||
| f"{speedup_str} " | ||
| f"{mem_saved_str}" | ||
| ) | ||
|
|
||
| if dist.get_rank() == 0: | ||
| print(sep) | ||
|
|
||
| # Free memory before next seq_len | ||
| del logits, target | ||
| torch.cuda.empty_cache() | ||
|
|
||
| if dist.get_rank() == 0: | ||
| print("\nAll times are mean wall-clock (ms) over forward+backward passes.") | ||
| print("vs no-chunk: speedup relative to chunk_size=None (>1 = faster).") | ||
|
SumanthRH marked this conversation as resolved.
|
||
| print("chunk_size=None uses DistributedLogprob; all others use ChunkedDistributedLogprob.") | ||
|
|
||
| dist.destroy_process_group() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
Uh oh!
There was an error while loading. Please reload this page.