Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
fdeb482
x
erictang000 Apr 29, 2026
2f40ffe
x
erictang000 Apr 29, 2026
496bfb5
[wip] starting point for overnight nemotron3 nano debug
erictang000 Apr 30, 2026
86fe57b
[nemotron3] fix wake_up(kv_cache) OOM for 30B nano test
erictang000 Apr 30, 2026
d3d13ec
[debug] dump bridge-emitted weight names via SKYRL_DUMP_WEIGHT_NAMES
erictang000 Apr 30, 2026
d52a1e7
[debug] dump bucket-ordered broadcast names via SKYRL_DUMP_BROADCAST_…
erictang000 Apr 30, 2026
08c5d4b
[debug] env var to bypass bucketing for nemotron NaN diagnosis
erictang000 Apr 30, 2026
01c4a1d
[docs] running notes on the nemotron3 nano post-sync NaN
erictang000 Apr 30, 2026
a406c02
[debug] nemotron3-nano_tp2_ep2 variant for EP-localization
erictang000 Apr 30, 2026
7e49668
[debug] include value stats in broadcast dump
erictang000 Apr 30, 2026
7dcc5a2
[test] revert diagnostic-only nemotron3-nano_tp2_ep2 variant
erictang000 Apr 30, 2026
4b7e946
[docs] expand nemotron3 nano debug writeup with full findings
erictang000 Apr 30, 2026
1ca719c
[deps] bump vllm 0.19.0 -> 0.20.0, torch 2.10 -> 2.11
erictang000 Apr 30, 2026
7ee0593
[deps] regenerate uv.lock for vllm 0.20 / torch 2.11 upgrade
erictang000 Apr 30, 2026
f4af91d
[deps] use vllm 0.20.0+cu129 wheel; keep torch on cu128
erictang000 Apr 30, 2026
c867a68
[nemotron3][vllm020] force moe_backend=triton for nano test
erictang000 Apr 30, 2026
1e08a0d
[docs] capture vllm 0.20 upgrade results
erictang000 Apr 30, 2026
495cd4a
[nemotron3][vllm020] also set moe_backend=triton for the tiny model
erictang000 Apr 30, 2026
1d79e23
[docs] tiny test passes end-to-end on vllm 0.20
erictang000 Apr 30, 2026
1470e13
Merge branch 'main' of https://github.com/erictang000/SkyRL into nemo…
erictang000 Apr 30, 2026
4a72c42
[docs] capture run17/run18 results on merged stack
erictang000 Apr 30, 2026
96a48a6
x
erictang000 Apr 30, 2026
6a38b86
[nemotron3][vllm020] fix Mamba conv1d corruption + clean up debug ins…
erictang000 Apr 30, 2026
1318ff1
[overnight] start nemotron3_nano gsm8k + dapo runs
erictang000 May 1, 2026
218c625
[overnight] add moe_backend=triton + max_model_len overrides for vllm…
erictang000 May 1, 2026
808e035
[overnight] use inline-dict syntax for engine_init_kwargs override
erictang000 May 1, 2026
9842a52
[overnight] run03 reward=0 at step 1, monitoring
erictang000 May 1, 2026
8d5a5b0
[overnight] disable thinking mode for gsm8k (was burning all 1024 tok…
erictang000 May 1, 2026
4cb4ecc
[overnight] thinking back on + tight sampling + smaller batch
erictang000 May 1, 2026
4c615b0
[overnight] default gsm8k scoring to 'flexible' (extracts last number)
erictang000 May 1, 2026
d00eda4
[overnight] log run06 disk-full + uv cache move to /mnt/nvme
erictang000 May 1, 2026
76b4977
[overnight] move all uv cache subdirs to /mnt/nvme (run07 hit EXDEV)
erictang000 May 1, 2026
858d61e
[overnight] symlink ~/.cache/uv root to nvme; subdir symlinks aren't …
erictang000 May 1, 2026
840c360
[overnight] document run09 degenerate output + start standalone vllm …
erictang000 May 1, 2026
d35c58d
[overnight] try legacy inference path (_SKYRL_USE_NEW_INFERENCE=0)
erictang000 May 1, 2026
262dcb5
[overnight] async_engine=false to dodge OpenAIServingRender API misma…
erictang000 May 1, 2026
697b5b5
[overnight] step 1 reward = 0.940! legacy sync path works
erictang000 May 1, 2026
3380f3b
[overnight] step 2 reward 0.952 (+0.012). reward rising
erictang000 May 1, 2026
dfff3b7
[overnight] dapo: same _SKYRL_USE_NEW_INFERENCE=0 fix as gsm8k
erictang000 May 1, 2026
76a77f4
[overnight] step 4: 0.952. trajectory oscillating around ceiling
erictang000 May 1, 2026
5cb2815
[overnight] step 5+6 + eval: validation 0.953
erictang000 May 1, 2026
ddf69d0
[overnight] step 11 + eval@10: validation plateaued at 0.95. plan DAP…
erictang000 May 1, 2026
8f7cad7
[overnight] gsm8k 16 steps + 3 evals. validation flat at 0.952. cutti…
erictang000 May 1, 2026
432ecb1
[overnight] DAPO launch: bump eval_interval 5->10 to limit eval overhead
erictang000 May 1, 2026
0b49c58
[overnight] DAPO baseline AIME pass@32 = 0.50 (15/30 problems solved)
erictang000 May 1, 2026
ef0281a
[overnight] DAPO run02: shrink micro batches + expandable_segments af…
erictang000 May 1, 2026
7d3a90e
[overnight] DAPO run03: drop expandable_segments (vLLM incompatible),…
erictang000 May 1, 2026
0fdd0af
[overnight] DAPO run03 step 1 OK: pass@16=0.375, no OOM, 25min/step
erictang000 May 1, 2026
b75b272
[overnight] DAPO trajectory through step 4: pass@16 0.375 -> 0.391 (r…
erictang000 May 1, 2026
29c3483
[overnight] DAPO step 6 new peak: pass@16=0.445 (+0.070 vs step 1)
erictang000 May 1, 2026
2fec2f9
[overnight] DAPO 8 steps: peak pass@16=0.445 at step 6, mean ~0.378
erictang000 May 1, 2026
e00407f
[overnight] DAPO step 10 = 0.422 (new peak). final summary + TL;DR
erictang000 May 1, 2026
4bbd4c1
[overnight] DAPO eval@10: pass@32 0.30 -> 0.333 (+3.3pp), mean_pos +44%
erictang000 May 1, 2026
282c268
[overnight] DAPO step 11 = 0.484 pass@16, +11pp vs step 1
erictang000 May 1, 2026
b982ed6
[overnight] DAPO step 12 = 0.539 pass@16 (+16.4pp). still climbing
erictang000 May 1, 2026
903353a
[overnight] DAPO step 13-14: 0.453, 0.484. settling around 0.48 band
erictang000 May 1, 2026
b7b5184
[overnight] DAPO step 15 = 0.523. mean of last 5 = 0.501 vs first 5 =…
erictang000 May 1, 2026
d5c4545
[overnight] DAPO step 16-17: 0.531, 0.539. mean of last 7 = 0.508 (+1…
erictang000 May 1, 2026
c4962c6
[overnight] DAPO step 18 = 0.672 pass@16 (+29.7pp). massive jump
erictang000 May 1, 2026
cc01899
[overnight] DAPO step 20 = 0.719 pass@16 (+34.4pp vs step 1). eval@20…
erictang000 May 1, 2026
81e2fa5
[overnight] DAPO eval@20: AIME pass@32 = 0.500 (+20pp absolute, +67% …
erictang000 May 1, 2026
898e94d
[overnight] DAPO step 22 = 0.727 pass@16 (+35.2pp). steady gains cont…
erictang000 May 1, 2026
33d9873
[overnight] DAPO step 23-25: pass@16 peak now 0.742 (+36.7pp). still …
erictang000 May 1, 2026
831c3ca
[overnight] DAPO step 29 = 0.797 pass@16 (+42.2pp). still climbing
erictang000 May 1, 2026
c71d173
[overnight] DAPO eval@30: AIME pass@32 = 0.567 (17/30, +26.7pp). exce…
erictang000 May 1, 2026
30bc58b
[overnight] DAPO step 31-34: pass@16 peak now 0.844 (+46.9pp). platea…
erictang000 May 1, 2026
4aca79a
[overnight] DAPO eval@40 regression: 0.567 -> 0.433 (overfit signal)
erictang000 May 2, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
511 changes: 511 additions & 0 deletions .claude/runs/PROGRESS.md

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.12
137 changes: 137 additions & 0 deletions examples/train/megatron/run_megatron_dapo_nemotron3_nano.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
set -x

# Use the legacy (non-chunked) inference path to avoid the vLLM 0.20
# layerwise-reload corruption that derails post-sync generation for nemotron_h.
# See PROGRESS.md / gsm8k_run09 → run11 for the diagnosis.
export _SKYRL_USE_NEW_INFERENCE=0
# NOTE: PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True is incompatible with
# vLLM's CuMemAllocator (assertion in vllm/device_allocator/cumem.py:132,
# pytorch/pytorch#147851). Rely on smaller micro batches + shorter
# MAX_RESPONSE_LENGTH instead.

# Colocated DAPO training+generation for Nemotron3-Nano-30B-A3B on DAPO with Megatron.
# Should run on 1 node of 8xB2000

# bash examples/train/algorithms/dapo/prepare_dapo_data.sh
# bash examples/train/megatron/run_megatron_dapo_nemotron3_nano.sh

MODEL_NAME="nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16"
DATA_DIR="$HOME/data/dapo"
TRAIN_FILE="$DATA_DIR/dapo-math-17k-cleaned.parquet"
TEST_FILE="$DATA_DIR/aime-2024-cleaned.parquet"
NUM_NODES=1
NUM_GPUS_PER_NODE=8
NUM_INFERENCE_ENGINES=1
INFERENCE_ENGINE_TENSOR_PARALLEL_SIZE=8
LOGGER="wandb" # change to "console" to print to stdout

CLIP_RATIO_LOW=0.2
CLIP_RATIO_HIGH=0.28
# use token mean loss reduction
LOSS_REDUCTION="token_mean"
# applies overlong filtering (but not soft overlong punishment)
APPLY_OVERLONG_FILTERING=true
# apply soft overlong punishment with custom trainer impl in main_dapo.py
OVERLONG_BUFFER_LEN=$((1024 * 4))
OVERLONG_BUFFER_PENALTY_FACTOR=1.0

# other DAPO parameters
USE_KL_LOSS=false
TEMPERATURE=1.0
TOP_P=1.0
EVAL_TOP_P=0.7
CLIP_RATIO_C=10.0
MAX_PROMPT_LENGTH=$((1024 * 2))
# Reduced from 8192 to 4096 for the overnight smoke run — full 8k responses
# pushed Megatron's packed activations OOM (run01) and we don't have headroom
# at this batch size. AIME problems usually fit in 4k.
MAX_RESPONSE_LENGTH=$((1024 * 4))

# repro run parameters
TRAIN_BATCH_SIZE=128
MINI_BATCH_SIZE=32
N_SAMPLES_PER_PROMPT=16
EVAL_N_SAMPLES_PER_PROMPT=32
ENFORCE_EAGER=true # cuda graphs can cause some instability
LR=1e-6

# megatron config
MEGATRON_TP=4
MEGATRON_PP=1
MEGATRON_CP=1
MEGATRON_EP=8
MEGATRON_ETP=1


# TIS parameters
TIS_IMP_RATIO_CAP=2.0
TIS_TYPE=token

uv run --isolated --extra megatron -m examples.train.algorithms.dapo.main_dapo \
data.train_data="['$TRAIN_FILE']" \
data.val_data="['$TEST_FILE']" \
trainer.algorithm.advantage_estimator="grpo" \
trainer.algorithm.policy_loss_type="dual_clip" \
trainer.algorithm.overlong_buffer_len=$OVERLONG_BUFFER_LEN \
trainer.algorithm.overlong_buffer_penalty_factor=$OVERLONG_BUFFER_PENALTY_FACTOR \
trainer.algorithm.loss_reduction=$LOSS_REDUCTION \
generator.inference_engine.enforce_eager=$ENFORCE_EAGER \
generator.apply_overlong_filtering=$APPLY_OVERLONG_FILTERING \
generator.sampling_params.temperature=$TEMPERATURE \
generator.sampling_params.top_p=$TOP_P \
generator.eval_sampling_params.top_p=$EVAL_TOP_P \
generator.eval_sampling_params.temperature=$TEMPERATURE \
generator.eval_sampling_params.max_generate_length=$MAX_RESPONSE_LENGTH \
trainer.algorithm.use_kl_loss=$USE_KL_LOSS \
trainer.algorithm.clip_ratio_c=$CLIP_RATIO_C \
trainer.policy.model.path="$MODEL_NAME" \
trainer.placement.colocate_all=true \
trainer.strategy=megatron \
trainer.placement.policy_num_nodes=$NUM_NODES \
trainer.placement.policy_num_gpus_per_node=$NUM_GPUS_PER_NODE \
generator.inference_engine.num_engines=$NUM_INFERENCE_ENGINES \
generator.inference_engine.tensor_parallel_size=$INFERENCE_ENGINE_TENSOR_PARALLEL_SIZE \
trainer.policy.megatron_config.tensor_model_parallel_size=$MEGATRON_TP \
trainer.policy.megatron_config.pipeline_model_parallel_size=$MEGATRON_PP \
trainer.policy.megatron_config.context_parallel_size=$MEGATRON_CP \
trainer.policy.megatron_config.expert_model_parallel_size=$MEGATRON_EP \
trainer.policy.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \
trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \
trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \
trainer.epochs=20 \
trainer.algorithm.eps_clip_low=$CLIP_RATIO_LOW \
trainer.algorithm.eps_clip_high=$CLIP_RATIO_HIGH \
trainer.eval_batch_size=1024 \
trainer.eval_before_train=true \
trainer.eval_interval=10 \
trainer.update_epochs_per_batch=1 \
trainer.train_batch_size=$TRAIN_BATCH_SIZE \
trainer.policy_mini_batch_size=$MINI_BATCH_SIZE \
trainer.micro_forward_batch_size_per_gpu=2 \
trainer.micro_train_batch_size_per_gpu=1 \
trainer.ckpt_interval=-1 \
trainer.max_prompt_length=$MAX_PROMPT_LENGTH \
generator.sampling_params.max_generate_length=$MAX_RESPONSE_LENGTH \
trainer.policy.optimizer_config.lr=$LR \
trainer.policy.optimizer_config.num_warmup_steps=40 \
trainer.policy.optimizer_config.weight_decay=0.1 \
trainer.policy.optimizer_config.max_grad_norm=1.0 \
generator.inference_engine.backend=vllm \
generator.inference_engine.run_engines_locally=true \
generator.inference_engine.weight_sync_backend=nccl \
generator.inference_engine.async_engine=false \
generator.batched=true \
environment.env_class=aime \
generator.n_samples_per_prompt=$N_SAMPLES_PER_PROMPT \
generator.eval_n_samples_per_prompt=$EVAL_N_SAMPLES_PER_PROMPT \
generator.inference_engine.gpu_memory_utilization=0.6 \
generator.inference_engine.engine_init_kwargs="{moe_backend: triton, max_model_len: 8192}" \
trainer.logger="$LOGGER" \
trainer.project_name="dapo_nemotron3_nano" \
trainer.run_name="dapo_nemotron3_nano_30b_a3b_base_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_ep${MEGATRON_EP}_etp${MEGATRON_ETP}" \
trainer.export_path="$HOME/exports/dapo_nemotron3_nano_30b_a3b_base_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_ep${MEGATRON_EP}_etp${MEGATRON_ETP}" \
trainer.hf_save_interval=-1 \
trainer.resume_mode=latest \
trainer.max_ckpts_to_keep=3 \
trainer.ckpt_path="$HOME/ckpts/dapo_nemotron3_nano_30b_a3b_base_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_ep${MEGATRON_EP}_etp${MEGATRON_ETP}" \
$@
87 changes: 87 additions & 0 deletions examples/train/megatron/run_megatron_nemotron3_nano.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
set -x

# Use the legacy (non-chunked) inference path. The new path goes through
# vLLM's layerwise reload, which re-runs `process_weights_after_loading` and
# (likely) re-creates view-buffer aliases that corrupt MoE/conv weights for
# nemotron_h beyond the `conv_weights` skip we already added. Standalone
# vLLM with HF weights at T=0.7 produces correct gsm8k answers; post-Megatron-
# sync vLLM produces degenerate output. Legacy path uses CUDA IPC + direct
# model.load_weights, no reload machinery.
export _SKYRL_USE_NEW_INFERENCE=0

# Colocated GRPO training+generation for Nemotron3-Nano-30B-A3B on GSM8K with Megatron.

# uv run examples/train/gsm8k/gsm8k_dataset.py --output_dir $HOME/data/gsm8k
# export WANDB_API_KEY=<your_key_here>
# bash examples/train/megatron/run_megatron_nemotron3_nano.sh

DATA_DIR="$HOME/data/gsm8k"
LOGGER="wandb" # change to "console" to print to stdout
MODEL_NAME="nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16"

INFERENCE_BACKEND="vllm" # currently only vllm is supported for megatron

NUM_NODES=1
NUM_GPUS=8

MEGATRON_TP=4
MEGATRON_PP=1
MEGATRON_CP=1
MEGATRON_EP=8
MEGATRON_ETP=1

INFERENCE_ENGINE_TP=8

# # Qwen3.5 flags
# USE_SAMPLE_PACKING=false # sample packing is not yet supported for GDN layers in megatron - see: https://github.com/NVIDIA/Megatron-LM/pull/2644

uv run --isolated --extra megatron -m skyrl.train.entrypoints.main_base \
data.train_data="['$DATA_DIR/train.parquet']" \
data.val_data="['$DATA_DIR/validation.parquet']" \
trainer.algorithm.advantage_estimator="grpo" \
trainer.policy.model.path=$MODEL_NAME \
trainer.placement.colocate_all=true \
trainer.strategy=megatron \
trainer.placement.policy_num_nodes=$NUM_NODES \
trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \
trainer.placement.critic_num_gpus_per_node=$NUM_GPUS \
trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \
generator.inference_engine.num_engines=1 \
generator.inference_engine.tensor_parallel_size=$INFERENCE_ENGINE_TP \
trainer.policy.megatron_config.tensor_model_parallel_size=$MEGATRON_TP \
trainer.policy.megatron_config.pipeline_model_parallel_size=$MEGATRON_PP \
trainer.policy.megatron_config.context_parallel_size=$MEGATRON_CP \
trainer.policy.megatron_config.expert_model_parallel_size=$MEGATRON_EP \
trainer.policy.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \
trainer.use_sample_packing=true \
trainer.epochs=20 \
trainer.eval_batch_size=256 \
trainer.eval_before_train=false \
trainer.eval_interval=5 \
trainer.update_epochs_per_batch=1 \
trainer.train_batch_size=256 \
trainer.policy_mini_batch_size=64 \
trainer.micro_forward_batch_size_per_gpu=4 \
trainer.micro_train_batch_size_per_gpu=4 \
trainer.ckpt_interval=-1 \
trainer.max_prompt_length=512 \
generator.sampling_params.max_generate_length=3000 \
generator.sampling_params.temperature=0.7 \
generator.sampling_params.top_p=0.9 \
trainer.policy.optimizer_config.lr=1.0e-6 \
trainer.algorithm.use_kl_loss=true \
generator.inference_engine.backend=$INFERENCE_BACKEND \
generator.inference_engine.run_engines_locally=true \
generator.inference_engine.weight_sync_backend=nccl \
generator.inference_engine.async_engine=false \
generator.batched=true \
environment.env_class=gsm8k \
generator.n_samples_per_prompt=5 \
generator.inference_engine.gpu_memory_utilization=0.6 \
generator.inference_engine.engine_init_kwargs="{moe_backend: triton, max_model_len: 4096}" \
trainer.logger="$LOGGER" \
trainer.project_name="nemotron3_nano" \
trainer.run_name="nemotron3_nano_megatron" \
trainer.resume_mode=null \
trainer.ckpt_path="$HOME/ckpts/nemotron3_nano_megatron_ckpt" \
$@
73 changes: 55 additions & 18 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,36 +98,38 @@ skyrl-train = [

fsdp = [
"skyrl[skyrl-train]",
"vllm==0.19.0; sys_platform == 'linux'",
"vllm==0.20.0; sys_platform == 'linux'",
"vllm-router; sys_platform == 'linux'",
"nixl; sys_platform == 'linux'",
"flash-linear-attention; sys_platform == 'linux'",
"causal-conv1d; sys_platform == 'linux'",
"flash-attn==2.8.3; sys_platform == 'linux'",
"torch==2.10.0; sys_platform == 'linux'",
"flashinfer-python==0.6.6; sys_platform == 'linux' and platform_machine == 'x86_64'",
"flashinfer-jit-cache==0.6.6; sys_platform == 'linux' and platform_machine == 'x86_64'",
"torch==2.11.0; sys_platform == 'linux'",
"flashinfer-python==0.6.8.post1; sys_platform == 'linux' and platform_machine == 'x86_64'",
"flashinfer-jit-cache==0.6.8.post1; sys_platform == 'linux' and platform_machine == 'x86_64'",
"flashinfer-cubin==0.6.8.post1; sys_platform == 'linux' and platform_machine == 'x86_64'",
"torchvision; sys_platform == 'linux'",
]

megatron = [
"skyrl[skyrl-train]",
"transformer-engine[pytorch]==2.10.0; sys_platform == 'linux'",
"transformer-engine[pytorch]==2.11.0; sys_platform == 'linux'",
"flash-attn==2.8.3; sys_platform == 'linux'",
"flash-linear-attention; sys_platform == 'linux'",
"causal-conv1d; sys_platform == 'linux'",
"mamba-ssm>=2.3.0; sys_platform == 'linux'",
"vllm==0.19.0; sys_platform == 'linux'",
"vllm==0.20.0; sys_platform == 'linux'",
"vllm-router; sys_platform == 'linux'",
"nixl; sys_platform == 'linux'",
"torch==2.10.0; sys_platform == 'linux'",
"flashinfer-python==0.6.6; sys_platform == 'linux' and platform_machine == 'x86_64'",
"torch==2.11.0; sys_platform == 'linux'",
"flashinfer-python==0.6.8.post1; sys_platform == 'linux' and platform_machine == 'x86_64'",
"torchvision; sys_platform == 'linux'",
# megatron-bridge requires Python 3.12+; pin megatron-core to the same
# constraint so both packages are consistently available (or absent).
"megatron-bridge; sys_platform == 'linux' and python_version >= '3.12'",
"megatron-core; sys_platform == 'linux' and python_version >= '3.12'",
"flashinfer-jit-cache==0.6.6; sys_platform == 'linux' and platform_machine == 'x86_64'",
"flashinfer-jit-cache==0.6.8.post1; sys_platform == 'linux' and platform_machine == 'x86_64'",
"flashinfer-cubin==0.6.8.post1; sys_platform == 'linux' and platform_machine == 'x86_64'",
"nvidia-modelopt; sys_platform == 'linux'",
]

Expand Down Expand Up @@ -184,7 +186,8 @@ required-environments = [
]

constraint-dependencies = [
"flashinfer-jit-cache==0.6.6",
"flashinfer-jit-cache==0.6.8.post1",
"flashinfer-cubin==0.6.8.post1",
]
# each backend should have separate dependencies that can potentially clash
# megatron also clashes with the jax dependency from gpu and tpu extras
Expand All @@ -208,12 +211,17 @@ no-build-isolation-package = [
"transformer-engine-torch",
"transformer-engine",
"nv-grouped-gemm",
# causal-conv1d and mamba-ssm need to compile against torch 2.11 (no
# upstream wheels yet); building with isolation would pin a different
# torch in the build env than the runtime.
"causal-conv1d",
"mamba-ssm",
]
# override unnecessary dependencies and pin versions to override Megatron-Bridge
# unpinned dependencies.
override-dependencies = [
"nvidia-resiliency-ext; sys_platform == 'never'",
"transformer-engine[pytorch]==2.10.0; sys_platform == 'linux'",
"transformer-engine[pytorch]==2.11.0; sys_platform == 'linux'",
"transformers>=5.0.0,<=5.3.0; sys_platform == 'linux'",
"megatron-core>=0.16.0; sys_platform == 'linux'",
"ml_dtypes>=0.5.0; sys_platform == 'linux'",
Expand All @@ -223,6 +231,10 @@ override-dependencies = [
flash-attn = [{requirement = "torch", match-runtime = true}]
transformer-engine = [{requirement = "torch", match-runtime = true}, "build_tools", "ninja"]
transformer-engine-torch = [{requirement = "torch", match-runtime = true}, "build_tools", "ninja"]
# causal-conv1d / mamba-ssm need torch + ninja in the build env (we run them
# with build isolation disabled but uv still uses extra-build-dependencies).
causal-conv1d = [{requirement = "torch", match-runtime = true}, "ninja", "packaging", "wheel", "setuptools"]
mamba-ssm = [{requirement = "torch", match-runtime = true}, "ninja", "packaging", "wheel", "setuptools"]

[tool.uv.extra-build-variables]
flash-attn = { FLASH_ATTENTION_SKIP_CUDA_BUILD = "TRUE"}
Expand All @@ -232,6 +244,11 @@ name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
explicit = true

[[tool.uv.index]]
name = "pytorch-cu129"
url = "https://download.pytorch.org/whl/cu129"
explicit = true

[[tool.uv.index]]
name = "pytorch-cpu"
url = "https://download.pytorch.org/whl/cpu"
Expand All @@ -246,16 +263,36 @@ name = "flashinfer-cu128"
url = "https://flashinfer.ai/whl/cu128"
explicit = true

[[tool.uv.index]]
name = "flashinfer-cu129"
url = "https://flashinfer.ai/whl/cu129"
explicit = true

[[tool.uv.index]]
name = "vllm-cu129"
url = "https://wheels.vllm.ai/0.20.0/cu129"
explicit = true

[tool.uv.sources]
skyrl-gym = { path = "./skyrl-gym", editable = true }
# flashinfer wheels are only available from the custom cu128 index
# flashinfer-jit-cache 0.6.8 is only published against cu128 / cu129. Keep the
# cu128 index since torch is also cu128 here.
flashinfer-jit-cache = { index = "flashinfer-cu128", marker = "sys_platform == 'linux'" }
causal-conv1d = { url = "https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.6.1.post4/causal_conv1d-1.6.1%2Bcu12torch2.10cxx11abiTRUE-cp312-cp312-linux_x86_64.whl", marker = "sys_platform == 'linux' and python_version == '3.12' and platform_machine == 'x86_64'" }
mamba-ssm = { url = "https://github.com/state-spaces/mamba/releases/download/v2.3.1/mamba_ssm-2.3.1%2Bcu12torch2.10cxx11abiTRUE-cp312-cp312-linux_x86_64.whl", marker = "sys_platform == 'linux' and python_version == '3.12' and platform_machine == 'x86_64'" }
# TODO (aaron): Once PyTorch 2.10 is officially supported (stable PyPI torch + matching
# flash-attn wheels), drop the custom wheel URL
flash-attn = { url = "https://github.com/lesj0610/flash-attention/releases/download/v2.8.3-cu12-torch2.10-cp312/flash_attn-2.8.3%2Bcu12torch2.10cxx11abiTRUE-cp312-cp312-linux_x86_64.whl", marker = "sys_platform == 'linux' and python_version == '3.12' and platform_machine == 'x86_64'" }
# Use CUDA torch on Linux, CPU torch on macOS (must match skyrl-train config)
# vllm 0.20.0 PyPI wheel is built against CUDA 13 (libcudart.so.13). The system
# has CUDA 12.9 with torch 2.11+cu129, so use the cu129 wheel from the vllm
# wheels index (not on PyPI).
vllm = [
{ index = "vllm-cu129", marker = "sys_platform == 'linux'" },
]
# NOTE (overnight 2026-04-30): bumped to torch 2.11 so vllm 0.20.0 install
# resolves cleanly. There are no upstream torch-2.11 wheels for causal-conv1d
# or mamba-ssm yet, so those build from source against torch 2.11. Keep the
# flash-attn URL pinned to the lesj0610 fork's torch-2.11 wheel.
flash-attn = { url = "https://github.com/lesj0610/flash-attention/releases/download/v2.8.3-cu12-torch2.11/flash_attn-2.8.3%2Bcu12torch2.11cxx11abiTRUE-cp312-cp312-linux_x86_64.whl", marker = "sys_platform == 'linux' and python_version == '3.12' and platform_machine == 'x86_64'" }

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a personal fork (lesj0610/flash-attention) for a critical dependency like flash-attn is a security and maintainability risk. It is recommended to use the official repository or build from source if a specific patch is needed. If this is a temporary workaround, please add a TODO to revert to the official source once a compatible version is released.

# Use CUDA torch on Linux, CPU torch on macOS (must match skyrl-train config).
# Stay on the cu128 index because torch 2.11+cu128 exists there and the
# flashrl extra requires torch 2.7 (only on cu128). The vllm 0.20 wheel pulled
# from cu129 still loads against cu12 libcudart.so.12 supplied by torch+cu128.
torch = [
{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" },
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
Expand Down
10 changes: 9 additions & 1 deletion skyrl-gym/skyrl_gym/envs/gsm8k/env.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from skyrl_gym.envs.base_text_env import BaseTextEnv, BaseTextEnvStepOutput
from skyrl_gym.envs.gsm8k import utils
from typing import Dict, Any
Expand All @@ -14,9 +15,16 @@ def __init__(self, env_config: Any = None, extras: Dict[str, Any] = {}):
assert "reward_spec" in extras, "reward_spec field is required"
assert "ground_truth" in extras["reward_spec"], "ground_truth is required in reward_spec field"
self.ground_truth = extras["reward_spec"]["ground_truth"]
# Default to flexible scoring. The strict "#### NUMBER" extraction is
# too brittle for modern instruct/thinking models, which typically end
# with "The answer is 42." or "$\boxed{42}$" rather than the GSM8K
# ground-truth format. Flexible takes the last number in the output,
# which works across response styles. Override with
# SKYRL_GSM8K_SCORING_METHOD=strict for the original behavior.
self._scoring_method = os.environ.get("SKYRL_GSM8K_SCORING_METHOD", "flexible")

def _get_reward(self, action: str) -> float:
return utils.compute_score(action, self.ground_truth)
return utils.compute_score(action, self.ground_truth, method=self._scoring_method)

def step(self, action: str) -> BaseTextEnvStepOutput:
done = True # always done after one step
Expand Down
Loading
Loading