[megatron] Enable Nemotron-3-Ultra-550B GRPO RL + fix multi-rank (EP>16/PP>2) weight sync#1816
[megatron] Enable Nemotron-3-Ultra-550B GRPO RL + fix multi-rank (EP>16/PP>2) weight sync#1816erictang000 wants to merge 3 commits into
Conversation
…ht sync Adds an end-to-end full-finetuning GRPO recipe for NVIDIA-Nemotron-3-Ultra-550B (NemotronH hybrid Mamba2+attention, latent MoE, reasoning) colocated with vLLM on 8x8 H200 (EFA), plus the weight-sync/reload correctness fixes it depends on. Validated: avg_raw_reward ~0.9, GSM8K eval ~0.94, grad_norm > 0. Core fix (general, affects any MoE synced at EP>16 / PP>2): the CUDA-IPC weight transport sent only rank-0's per-param slicing metadata, but each Megatron rank packs its own (per-rank-divergent) buffer. Each vLLM worker rebuilt its own GPU's buffer yet sliced it with rank-0's metadata -> correct bytes loaded under wrong names -> coherent- but-garbage generations and reward stuck at 0. Now sends per-GPU metadata and each worker slices its own buffer (cuda_ipc_strategy.py, new_inference_worker_wrap.py). Also included: - Preserve fp32 for the MoE router bias (gate.e_score_correction_bias) through sync; bf16 ULP at its ~25-57 magnitude collapses the per-expert offsets and corrupts routing. - Guard vLLM layerwise-reload get_numel_loaded (cf. vllm-project/vllm#44814): the composed-weight-loader double-count finalizes a layer early and drops Mamba mixer.D (uninitialized -> NaN). Mirrors the existing conv_weights reload workaround. - Forward HF_*/cache dirs and SKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S to Ray worker actors (prepare_runtime_environment + GPU-CI conftest). - Reasoning-aware GSM8K reward: strip the <think> trace, score strict `#### <n>` else last-number with comma/$ normalization. - Nemotron-Ultra logprob round-trip test params (EP16/PP2 baseline; EP32/PP4 regressions). - Example recipe (run_megatron_nemotron_ultra.sh), staging helper, and README. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request adds support for training the Nemotron-3-Ultra-550B model using GRPO RL on GSM8K with Megatron and vLLM. Key changes include a new multi-node launch script, a staging script, robust GSM8K reward parsing for reasoning models, and critical fixes for CUDA-IPC weight synchronization, vLLM layerwise-reload, and native fp32 precision syncing for MoE router biases. The review feedback suggests improving the number normalization function to handle float representations of integers, simplifying regex group extraction in the staging script, and double-quoting the arguments in the shell script to prevent word splitting.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| def _norm_num(s: str) -> str: | ||
| """Normalize a parsed number to compare against the (comma-free integer) ground truth.""" | ||
| return s.strip().rstrip(".").replace(",", "").replace("$", "") |
There was a problem hiding this comment.
The current implementation of _norm_num does not normalize float representations of integers (e.g., '72.0' or '72.'). If the model outputs a decimal representation of an integer, the comparison against the ground truth (which is typically a clean integer string like '72') will fail, leading to a reward of 0.0 instead of 1.0.
We can make this more robust by attempting to parse the value as a float and checking if it represents an integer.
| def _norm_num(s: str) -> str: | |
| """Normalize a parsed number to compare against the (comma-free integer) ground truth.""" | |
| return s.strip().rstrip(".").replace(",", "").replace("$", "") | |
| def _norm_num(s: str) -> str: | |
| """Normalize a parsed number to compare against the (comma-free integer) ground truth.""" | |
| val = s.strip().rstrip(".").replace(",", "").replace("$", "") | |
| try: | |
| f_val = float(val) | |
| if f_val.is_integer(): | |
| return str(int(f_val)) | |
| except ValueError: | |
| pass | |
| return val |
|
|
||
| def to_row(example, idx, split): | ||
| q = example["question"] | ||
| sol = re.search(r"#### (\-?[0-9\.\,]+)", example["answer"]).group(0).split("#### ")[1].replace(",", "") |
There was a problem hiding this comment.
Using group(0) followed by .split('#### ')[1] is redundant and can be simplified by directly accessing the captured group group(1). Additionally, wrapping this in a safe check prevents potential AttributeError if re.search ever returns None.
match = re.search(r"#### (\-?[0-9\.\,]+)", example["answer"])
sol = match.group(1).replace(",", "") if match else ""| trainer.max_ckpts_to_keep=3 \ | ||
| trainer.ckpt_interval=20 \ | ||
| trainer.ckpt_path="$HOME/ckpts/gsm8k_nemotron_ultra_ckpt" \ | ||
| $@ |
…dings
Adds a sweep harness (real Megatron fwd+bwd on fabricated rollouts, no vLLM
generation) to map, on 64xH200, the max max_tokens_per_microbatch and the
parallelism (TP/PP/CP/EP/DP) that maximizes full-FT GRPO training throughput
for NVIDIA-Nemotron-3-Ultra-550B, plus a long-context (variable-length) study.
Findings (examples/train/megatron/NEMOTRON_ULTRA_THROUGHPUT.md):
- Max MTPM ~= 64k tokens/microbatch at the validated TP8/PP4/EP16/DP2 config.
- Highest throughput: TP8/PP2/EP32/DP4 (~8.5k tok/s, +11% over PP4/DP2) for
short/medium seqs; config space is pinned (PP8 invalid for 108 layers, EP8
OOMs, TP4 doubles activations via sequence parallelism).
- Long context: single-sequence ceiling ~40-48k tokens (CP1/PP4/DP2); CP gives
little net benefit (CP=2 forces PP2 whose 2x weights cancel the savings). Long
seqs are more throughput-efficient per token (~12k tok/s at ~39k mean).
New files:
- examples/train_scripts/full_context/{trainer_ultra_sweep,main_ultra_sweep,analyze_sweep}.py
- examples/train/megatron/run_ultra_sweep.sh
- examples/train/megatron/NEMOTRON_ULTRA_THROUGHPUT.md
worker.py: get_cuda_memory() now also returns max_allocated/max_reserved
high-water marks (capture in-step peak even when queried after offload).
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Follow-up to the throughput sweep. CP composes with EP in Megatron-Core (EP divides TP*CP*DP, with ETP=1), so adding CP does NOT force EP down. Measured: - TP8/PP4/CP2/EP16/DP1 fits a single 96k sequence (128k OOMs) -- CP2 roughly doubles the single-sequence ceiling (~40-48k -> ~96k) while keeping PP4's low weights and baseline expert memory. Best long-context config. - TP8/PP2/CP4/EP32/DP1 is valid and loads but still OOMs at 128k: dropping to PP2 to free GPUs for CP4 doubles the weights and eats the budget CP frees. So the 60k+-30k distribution is mostly trainable with PP4/CP2 (clamp ~96k, ~10% truncated) at the cost of DP->1; the full 131k tail still OOMs. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
What
Enables end-to-end full-finetuning GRPO RL of NVIDIA-Nemotron-3-Ultra-550B-A55B (NemotronH hybrid Mamba2 + attention, latent MoE with 512 experts, reasoning model) colocated with vLLM on 8× nodes of 8×H200 (64 GPUs, EFA) — and fixes the weight-sync/reload correctness bugs that block it (and other large MoE models).
Validated: trains end-to-end with the included recipe —
avg_raw_reward ≈ 0.9, GSM8Keval ≈ 0.94,grad_norm > 0. Megatron mesh TP8 / PP4 / EP16 / ETP1 (DP2); vLLM TP8 × PP4.Replication guide:
examples/train/megatron/README_nemotron_ultra.md.Why (root cause)
vLLM produced coherent-looking garbage after every weight sync → all rewards 0 → no learning. The bridge export was proven bit-correct (0 mismatches over 108k tensors), which localized the bug to the CUDA-IPC weight transport:
This is general to any MoE synced at EP>16/PP>2, not just Nemotron.
Changes
cuda_ipc_strategy.py,new_inference_worker_wrap.py)gate.e_score_correction_bias) preserved through sync — bf16 ULP at its ~25–57 magnitude collapses the tiny per-expert offsets (std ~7e-4) and corrupts routing. (megatron_worker.py)composed_weight_loaderdouble-count finalizes a layer early and drops Mambamixer.D(uninitialized → NaN). Guarded monkeypatch mirroring the existingconv_weightsworkaround; remove once on a vLLM that includes #44814. (layerwise_reload.py)HF_*/cache dirs andSKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S→ Ray worker actors. (utils.py, GPU-CIconftest.py)<think>, score strict#### <n>else last-number with comma/$ normalization. (skyrl_gym/envs/gsm8k/env.py)run_megatron_nemotron_ultra.sh,stage_nemotron_ultra.py(model+data staging incl.chat_template.jinja),README_nemotron_ultra.md.h100-marked / gated on the 550B model.Testing
test_logprobs_matching_roundtrip[nemotron3-ultra…]round-trip on 8×8 H200 (EP16/PP2 and EP16/PP4 pass; pre-sync Megatron↔vLLM ~0.06, post-sync ~0.15–0.30).Note
The
get_numel_loadedandconv_weightsreload patches are temporary vLLM workarounds (pending #44814 and #42481); the per-GPU-metadata transport fix and the rest are permanent.🤖 Generated with Claude Code