From 9db491ea46dd2208c1af2b2cd7466d0a72b629bf Mon Sep 17 00:00:00 2001 From: dyurk-lila Date: Mon, 15 Jun 2026 17:06:32 -0500 Subject: [PATCH] [megatron] Accept dtype-string optimizer_config_kwargs (coerce exp_avg_dtype etc. to torch.dtype) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Megatron's precision-aware `OptimizerConfig` types its `*_dtype` fields (`exp_avg_dtype`, `exp_avg_sq_dtype`, `main_params_dtype`, `params_dtype`) as real `torch.dtype`, but `optimizer_config_kwargs` is forwarded verbatim from YAML/Hydra, which delivers plain strings (e.g. `"bf16"`). Such a string would reach TransformerEngine FusedAdam and crash. This adds a central `str -> torch.dtype` coercion at the single optimizer-construction choke point so low-precision optimizer state can be configured from YAML. ## What changed - New torch-only module `skyrl/backends/skyrl_train/distributed/megatron/optimizer_dtype.py` holding `coerce_optimizer_dtype_kwargs(dict) -> dict` plus two mapping tables: - `_DTYPE_NAME_TO_TORCH`: canonical dtype-name -> `torch.dtype`. The short forms `fp32`/`bf16`/`fp16`/`fp8` follow Megatron-LM's own `dtype_map`; common alias spellings (`bfloat16`, `float16`/`half`, `float32`/`float`, `float8`/`uint8`) are also accepted. `fp8 -> torch.uint8` is a TransformerEngine convention (TE represents FP8 optimizer state as uint8) and is added here, not sourced from Megatron's `dtype_map`. - `_LEGAL_FIELD_DTYPES`: per-field legal sets for the fields actually forwarded to TE FusedAdam — `main_params_dtype` (master weights) restricted to `{fp32, fp16}`; `exp_avg_dtype`/`exp_avg_sq_dtype` allow `{fp32, fp16, bf16, fp8}`. Illegal values raise a clear `ValueError` before reaching FusedAdam. - `optimizer.py`: `init_megatron_optim_config` now calls `coerce_optimizer_dtype_kwargs(optimizer_config_kwargs)` in place of the raw `.update(...)`. The coercion sits at the single shared construction point (sole caller serves SFT and the RL policy). - Docs: documented the new string-dtype support for `optimizer_config_kwargs` in `docs/content/docs/configuration/config.mdx` (next to the existing `use_precision_aware_optimizer` callout, with the full name/alias table) and in `docs/content/docs/examples/megatron.mdx` (concise note cross-linking to the table) — accepted names/aliases, the per-field legal sets, and the `fp8 -> uint8` convention — noting these short forms differ from the full `bfloat16`/`float16`/`float32` spellings accepted by `str_to_torch_dtype` elsewhere. The helper is deliberately kept free of any `megatron.core` import so it can be unit-tested on the CPU CI lane (torch only). Coercion lives at the optimizer-construction choke point rather than `MegatronConfig.__post_init__`, which would replace the YAML strings with `torch.dtype` objects in the dataclass and break the serializable config path (`asdict`/`yaml.dump`). ## Numerical equivalence / safety Byte-identical to current behavior unless a `*_dtype` key is explicitly set in `optimizer_config_kwargs`. With no `*_dtype` overrides the coercion is a pure pass-through copy, so `OptimizerConfig` keeps its existing fp32 defaults for `exp_avg_dtype`/`exp_avg_sq_dtype`/`main_params_dtype` and the hardcoded `params_dtype=torch.bfloat16` seed is unchanged. The default optimizer kwargs contain no `*_dtype` keys, so the default path is unchanged. The only intentional behavior change: a `*_dtype` string that previously would have reached FusedAdam and crashed now becomes the correct `torch.dtype` (enabling low-precision optimizer state), and an illegal `main_params_dtype` now fails fast with a clear message instead of a cryptic TE error. Values already `torch.dtype` and non-dtype kwargs pass through untouched; non-string/non-dtype `*_dtype` values (e.g. `None`) pass through so Megatron's own validation surfaces them. `main_grads_dtype` is coerced str->dtype but intentionally has no legal-set row: at the pinned megatron-core rev it is not forwarded to TE FusedAdam, so there is no TE-backed legal set to enforce; it is left for `OptimizerConfig.__post_init__` to validate (mirroring how `params_dtype` is handled). ## Generality & follow-ups - Covers all Megatron optimizer construction reachable via `init_megatron_optim_config` (SFT and RL policy). The critic worker does not construct an optimizer through this path. - The FSDP optimizer-state path is intentionally out of scope — it is a separate code path with its own mixed-precision config; no change made there. - A separate `str -> torch.dtype` helper already exists (`str_to_torch_dtype` / `PrecisionType.to_dtype`), but neither knows the `fp8 -> uint8` mapping nor does per-field legal-set validation, both of which the precision-aware optimizer-state feature requires; consolidating the canonical name table is a possible follow-up. ## Test plan `tests/backends/skyrl_train/distributed/test_optimizer_dtype_coercion.py`: - `TestCoerceOptimizerDtypeKwargs` (CPU, no skip-guard — runs on the CPU lane): parametrized name->dtype coercion for all aliases, `fp8->uint8`, case/whitespace insensitivity, `torch.dtype` pass-through, `main_params_dtype` accepts fp32/fp16 and rejects bf16/fp8, `params_dtype` coercion, `main_grads_dtype` coerced-but-not-field-validated, unrecognized-name `ValueError`, unrelated kwargs untouched, `None` pass-through, input not mutated. - `TestInitMegatronOptimConfigDtypeCoercion` (megatron-gated via `_has_megatron` skip-guard, no GPU): end-to-end that string kwargs reach a real `OptimizerConfig` with coerced dtypes; `params_dtype` string override replaces the seeded default; default (no override) keeps fp32 defaults; and `use_precision_aware_optimizer=False` + non-fp32 state fast-fails with megatron's own `AssertionError`. Run: ```bash uv run --isolated --extra megatron --extra dev pytest \ tests/backends/skyrl_train/distributed/test_optimizer_dtype_coercion.py -v ``` The CPU class runs on the CPU lane (torch only, megatron-core not required); the megatron-gated class runs wherever megatron-core is installed. No GPU is required by either. --- docs/content/docs/configuration/config.mdx | 28 +++ docs/content/docs/examples/megatron.mdx | 12 + .../distributed/megatron/optimizer.py | 7 +- .../distributed/megatron/optimizer_dtype.py | 96 ++++++++ .../test_optimizer_dtype_coercion.py | 207 ++++++++++++++++++ 5 files changed, 349 insertions(+), 1 deletion(-) create mode 100644 skyrl/backends/skyrl_train/distributed/megatron/optimizer_dtype.py create mode 100644 tests/backends/skyrl_train/distributed/test_optimizer_dtype_coercion.py diff --git a/docs/content/docs/configuration/config.mdx b/docs/content/docs/configuration/config.mdx index 730beccc15..abb48f1e90 100644 --- a/docs/content/docs/configuration/config.mdx +++ b/docs/content/docs/configuration/config.mdx @@ -213,6 +213,34 @@ Some rules for configuring these parameters: `optimizer_config_kwargs.use_precision_aware_optimizer=true` can cause checkpointing to fail. See: https://github.com/nvidia/megatron-lm/issues/1820. We recommend leaving this setting to `false`. +`OptimizerConfig`'s `*_dtype` fields (`main_params_dtype`, `exp_avg_dtype`, `exp_avg_sq_dtype`, `params_dtype`) are typed as `torch.dtype`, but `optimizer_config_kwargs` is forwarded from YAML as plain strings. We coerce these `*_dtype` strings to `torch.dtype` before constructing `OptimizerConfig`, so low-precision optimizer state can be configured directly from YAML, e.g.: + +```yaml +optimizer_config_kwargs: + use_precision_aware_optimizer: true + exp_avg_dtype: bf16 + exp_avg_sq_dtype: fp8 + main_params_dtype: fp32 +``` + +The accepted dtype-name strings (case- and whitespace-insensitive) are: + +| Name | Aliases | `torch.dtype` | +|------|---------|---------------| +| `fp32` | `float32`, `float` | `torch.float32` | +| `fp16` | `float16`, `half` | `torch.float16` | +| `bf16` | `bfloat16` | `torch.bfloat16` | +| `fp8` | `float8`, `uint8` | `torch.uint8` | + +`fp8` maps to `torch.uint8` because TransformerEngine represents FP8 optimizer state as `uint8`. Per-field legal sets are enforced before the kwargs reach FusedAdam, so an illegal value fails fast with a clear `ValueError`: + +- `main_params_dtype` (master weights): `fp32`, `fp16` +- `exp_avg_dtype` / `exp_avg_sq_dtype`: `fp32`, `fp16`, `bf16`, `fp8` + + +These short forms are specific to `optimizer_config_kwargs` and differ from the full `bfloat16` / `float16` / `float32` spellings accepted by `str_to_torch_dtype` used elsewhere in SkyRL. + + ## Optimizer Configuration For both the critic and policy model, we provide a common optimizer configuration diff --git a/docs/content/docs/examples/megatron.mdx b/docs/content/docs/examples/megatron.mdx index f6fd0e22f5..6733d09702 100644 --- a/docs/content/docs/examples/megatron.mdx +++ b/docs/content/docs/examples/megatron.mdx @@ -85,6 +85,18 @@ empty_cuda_cache: true These default values can be overridden by passing in the corresponding arguments to `trainer.policy.megatron_config` in the launch script. +`optimizer_config_kwargs` accepts dtype-name strings for its `*_dtype` fields (`main_params_dtype`, `exp_avg_dtype`, `exp_avg_sq_dtype`, `params_dtype`), which are coerced to `torch.dtype` before constructing `OptimizerConfig`. This lets you configure low-precision optimizer state from YAML: + +```yaml +optimizer_config_kwargs: + use_precision_aware_optimizer: true + exp_avg_dtype: bf16 + exp_avg_sq_dtype: fp8 + main_params_dtype: fp32 +``` + +The accepted short forms are `fp32` (aliases `float32`/`float`), `fp16` (`float16`/`half`), `bf16` (`bfloat16`), and `fp8` (`float8`/`uint8`); `fp8` maps to `torch.uint8` since TransformerEngine represents FP8 optimizer state as `uint8`. `main_params_dtype` is restricted to `fp32`/`fp16`, while `exp_avg_dtype`/`exp_avg_sq_dtype` additionally allow `bf16`/`fp8`; illegal values fail fast with a `ValueError`. See the [Megatron configuration guide](../configuration/config#megatron-configuration) for the full table. + ## Parallelism Resources Understanding and configuring parallelism strategies for large models can be challenging. diff --git a/skyrl/backends/skyrl_train/distributed/megatron/optimizer.py b/skyrl/backends/skyrl_train/distributed/megatron/optimizer.py index c1cdfbdd3b..8680332e85 100644 --- a/skyrl/backends/skyrl_train/distributed/megatron/optimizer.py +++ b/skyrl/backends/skyrl_train/distributed/megatron/optimizer.py @@ -26,6 +26,9 @@ from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler from omegaconf import DictConfig +from skyrl.backends.skyrl_train.distributed.megatron.optimizer_dtype import ( + coerce_optimizer_dtype_kwargs, +) from skyrl.train.config import OptimizerConfig as SkyRLOptimizerConfig @@ -45,7 +48,9 @@ def init_megatron_optim_config( "params_dtype": torch.bfloat16, "use_distributed_optimizer": True, } - optim_args.update(optimizer_config_kwargs) + # Coerce any ``*_dtype`` string (e.g. "bf16" from YAML) into a real torch.dtype + # before it reaches Megatron's OptimizerConfig / FusedAdam, which require dtypes. + optim_args.update(coerce_optimizer_dtype_kwargs(optimizer_config_kwargs)) config = OptimizerConfig(**optim_args) return config diff --git a/skyrl/backends/skyrl_train/distributed/megatron/optimizer_dtype.py b/skyrl/backends/skyrl_train/distributed/megatron/optimizer_dtype.py new file mode 100644 index 0000000000..79dd29a136 --- /dev/null +++ b/skyrl/backends/skyrl_train/distributed/megatron/optimizer_dtype.py @@ -0,0 +1,96 @@ +"""Pure-Python (torch-only) coercion of Megatron optimizer ``*_dtype`` kwargs. + +This is intentionally kept free of any ``megatron.core`` import so the coercion +logic can be unit-tested on the cheap CPU CI lane (which installs torch but not +megatron-core). ``optimizer.py`` imports the public helper from here. +""" + +from typing import Any, Dict, Set + +import torch + +# Canonical dtype-name -> torch.dtype mapping, mirroring Megatron-LM's own +# ``dtype_map`` in ``megatron/training/arguments.py`` (the short forms 'fp32', +# 'bf16', 'fp16', 'fp8' are the ones Megatron itself maps). The extra aliases +# ('bfloat16', 'float16'/'half', 'float32'/'float', 'float8'/'uint8') accept the +# spellings a user might reasonably write in YAML. ``fp8`` maps to ``torch.uint8`` +# because TransformerEngine represents FP8 optimizer state as uint8. +_DTYPE_NAME_TO_TORCH: Dict[str, torch.dtype] = { + "fp32": torch.float32, + "float32": torch.float32, + "float": torch.float32, + "bf16": torch.bfloat16, + "bfloat16": torch.bfloat16, + "fp16": torch.float16, + "float16": torch.float16, + "half": torch.float16, + "fp8": torch.uint8, + "float8": torch.uint8, + "uint8": torch.uint8, +} + +# Per-field legal dtypes, enforced before the kwargs reach FusedAdam (which would +# otherwise raise a cryptic error deep inside TransformerEngine). Only fields that +# are actually forwarded to TE FusedAdam — and whose accepted set is verifiable in +# the TE source — are listed here. ``main_params_dtype`` (a.k.a. master weights) +# maps to FusedAdam's ``master_weight_dtype``, which only supports fp32/fp16; +# ``exp_avg_dtype`` / ``exp_avg_sq_dtype`` additionally allow bf16/fp8. +# ``main_grads_dtype`` is deliberately NOT listed: at the pinned megatron-core rev +# it is not forwarded to FusedAdam (see megatron/core/optimizer/__init__.py, which +# only passes exp_avg_dtype, exp_avg_sq_dtype, and main_params_dtype as +# master_weight_dtype), so there is no TE-backed legal set to enforce — it is still +# coerced str->dtype here and left for ``OptimizerConfig.__post_init__`` to validate +# (mirroring how ``params_dtype`` is handled: coerced, not field-validated). +# Fields not listed here accept any value in ``_DTYPE_NAME_TO_TORCH``. +_LEGAL_FIELD_DTYPES: Dict[str, Set[torch.dtype]] = { + "main_params_dtype": {torch.float32, torch.float16}, + "exp_avg_dtype": {torch.float32, torch.bfloat16, torch.float16, torch.uint8}, + "exp_avg_sq_dtype": {torch.float32, torch.bfloat16, torch.float16, torch.uint8}, +} + + +def coerce_optimizer_dtype_kwargs(optimizer_config_kwargs: Dict[str, Any]) -> Dict[str, Any]: + """Coerce ``*_dtype`` string values in Megatron optimizer kwargs to ``torch.dtype``. + + Megatron's precision-aware ``OptimizerConfig`` types ``exp_avg_dtype`` / + ``exp_avg_sq_dtype`` / ``main_params_dtype`` (and friends) as real ``torch.dtype``, + but SkyRL forwards ``optimizer_config_kwargs`` verbatim from YAML/Hydra, which + delivers plain strings (e.g. ``"bf16"``). This converts any ``*_dtype`` key whose + value is a dtype-name string into the corresponding ``torch.dtype`` using Megatron-LM's + canonical mapping, validates the result against the per-field legal set, and leaves + everything else (non-``*_dtype`` keys, values already ``torch.dtype``) untouched. + + Returns a new dict; the input is not mutated. + + Raises: + ValueError: if a ``*_dtype`` value is an unrecognized dtype name, or if a coerced + dtype is illegal for that specific field (e.g. bf16/fp8 for ``main_params_dtype``). + """ + coerced: Dict[str, Any] = {} + for key, value in optimizer_config_kwargs.items(): + if not key.endswith("_dtype"): + coerced[key] = value + continue + + if isinstance(value, torch.dtype): + dtype = value + elif isinstance(value, str): + name = value.strip().lower() + if name not in _DTYPE_NAME_TO_TORCH: + raise ValueError( + f"Unrecognized dtype name {value!r} for optimizer kwarg {key!r}. " + f"Expected one of {sorted(_DTYPE_NAME_TO_TORCH)} or a torch.dtype." + ) + dtype = _DTYPE_NAME_TO_TORCH[name] + else: + # Not a dtype-name string or torch.dtype (e.g. None); pass through so + # Megatron's own validation surfaces any problem. + coerced[key] = value + continue + + legal = _LEGAL_FIELD_DTYPES.get(key) + if legal is not None and dtype not in legal: + legal_names = sorted({n for n, d in _DTYPE_NAME_TO_TORCH.items() if d in legal}) + raise ValueError(f"Illegal dtype {dtype} for optimizer kwarg {key!r}; legal values are {legal_names}.") + coerced[key] = dtype + return coerced diff --git a/tests/backends/skyrl_train/distributed/test_optimizer_dtype_coercion.py b/tests/backends/skyrl_train/distributed/test_optimizer_dtype_coercion.py new file mode 100644 index 0000000000..2edc6a7509 --- /dev/null +++ b/tests/backends/skyrl_train/distributed/test_optimizer_dtype_coercion.py @@ -0,0 +1,207 @@ +"""CPU tests for central Megatron optimizer-state dtype coercion. + +Verifies that ``*_dtype`` string values forwarded through +``MegatronConfig.optimizer_config_kwargs`` (e.g. "bf16" from YAML/Hydra) are +coerced to real ``torch.dtype`` before reaching Megatron's precision-aware +``OptimizerConfig`` / TransformerEngine FusedAdam, using Megatron-LM's own +canonical dtype mapping; that illegal values for ``main_params_dtype`` (master +weights, fp32/fp16 only) are rejected; and that unrelated kwargs pass through +untouched. + +``TestCoerceOptimizerDtypeKwargs`` exercises the pure-Python (torch-only) +coercion helper, which lives in a module that does NOT import megatron-core, so +it runs unconditionally on the cheap CPU CI lane (which installs torch but not +megatron-core). ``TestInitMegatronOptimConfigDtypeCoercion`` builds a real +``OptimizerConfig`` and is therefore skipped when megatron-core is not installed +(mirroring ``test_megatron_correctness.py``). No CUDA is required by either. + +uv run --isolated --extra megatron --extra dev pytest \ + tests/backends/skyrl_train/distributed/test_optimizer_dtype_coercion.py -v +""" + +import sys + +import pytest +import torch + +from skyrl.backends.skyrl_train.distributed.megatron.optimizer_dtype import ( + coerce_optimizer_dtype_kwargs, +) + +_has_megatron = "megatron" in sys.modules or __import__("importlib").util.find_spec("megatron") is not None + + +class TestCoerceOptimizerDtypeKwargs: + """``coerce_optimizer_dtype_kwargs`` maps dtype-name strings to torch.dtype. + + Runs unconditionally: the helper module is torch-only (no megatron-core). + """ + + def _coerce(self, kwargs: dict) -> dict: + return coerce_optimizer_dtype_kwargs(kwargs) + + @pytest.mark.parametrize( + "name,expected", + [ + ("bf16", torch.bfloat16), + ("bfloat16", torch.bfloat16), + ("fp16", torch.float16), + ("float16", torch.float16), + ("half", torch.float16), + ("fp32", torch.float32), + ("float32", torch.float32), + ("float", torch.float32), + ("fp8", torch.uint8), + ("float8", torch.uint8), + ("uint8", torch.uint8), + ], + ) + def test_string_names_coerce_to_torch_dtype(self, name, expected): + """Each canonical/alias dtype name maps to the right torch.dtype.""" + # exp_avg_dtype legally accepts fp32/fp16/bf16/fp8, so it can exercise all names. + out = self._coerce({"exp_avg_dtype": name}) + assert out["exp_avg_dtype"] == expected + assert isinstance(out["exp_avg_dtype"], torch.dtype) + + def test_fp8_maps_to_uint8(self): + """TE represents fp8 optimizer state as uint8.""" + out = self._coerce({"exp_avg_sq_dtype": "fp8"}) + assert out["exp_avg_sq_dtype"] is torch.uint8 + + def test_case_and_whitespace_insensitive(self): + out = self._coerce({"exp_avg_dtype": " BF16 "}) + assert out["exp_avg_dtype"] is torch.bfloat16 + + def test_already_torch_dtype_passes_through(self): + """A value already a torch.dtype is preserved as-is.""" + out = self._coerce({"exp_avg_dtype": torch.bfloat16}) + assert out["exp_avg_dtype"] is torch.bfloat16 + + def test_main_params_dtype_accepts_fp32_and_fp16(self): + """main_params_dtype (master weights) legally accepts only fp32/fp16.""" + assert self._coerce({"main_params_dtype": "fp32"})["main_params_dtype"] is torch.float32 + assert self._coerce({"main_params_dtype": "fp16"})["main_params_dtype"] is torch.float16 + + @pytest.mark.parametrize("bad", ["bf16", "fp8"]) + def test_main_params_dtype_rejects_bf16_and_fp8(self, bad): + """bf16/fp8 are illegal master-weight dtypes and must raise.""" + with pytest.raises(ValueError, match="main_params_dtype"): + self._coerce({"main_params_dtype": bad}) + + @pytest.mark.parametrize( + "name,expected", [("bf16", torch.bfloat16), ("fp16", torch.float16), ("fp32", torch.float32)] + ) + def test_params_dtype_is_coerced_with_no_field_restriction(self, name, expected): + """``params_dtype`` ends in ``_dtype`` and has no legal-set entry, so it is + coerced for any recognized alias and overrides the bf16 default that + ``init_megatron_optim_config`` seeds (see optimizer.py).""" + out = self._coerce({"params_dtype": name}) + assert out["params_dtype"] is expected + + def test_main_grads_dtype_coerced_but_not_field_validated(self): + """``main_grads_dtype`` is not forwarded to FusedAdam at the pinned rev, so it + has no legal-set row: it is coerced str->dtype but any value is accepted here, + leaving megatron-core's ``__post_init__`` to validate it. bf16 (which a legal + set would reject) coerces fine.""" + out = self._coerce({"main_grads_dtype": "bf16"}) + assert out["main_grads_dtype"] is torch.bfloat16 + + def test_unrecognized_dtype_name_raises(self): + with pytest.raises(ValueError, match="Unrecognized dtype name"): + self._coerce({"exp_avg_dtype": "bf17"}) + + def test_unrelated_kwargs_pass_through_untouched(self): + """Non-``*_dtype`` keys are returned unchanged.""" + kwargs = { + "use_precision_aware_optimizer": True, + "optimizer_offload_fraction": 0.5, + "overlap_cpu_optimizer_d2h_h2d": False, + "exp_avg_dtype": "bf16", + } + out = self._coerce(kwargs) + assert out["use_precision_aware_optimizer"] is True + assert out["optimizer_offload_fraction"] == 0.5 + assert out["overlap_cpu_optimizer_d2h_h2d"] is False + assert out["exp_avg_dtype"] is torch.bfloat16 + + def test_non_string_non_dtype_dtype_value_passes_through(self): + """A ``*_dtype`` key whose value is neither a dtype name nor torch.dtype + (e.g. None) passes through so Megatron's own validation surfaces it.""" + out = self._coerce({"main_grads_dtype": None}) + assert out["main_grads_dtype"] is None + + def test_input_not_mutated(self): + """The helper returns a new dict and does not mutate the input.""" + kwargs = {"exp_avg_dtype": "bf16"} + self._coerce(kwargs) + assert kwargs["exp_avg_dtype"] == "bf16" + + +@pytest.mark.skipif(not _has_megatron, reason="megatron-core not installed") +class TestInitMegatronOptimConfigDtypeCoercion: + """End-to-end: ``init_megatron_optim_config`` builds a real OptimizerConfig + with coerced dtypes from string kwargs.""" + + def test_string_dtype_kwargs_reach_optimizer_config(self): + from skyrl.backends.skyrl_train.distributed.megatron.optimizer import ( + init_megatron_optim_config, + ) + from skyrl.train.config import OptimizerConfig as SkyRLOptimizerConfig + + optim_config = SkyRLOptimizerConfig() + config = init_megatron_optim_config( + optim_config, + { + "use_precision_aware_optimizer": True, + "exp_avg_dtype": "bf16", + "exp_avg_sq_dtype": "fp8", + "main_params_dtype": "fp32", + }, + ) + assert config.exp_avg_dtype is torch.bfloat16 + assert config.exp_avg_sq_dtype is torch.uint8 + assert config.main_params_dtype is torch.float32 + + def test_params_dtype_string_override_reaches_optimizer_config(self): + """A string ``params_dtype`` override is coerced and replaces the seeded + bf16 default in the constructed OptimizerConfig.""" + from skyrl.backends.skyrl_train.distributed.megatron.optimizer import ( + init_megatron_optim_config, + ) + from skyrl.train.config import OptimizerConfig as SkyRLOptimizerConfig + + config = init_megatron_optim_config(SkyRLOptimizerConfig(), {"params_dtype": "fp16"}) + assert config.params_dtype is torch.float16 + + def test_default_kwargs_leave_dtypes_at_megatron_defaults(self): + """With no ``*_dtype`` overrides, OptimizerConfig keeps its fp32 defaults + (byte-identical to prior behavior).""" + from skyrl.backends.skyrl_train.distributed.megatron.optimizer import ( + init_megatron_optim_config, + ) + from skyrl.train.config import OptimizerConfig as SkyRLOptimizerConfig + + config = init_megatron_optim_config(SkyRLOptimizerConfig(), {}) + assert config.exp_avg_dtype is torch.float32 + assert config.exp_avg_sq_dtype is torch.float32 + assert config.main_params_dtype is torch.float32 + + def test_precision_aware_off_with_nonfp32_state_fast_fails_in_megatron(self): + """Coercing ``exp_avg_dtype='bf16'`` passes the helper's own validation, but + megatron-core's ``OptimizerConfig.__post_init__`` then asserts that + exp_avg_dtype can only be fp32 when ``use_precision_aware_optimizer`` is False. + + This documents that the fast-fail is megatron's (a real AssertionError), + not a silent mis-coercion: the helper coerces the string fine, the rejection + happens downstream in OptimizerConfig construction. + """ + from skyrl.backends.skyrl_train.distributed.megatron.optimizer import ( + init_megatron_optim_config, + ) + from skyrl.train.config import OptimizerConfig as SkyRLOptimizerConfig + + with pytest.raises(AssertionError, match="exp_avg_dtype can only be fp32"): + init_megatron_optim_config( + SkyRLOptimizerConfig(), + {"use_precision_aware_optimizer": False, "exp_avg_dtype": "bf16"}, + )