Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions docs/content/docs/configuration/config.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,34 @@ Some rules for configuring these parameters:
`optimizer_config_kwargs.use_precision_aware_optimizer=true` can cause checkpointing to fail. See: https://github.com/nvidia/megatron-lm/issues/1820. We recommend leaving this setting to `false`.
</Callout>

`OptimizerConfig`'s `*_dtype` fields (`main_params_dtype`, `exp_avg_dtype`, `exp_avg_sq_dtype`, `params_dtype`) are typed as `torch.dtype`, but `optimizer_config_kwargs` is forwarded from YAML as plain strings. We coerce these `*_dtype` strings to `torch.dtype` before constructing `OptimizerConfig`, so low-precision optimizer state can be configured directly from YAML, e.g.:

```yaml
optimizer_config_kwargs:
use_precision_aware_optimizer: true
exp_avg_dtype: bf16
exp_avg_sq_dtype: fp8
main_params_dtype: fp32
```

The accepted dtype-name strings (case- and whitespace-insensitive) are:

| Name | Aliases | `torch.dtype` |
|------|---------|---------------|
| `fp32` | `float32`, `float` | `torch.float32` |
| `fp16` | `float16`, `half` | `torch.float16` |
| `bf16` | `bfloat16` | `torch.bfloat16` |
| `fp8` | `float8`, `uint8` | `torch.uint8` |

`fp8` maps to `torch.uint8` because TransformerEngine represents FP8 optimizer state as `uint8`. Per-field legal sets are enforced before the kwargs reach FusedAdam, so an illegal value fails fast with a clear `ValueError`:

- `main_params_dtype` (master weights): `fp32`, `fp16`
- `exp_avg_dtype` / `exp_avg_sq_dtype`: `fp32`, `fp16`, `bf16`, `fp8`

<Callout type="info">
These short forms are specific to `optimizer_config_kwargs` and differ from the full `bfloat16` / `float16` / `float32` spellings accepted by `str_to_torch_dtype` used elsewhere in SkyRL.
</Callout>

## Optimizer Configuration

For both the critic and policy model, we provide a common optimizer configuration
Expand Down
12 changes: 12 additions & 0 deletions docs/content/docs/examples/megatron.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,18 @@ empty_cuda_cache: true

These default values can be overridden by passing in the corresponding arguments to `trainer.policy.megatron_config` in the launch script.

`optimizer_config_kwargs` accepts dtype-name strings for its `*_dtype` fields (`main_params_dtype`, `exp_avg_dtype`, `exp_avg_sq_dtype`, `params_dtype`), which are coerced to `torch.dtype` before constructing `OptimizerConfig`. This lets you configure low-precision optimizer state from YAML:

```yaml
optimizer_config_kwargs:
use_precision_aware_optimizer: true
exp_avg_dtype: bf16
exp_avg_sq_dtype: fp8
main_params_dtype: fp32
```

The accepted short forms are `fp32` (aliases `float32`/`float`), `fp16` (`float16`/`half`), `bf16` (`bfloat16`), and `fp8` (`float8`/`uint8`); `fp8` maps to `torch.uint8` since TransformerEngine represents FP8 optimizer state as `uint8`. `main_params_dtype` is restricted to `fp32`/`fp16`, while `exp_avg_dtype`/`exp_avg_sq_dtype` additionally allow `bf16`/`fp8`; illegal values fail fast with a `ValueError`. See the [Megatron configuration guide](../configuration/config#megatron-configuration) for the full table.

## Parallelism Resources

Understanding and configuring parallelism strategies for large models can be challenging.
Expand Down
7 changes: 6 additions & 1 deletion skyrl/backends/skyrl_train/distributed/megatron/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler
from omegaconf import DictConfig

from skyrl.backends.skyrl_train.distributed.megatron.optimizer_dtype import (
coerce_optimizer_dtype_kwargs,
)
from skyrl.train.config import OptimizerConfig as SkyRLOptimizerConfig


Expand All @@ -45,7 +48,9 @@ def init_megatron_optim_config(
"params_dtype": torch.bfloat16,
"use_distributed_optimizer": True,
}
optim_args.update(optimizer_config_kwargs)
# Coerce any ``*_dtype`` string (e.g. "bf16" from YAML) into a real torch.dtype
# before it reaches Megatron's OptimizerConfig / FusedAdam, which require dtypes.
optim_args.update(coerce_optimizer_dtype_kwargs(optimizer_config_kwargs))

config = OptimizerConfig(**optim_args)
return config
Expand Down
96 changes: 96 additions & 0 deletions skyrl/backends/skyrl_train/distributed/megatron/optimizer_dtype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""Pure-Python (torch-only) coercion of Megatron optimizer ``*_dtype`` kwargs.

This is intentionally kept free of any ``megatron.core`` import so the coercion
logic can be unit-tested on the cheap CPU CI lane (which installs torch but not
megatron-core). ``optimizer.py`` imports the public helper from here.
"""

from typing import Any, Dict, Set

import torch

# Canonical dtype-name -> torch.dtype mapping, mirroring Megatron-LM's own
# ``dtype_map`` in ``megatron/training/arguments.py`` (the short forms 'fp32',
# 'bf16', 'fp16', 'fp8' are the ones Megatron itself maps). The extra aliases
# ('bfloat16', 'float16'/'half', 'float32'/'float', 'float8'/'uint8') accept the
# spellings a user might reasonably write in YAML. ``fp8`` maps to ``torch.uint8``
# because TransformerEngine represents FP8 optimizer state as uint8.
_DTYPE_NAME_TO_TORCH: Dict[str, torch.dtype] = {
"fp32": torch.float32,
"float32": torch.float32,
"float": torch.float32,
"bf16": torch.bfloat16,
"bfloat16": torch.bfloat16,
"fp16": torch.float16,
"float16": torch.float16,
"half": torch.float16,
"fp8": torch.uint8,
"float8": torch.uint8,
"uint8": torch.uint8,
}

# Per-field legal dtypes, enforced before the kwargs reach FusedAdam (which would
# otherwise raise a cryptic error deep inside TransformerEngine). Only fields that
# are actually forwarded to TE FusedAdam — and whose accepted set is verifiable in
# the TE source — are listed here. ``main_params_dtype`` (a.k.a. master weights)
# maps to FusedAdam's ``master_weight_dtype``, which only supports fp32/fp16;
# ``exp_avg_dtype`` / ``exp_avg_sq_dtype`` additionally allow bf16/fp8.
# ``main_grads_dtype`` is deliberately NOT listed: at the pinned megatron-core rev
# it is not forwarded to FusedAdam (see megatron/core/optimizer/__init__.py, which
# only passes exp_avg_dtype, exp_avg_sq_dtype, and main_params_dtype as
# master_weight_dtype), so there is no TE-backed legal set to enforce — it is still
# coerced str->dtype here and left for ``OptimizerConfig.__post_init__`` to validate
# (mirroring how ``params_dtype`` is handled: coerced, not field-validated).
# Fields not listed here accept any value in ``_DTYPE_NAME_TO_TORCH``.
_LEGAL_FIELD_DTYPES: Dict[str, Set[torch.dtype]] = {
"main_params_dtype": {torch.float32, torch.float16},
"exp_avg_dtype": {torch.float32, torch.bfloat16, torch.float16, torch.uint8},
"exp_avg_sq_dtype": {torch.float32, torch.bfloat16, torch.float16, torch.uint8},
}


def coerce_optimizer_dtype_kwargs(optimizer_config_kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""Coerce ``*_dtype`` string values in Megatron optimizer kwargs to ``torch.dtype``.

Megatron's precision-aware ``OptimizerConfig`` types ``exp_avg_dtype`` /
``exp_avg_sq_dtype`` / ``main_params_dtype`` (and friends) as real ``torch.dtype``,
but SkyRL forwards ``optimizer_config_kwargs`` verbatim from YAML/Hydra, which
delivers plain strings (e.g. ``"bf16"``). This converts any ``*_dtype`` key whose
value is a dtype-name string into the corresponding ``torch.dtype`` using Megatron-LM's
canonical mapping, validates the result against the per-field legal set, and leaves
everything else (non-``*_dtype`` keys, values already ``torch.dtype``) untouched.

Returns a new dict; the input is not mutated.

Raises:
ValueError: if a ``*_dtype`` value is an unrecognized dtype name, or if a coerced
dtype is illegal for that specific field (e.g. bf16/fp8 for ``main_params_dtype``).
"""
coerced: Dict[str, Any] = {}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If optimizer_config_kwargs is None (e.g., if it is omitted or set to null in the YAML configuration), calling .items() on it will raise an AttributeError. Adding a defensive None check at the beginning of the function ensures robustness and prevents runtime crashes.

    if optimizer_config_kwargs is None:
        return {}
    coerced: Dict[str, Any] = {}

for key, value in optimizer_config_kwargs.items():
if not key.endswith("_dtype"):
coerced[key] = value
continue

if isinstance(value, torch.dtype):
dtype = value
elif isinstance(value, str):
name = value.strip().lower()
if name not in _DTYPE_NAME_TO_TORCH:
raise ValueError(
f"Unrecognized dtype name {value!r} for optimizer kwarg {key!r}. "
f"Expected one of {sorted(_DTYPE_NAME_TO_TORCH)} or a torch.dtype."
)
dtype = _DTYPE_NAME_TO_TORCH[name]
else:
# Not a dtype-name string or torch.dtype (e.g. None); pass through so
# Megatron's own validation surfaces any problem.
coerced[key] = value
continue

legal = _LEGAL_FIELD_DTYPES.get(key)
if legal is not None and dtype not in legal:
legal_names = sorted({n for n, d in _DTYPE_NAME_TO_TORCH.items() if d in legal})
raise ValueError(f"Illegal dtype {dtype} for optimizer kwarg {key!r}; legal values are {legal_names}.")
coerced[key] = dtype
return coerced
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
"""CPU tests for central Megatron optimizer-state dtype coercion.

Verifies that ``*_dtype`` string values forwarded through
``MegatronConfig.optimizer_config_kwargs`` (e.g. "bf16" from YAML/Hydra) are
coerced to real ``torch.dtype`` before reaching Megatron's precision-aware
``OptimizerConfig`` / TransformerEngine FusedAdam, using Megatron-LM's own
canonical dtype mapping; that illegal values for ``main_params_dtype`` (master
weights, fp32/fp16 only) are rejected; and that unrelated kwargs pass through
untouched.

``TestCoerceOptimizerDtypeKwargs`` exercises the pure-Python (torch-only)
coercion helper, which lives in a module that does NOT import megatron-core, so
it runs unconditionally on the cheap CPU CI lane (which installs torch but not
megatron-core). ``TestInitMegatronOptimConfigDtypeCoercion`` builds a real
``OptimizerConfig`` and is therefore skipped when megatron-core is not installed
(mirroring ``test_megatron_correctness.py``). No CUDA is required by either.

uv run --isolated --extra megatron --extra dev pytest \
tests/backends/skyrl_train/distributed/test_optimizer_dtype_coercion.py -v
"""

import sys

import pytest
import torch

from skyrl.backends.skyrl_train.distributed.megatron.optimizer_dtype import (
coerce_optimizer_dtype_kwargs,
)

_has_megatron = "megatron" in sys.modules or __import__("importlib").util.find_spec("megatron") is not None


class TestCoerceOptimizerDtypeKwargs:
"""``coerce_optimizer_dtype_kwargs`` maps dtype-name strings to torch.dtype.

Runs unconditionally: the helper module is torch-only (no megatron-core).
"""

def _coerce(self, kwargs: dict) -> dict:
return coerce_optimizer_dtype_kwargs(kwargs)

@pytest.mark.parametrize(
"name,expected",
[
("bf16", torch.bfloat16),
("bfloat16", torch.bfloat16),
("fp16", torch.float16),
("float16", torch.float16),
("half", torch.float16),
("fp32", torch.float32),
("float32", torch.float32),
("float", torch.float32),
("fp8", torch.uint8),
("float8", torch.uint8),
("uint8", torch.uint8),
],
)
def test_string_names_coerce_to_torch_dtype(self, name, expected):
"""Each canonical/alias dtype name maps to the right torch.dtype."""
# exp_avg_dtype legally accepts fp32/fp16/bf16/fp8, so it can exercise all names.
out = self._coerce({"exp_avg_dtype": name})
assert out["exp_avg_dtype"] == expected
assert isinstance(out["exp_avg_dtype"], torch.dtype)

def test_fp8_maps_to_uint8(self):
"""TE represents fp8 optimizer state as uint8."""
out = self._coerce({"exp_avg_sq_dtype": "fp8"})
assert out["exp_avg_sq_dtype"] is torch.uint8

def test_case_and_whitespace_insensitive(self):
out = self._coerce({"exp_avg_dtype": " BF16 "})
assert out["exp_avg_dtype"] is torch.bfloat16

def test_already_torch_dtype_passes_through(self):
"""A value already a torch.dtype is preserved as-is."""
out = self._coerce({"exp_avg_dtype": torch.bfloat16})
assert out["exp_avg_dtype"] is torch.bfloat16

def test_main_params_dtype_accepts_fp32_and_fp16(self):
"""main_params_dtype (master weights) legally accepts only fp32/fp16."""
assert self._coerce({"main_params_dtype": "fp32"})["main_params_dtype"] is torch.float32
assert self._coerce({"main_params_dtype": "fp16"})["main_params_dtype"] is torch.float16

@pytest.mark.parametrize("bad", ["bf16", "fp8"])
def test_main_params_dtype_rejects_bf16_and_fp8(self, bad):
"""bf16/fp8 are illegal master-weight dtypes and must raise."""
with pytest.raises(ValueError, match="main_params_dtype"):
self._coerce({"main_params_dtype": bad})

@pytest.mark.parametrize(
"name,expected", [("bf16", torch.bfloat16), ("fp16", torch.float16), ("fp32", torch.float32)]
)
def test_params_dtype_is_coerced_with_no_field_restriction(self, name, expected):
"""``params_dtype`` ends in ``_dtype`` and has no legal-set entry, so it is
coerced for any recognized alias and overrides the bf16 default that
``init_megatron_optim_config`` seeds (see optimizer.py)."""
out = self._coerce({"params_dtype": name})
assert out["params_dtype"] is expected

def test_main_grads_dtype_coerced_but_not_field_validated(self):
"""``main_grads_dtype`` is not forwarded to FusedAdam at the pinned rev, so it
has no legal-set row: it is coerced str->dtype but any value is accepted here,
leaving megatron-core's ``__post_init__`` to validate it. bf16 (which a legal
set would reject) coerces fine."""
out = self._coerce({"main_grads_dtype": "bf16"})
assert out["main_grads_dtype"] is torch.bfloat16

def test_unrecognized_dtype_name_raises(self):
with pytest.raises(ValueError, match="Unrecognized dtype name"):
self._coerce({"exp_avg_dtype": "bf17"})

def test_unrelated_kwargs_pass_through_untouched(self):
"""Non-``*_dtype`` keys are returned unchanged."""
kwargs = {
"use_precision_aware_optimizer": True,
"optimizer_offload_fraction": 0.5,
"overlap_cpu_optimizer_d2h_h2d": False,
"exp_avg_dtype": "bf16",
}
out = self._coerce(kwargs)
assert out["use_precision_aware_optimizer"] is True
assert out["optimizer_offload_fraction"] == 0.5
assert out["overlap_cpu_optimizer_d2h_h2d"] is False
assert out["exp_avg_dtype"] is torch.bfloat16

def test_non_string_non_dtype_dtype_value_passes_through(self):
"""A ``*_dtype`` key whose value is neither a dtype name nor torch.dtype
(e.g. None) passes through so Megatron's own validation surfaces it."""
out = self._coerce({"main_grads_dtype": None})
assert out["main_grads_dtype"] is None

def test_input_not_mutated(self):
"""The helper returns a new dict and does not mutate the input."""
kwargs = {"exp_avg_dtype": "bf16"}
self._coerce(kwargs)
assert kwargs["exp_avg_dtype"] == "bf16"


@pytest.mark.skipif(not _has_megatron, reason="megatron-core not installed")
class TestInitMegatronOptimConfigDtypeCoercion:
"""End-to-end: ``init_megatron_optim_config`` builds a real OptimizerConfig
with coerced dtypes from string kwargs."""

def test_string_dtype_kwargs_reach_optimizer_config(self):
from skyrl.backends.skyrl_train.distributed.megatron.optimizer import (
init_megatron_optim_config,
)
from skyrl.train.config import OptimizerConfig as SkyRLOptimizerConfig

optim_config = SkyRLOptimizerConfig()
config = init_megatron_optim_config(
optim_config,
{
"use_precision_aware_optimizer": True,
"exp_avg_dtype": "bf16",
"exp_avg_sq_dtype": "fp8",
"main_params_dtype": "fp32",
},
)
assert config.exp_avg_dtype is torch.bfloat16
assert config.exp_avg_sq_dtype is torch.uint8
assert config.main_params_dtype is torch.float32

def test_params_dtype_string_override_reaches_optimizer_config(self):
"""A string ``params_dtype`` override is coerced and replaces the seeded
bf16 default in the constructed OptimizerConfig."""
from skyrl.backends.skyrl_train.distributed.megatron.optimizer import (
init_megatron_optim_config,
)
from skyrl.train.config import OptimizerConfig as SkyRLOptimizerConfig

config = init_megatron_optim_config(SkyRLOptimizerConfig(), {"params_dtype": "fp16"})
assert config.params_dtype is torch.float16

def test_default_kwargs_leave_dtypes_at_megatron_defaults(self):
"""With no ``*_dtype`` overrides, OptimizerConfig keeps its fp32 defaults
(byte-identical to prior behavior)."""
from skyrl.backends.skyrl_train.distributed.megatron.optimizer import (
init_megatron_optim_config,
)
from skyrl.train.config import OptimizerConfig as SkyRLOptimizerConfig

config = init_megatron_optim_config(SkyRLOptimizerConfig(), {})
assert config.exp_avg_dtype is torch.float32
assert config.exp_avg_sq_dtype is torch.float32
assert config.main_params_dtype is torch.float32

def test_precision_aware_off_with_nonfp32_state_fast_fails_in_megatron(self):
"""Coercing ``exp_avg_dtype='bf16'`` passes the helper's own validation, but
megatron-core's ``OptimizerConfig.__post_init__`` then asserts that
exp_avg_dtype can only be fp32 when ``use_precision_aware_optimizer`` is False.

This documents that the fast-fail is megatron's (a real AssertionError),
not a silent mis-coercion: the helper coerces the string fine, the rejection
happens downstream in OptimizerConfig construction.
"""
from skyrl.backends.skyrl_train.distributed.megatron.optimizer import (
init_megatron_optim_config,
)
from skyrl.train.config import OptimizerConfig as SkyRLOptimizerConfig

with pytest.raises(AssertionError, match="exp_avg_dtype can only be fp32"):
init_megatron_optim_config(
SkyRLOptimizerConfig(),
{"use_precision_aware_optimizer": False, "exp_avg_dtype": "bf16"},
)
Loading