Skip to content
11 changes: 5 additions & 6 deletions skyrl/backends/skyrl_train/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,14 @@
FLASH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE = False


CHUNK_SIZE = 1024


def chunked_cross_entropy_from_log_probs(
logprobs: Float[torch.Tensor, "batch_size seqlen vocab_size"], requires_grad: bool = False
logprobs: Float[torch.Tensor, "batch_size seqlen vocab_size"],
requires_grad: bool = False,
chunk_size: int = 1024,
) -> Float[torch.Tensor, "batch_size seqlen"]:
Comment thread
SumanthRH marked this conversation as resolved.
cm = nullcontext() if requires_grad else torch.no_grad()
with cm:
# Calculate entropy in chunks to avoid OOM
chunk_size = CHUNK_SIZE
num_chunks = (logprobs.size(1) + chunk_size - 1) // chunk_size
Comment on lines 31 to 39
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The TrainerConfig allows logprobs_chunk_size to be None to disable chunking. However, if chunk_size is None, the calculation of num_chunks will raise a TypeError. We should handle the None case by defaulting to the full sequence length and update the type hint to int | None.

def chunked_cross_entropy_from_log_probs(
    logprobs: Float[torch.Tensor, "batch_size seqlen vocab_size"],
    requires_grad: bool = False,
    chunk_size: int | None = 1024,
) -> Float[torch.Tensor, "batch_size seqlen"]:
    cm = nullcontext() if requires_grad else torch.no_grad()
    with cm:
        # Calculate entropy in chunks to avoid OOM
        chunk_size = chunk_size or logprobs.size(1)
        num_chunks = (logprobs.size(1) + chunk_size - 1) // chunk_size

entropy_tensor = torch.zeros(
(logprobs.shape[0], logprobs.shape[1]), dtype=logprobs.dtype, device=logprobs.device
Expand All @@ -61,6 +59,7 @@ def chunked_entropy_from_logits(
logits: Float[torch.Tensor, "batch_size seqlen vocab"],
requires_grad: bool = False,
attention_mask: Float[torch.Tensor, "batch_size seqlen"] = None,
chunk_size: int = 1024,
) -> Float[torch.Tensor, "batch_size seqlen"]:
Comment on lines 59 to 63
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Update the type hint for chunk_size to int | None to align with the configuration which allows disabling chunking by setting it to None.

Suggested change
logits: Float[torch.Tensor, "batch_size seqlen vocab"],
requires_grad: bool = False,
attention_mask: Float[torch.Tensor, "batch_size seqlen"] = None,
chunk_size: int = 1024,
) -> Float[torch.Tensor, "batch_size seqlen"]:
logits: Float[torch.Tensor, "batch_size seqlen vocab"],
requires_grad: bool = False,
attention_mask: Float[torch.Tensor, "batch_size seqlen"] = None,
chunk_size: int | None = 1024,
) -> Float[torch.Tensor, "batch_size seqlen"]:

"""Chunked entropy calculation from logits.

Expand All @@ -71,6 +70,7 @@ def chunked_entropy_from_logits(
requires_grad: Whether to enable gradient computation
attention_mask: Optional attention mask of shape (batch_size, seqlen). When provided,
entropy values for padded positions (mask=0) will be zeroed out.
chunk_size: Sequence dimension chunk size (must be a positive integer).

Returns:
Entropy tensor of shape (batch_size, seqlen). If attention_mask is provided,
Expand All @@ -88,7 +88,6 @@ def chunked_entropy_from_logits(
cm = nullcontext() if requires_grad else torch.no_grad()
with cm:
# Calculate entropy in chunks to avoid OOM
chunk_size = CHUNK_SIZE
num_chunks = (logits.size(1) + chunk_size - 1) // chunk_size
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Ensure chunk_size is a valid integer before calculating num_chunks to avoid a crash when chunking is disabled (chunk_size=None).

Suggested change
num_chunks = (logits.size(1) + chunk_size - 1) // chunk_size
chunk_size = chunk_size or logits.size(1)
num_chunks = (logits.size(1) + chunk_size - 1) // chunk_size

entropy_tensor = torch.zeros((logits.shape[0], logits.shape[1]), dtype=logits.dtype, device=logits.device)

Expand Down
2 changes: 2 additions & 0 deletions skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def init_model(self, model_path, num_training_steps: int = None):
model_config_kwargs=self.cfg.policy.model_config_kwargs,
meta_init=use_meta,
language_model_only=self.cfg.policy.language_model_only,
logprobs_chunk_size=self.cfg.logprobs_chunk_size,
)
self._seq_parallel_monkey_patch(model=wrapped_model.model)

Expand Down Expand Up @@ -415,6 +416,7 @@ def init_model(self, model_path):
model_config_kwargs=self.cfg.ref.model_config_kwargs,
meta_init=use_meta,
language_model_only=self.cfg.ref.language_model_only,
logprobs_chunk_size=self.cfg.logprobs_chunk_size,
)
self._seq_parallel_monkey_patch(model=wrapped_model.model)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def collection_func(logits, data):
tp_group=tp_grp,
inference_only=True,
cp_group=None, # we handle cp gathering in `postprocess_packed_seqs`
chunk_size=None,
chunk_size=self.cfg.logprobs_chunk_size, # chunk seq dim to bound peak memory
)
return torch.tensor(0.0, device=token_logprobs.device), {"log_probs": token_logprobs}

Expand Down Expand Up @@ -264,7 +264,7 @@ def loss_func(logits, data):
tp_group=tp_grp,
inference_only=False,
cp_group=None, # we handle cp gathering in `postprocess_packed_seqs`
chunk_size=None,
chunk_size=self.cfg.logprobs_chunk_size, # chunk seq dim to bound peak memory
)

action_log_probs = token_logprobs[:, -num_actions:]
Expand Down
8 changes: 7 additions & 1 deletion skyrl/backends/skyrl_train/workers/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(
model_config_kwargs: dict = {},
meta_init: bool = False,
language_model_only: bool = False,
logprobs_chunk_size: int = 1024,
**kwargs,
) -> None:
Comment on lines 80 to 85
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logprobs_chunk_size parameter should be typed as int | None to match the TrainerConfig and allow disabling chunking.

Suggested change
model_config_kwargs: dict = {},
meta_init: bool = False,
language_model_only: bool = False,
logprobs_chunk_size: int = 1024,
**kwargs,
) -> None:
model_config_kwargs: dict = {},
meta_init: bool = False,
language_model_only: bool = False,
logprobs_chunk_size: int | None = 1024,
**kwargs,
) -> None:

super().__init__()
Expand Down Expand Up @@ -218,6 +219,8 @@ def __init__(
else:
self.model = pretrain_or_model

self.logprobs_chunk_size = logprobs_chunk_size

# TODO (sumanthrh): do the same for `logprobs_from_logits` and test.
# Credits: https://www.tylerromero.com/posts/2025-02-selective-log-softmax/#efficient-solution
self.chunked_entropy_from_logits_fn = (
Expand Down Expand Up @@ -351,7 +354,10 @@ def forward(
entropy_mask = attention_mask_fwd

entropy_BS = self.chunked_entropy_from_logits_fn(
logits_BSV, requires_grad=entropy_requires_grad, attention_mask=entropy_mask
logits_BSV,
requires_grad=entropy_requires_grad,
attention_mask=entropy_mask,
chunk_size=self.logprobs_chunk_size,
)

if self.sequence_parallel_size > 1:
Expand Down
232 changes: 232 additions & 0 deletions skyrl/benchmarks/bench_chunked_logprobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
"""
Benchmark: chunked vs non-chunked logprob computation.

Tests ChunkedDistributedLogprob and DistributedLogprob from
skyrl/backends/skyrl_train/distributed/megatron/model_utils.py,
which is the actual code path used in SkyRL's MegatronModelWrapper.

Usage (single GPU, torchrun required for distributed init):
uv run --isolated --extra megatron torchrun --nproc_per_node=1 \\
skyrl/benchmarks/bench_chunked_logprobs.py
"""

import os
import time
from typing import Optional

import torch
import torch.distributed as dist

VOCAB_SIZES = [32000, 64000, 128000]
SEQ_LENS = [32768, 65536, 131072]
# chunk_size=None routes through DistributedLogprob (no chunking); all others use
# ChunkedDistributedLogprob with the given chunk size.
CHUNK_SIZES = [None, 32, 1024, 4096, 8192, 16384]
WARMUP_REPS = 2
BENCH_REPS = 5


def measure(
vocab_parallel_logits: torch.Tensor,
target: torch.Tensor,
vocab_start_index: int,
vocab_end_index: int,
chunk_size: Optional[int],
tp_group: torch.distributed.ProcessGroup,
reps: int,
):
"""Run forward+backward through the real SkyRL logprob kernel.

Returns (mean_wall_ms, mean_peak_mem_bytes).
"""
from skyrl.backends.skyrl_train.distributed.megatron.model_utils import (
ChunkedDistributedLogprob,
DistributedLogprob,
)

device = vocab_parallel_logits.device
times = []
peak_mems = []

for _ in range(reps):
# Fresh leaf tensor each rep so grad accumulation does not interfere.
logits_rep = vocab_parallel_logits.detach().requires_grad_(True)

torch.cuda.reset_peak_memory_stats(device)
torch.cuda.synchronize(device)
t0 = time.perf_counter()

if chunk_size is None:
# Non-chunked real implementation (DistributedLogprob)
out = DistributedLogprob.apply(
logits_rep,
target,
vocab_start_index,
vocab_end_index,
tp_group,
False, # inference_only=False -> saves tensors for backward
)
else:
# Chunked real implementation (ChunkedDistributedLogprob)
out = ChunkedDistributedLogprob.apply(
logits_rep,
target,
vocab_start_index,
vocab_end_index,
chunk_size,
tp_group,
False, # inference_only=False -> saves tensors for backward
)

loss = out.sum()
loss.backward()

torch.cuda.synchronize(device)
t1 = time.perf_counter()

times.append((t1 - t0) * 1000.0)
peak_mems.append(torch.cuda.max_memory_allocated(device))

return sum(times) / len(times), sum(peak_mems) / len(peak_mems)


def main():
# --- Distributed init (required by the real SkyRL kernel) ---
dist.init_process_group(backend="nccl")
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)

# Initialise Megatron model-parallel state (TP=1, single GPU).
import megatron.core.parallel_state as mpu

mpu.initialize_model_parallel(tensor_model_parallel_size=1)

tp_group = dist.group.WORLD # TP=1, so the whole-world group is the TP group

device = torch.device("cuda", local_rank)

if dist.get_rank() == 0:
print(f"Device : {torch.cuda.get_device_name(device)}")
print(f"World size : {dist.get_world_size()}")
print(
f"Vocab sizes : {VOCAB_SIZES} | chunk_sizes={CHUNK_SIZES} "
f"| warmup={WARMUP_REPS} bench={BENCH_REPS}"
)
print("Implementation: real SkyRL ChunkedDistributedLogprob / DistributedLogprob\n")

col_w = 14
header = (
f"{'vocab_size':>10} "
f"{'seq_len':>10} "
f"{'chunk_size':>10} "
f"{'time ms':>{col_w}} "
f"{'peak MB':>{col_w}} "
f"{'vs no-chunk':>{col_w}} "
f"{'mem saved MB':>{col_w}}"
)
sep = "-" * len(header)

for vocab_size in VOCAB_SIZES:
if dist.get_rank() == 0:
print(f"\n=== vocab_size={vocab_size:,} ===")
print(header)
print(sep)

for seq_len in SEQ_LENS:
# With TP=1 the full vocab lives on this rank.
vocab_start_index = 0
vocab_end_index = vocab_size

# Shape expected by the real kernel: [batch, seq, vocab // TP]
# We use batch=1 to keep allocations comparable to a single-sequence workload.
try:
logits = torch.randn(1, seq_len, vocab_size, dtype=torch.bfloat16, device=device)
# targets: [batch, seq], values in [0, vocab_size)
target = torch.randint(0, vocab_size, (1, seq_len), device=device)
except torch.OutOfMemoryError:
if dist.get_rank() == 0:
oom_row = (
f"{vocab_size:>10,} "
f"{seq_len:>10,} "
f"{'(all)':>10} "
f"{'OOM':>{col_w}} "
f"{'OOM':>{col_w}} "
f"{'OOM':>{col_w}} "
f"{'OOM':>{col_w}}"
)
print(oom_row)
print(sep)
torch.cuda.empty_cache()
continue

# ----- single pass: warmup + benchmark inline per chunk size -----
results: dict[Optional[int], tuple[Optional[float], Optional[float]]] = {}
for cs in CHUNK_SIZES:
try:
for _ in range(WARMUP_REPS):
measure(logits, target, vocab_start_index, vocab_end_index, cs, tp_group, reps=1)
t_cs, mem_cs = measure(
logits, target, vocab_start_index, vocab_end_index, cs, tp_group, reps=BENCH_REPS
)
results[cs] = (t_cs, mem_cs)
except torch.OutOfMemoryError:
results[cs] = (None, None)
finally:
torch.cuda.empty_cache() # isolate between chunk sizes

t_baseline, mem_baseline = results[None]

# ----- print one row per chunk_size -----
for cs in CHUNK_SIZES:
cs_label = "None" if cs is None else str(cs)
t_cs, mem_cs = results[cs]

if t_cs is None:
if dist.get_rank() == 0:
print(
f"{vocab_size:>10,} "
f"{seq_len:>10,} "
f"{cs_label:>10} "
f"{'OOM':>{col_w}} "
f"{'OOM':>{col_w}} "
f"{'OOM':>{col_w}} "
f"{'OOM':>{col_w}}"
)
continue

mem_cs_mb = mem_cs / (1024**2)
if t_baseline is not None and t_cs > 0:
speedup_str = f"{t_baseline / t_cs:>{col_w}.2f}x"
mem_saved_str = f"{mem_baseline / (1024**2) - mem_cs_mb:>{col_w}.0f}"
else:
speedup_str = f"{'N/A':>{col_w}}"
mem_saved_str = f"{'N/A':>{col_w}}"

if dist.get_rank() == 0:
print(
f"{vocab_size:>10,} "
f"{seq_len:>10,} "
f"{cs_label:>10} "
f"{t_cs:>{col_w}.1f} "
f"{mem_cs_mb:>{col_w}.0f} "
f"{speedup_str} "
f"{mem_saved_str}"
)

if dist.get_rank() == 0:
print(sep)

# Free memory before next seq_len
del logits, target
torch.cuda.empty_cache()

if dist.get_rank() == 0:
print("\nAll times are mean wall-clock (ms) over forward+backward passes.")
print("vs no-chunk: speedup relative to chunk_size=None (>1 = faster).")
Comment thread
SumanthRH marked this conversation as resolved.
print("chunk_size=None uses DistributedLogprob; all others use ChunkedDistributedLogprob.")

dist.destroy_process_group()


if __name__ == "__main__":
main()
17 changes: 17 additions & 0 deletions skyrl/train/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,12 +639,29 @@ class TrainerConfig(BaseConfig):
rope_theta: Optional[float] = None
log_example_interval: int = 1
"""Log an example prompt every N training steps, ``0``/``-1`` to disable"""
logprobs_chunk_size: Optional[int] = 1024
"""Chunk size along the sequence dimension when computing log-probs from logits.
This lowers peak GPU memory at the cost of ~2x wall-clock time.
``None`` disables chunking (Megatron backend only; FSDP requires a positive int).
See https://github.com/NovaSky-AI/SkyRL/pull/1610 for more details."""
Comment thread
SumanthRH marked this conversation as resolved.

def __post_init__(self):
# ref model defaults to the policy model
if self.ref.model.path is None:
self.ref.model.path = self.policy.model.path

if self.logprobs_chunk_size is not None and (
not isinstance(self.logprobs_chunk_size, int) or self.logprobs_chunk_size <= 0
):
raise ValueError(
f"logprobs_chunk_size must be a positive integer or None, got {self.logprobs_chunk_size!r}."
)
if self.logprobs_chunk_size is None and self.strategy != "megatron":
raise ValueError(
"logprobs_chunk_size=None (no chunking) is only supported with the Megatron backend. "
f"Set a positive integer for strategy={self.strategy!r}."
)


def validate_dict_keys_against_dataclass(datacls: Type[Any], d: dict):
"""
Expand Down
Loading