-
Notifications
You must be signed in to change notification settings - Fork 359
[megatron] Accept dtype-string optimizer_config_kwargs (coerce exp_avg_dtype etc. to torch.dtype) #1805
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
dyurk-lila
wants to merge
1
commit into
NovaSky-AI:main
Choose a base branch
from
dyurk-lila:feat/optimizer-state-dtype-coercion
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
[megatron] Accept dtype-string optimizer_config_kwargs (coerce exp_avg_dtype etc. to torch.dtype) #1805
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
96 changes: 96 additions & 0 deletions
96
skyrl/backends/skyrl_train/distributed/megatron/optimizer_dtype.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,96 @@ | ||
| """Pure-Python (torch-only) coercion of Megatron optimizer ``*_dtype`` kwargs. | ||
|
|
||
| This is intentionally kept free of any ``megatron.core`` import so the coercion | ||
| logic can be unit-tested on the cheap CPU CI lane (which installs torch but not | ||
| megatron-core). ``optimizer.py`` imports the public helper from here. | ||
| """ | ||
|
|
||
| from typing import Any, Dict, Set | ||
|
|
||
| import torch | ||
|
|
||
| # Canonical dtype-name -> torch.dtype mapping, mirroring Megatron-LM's own | ||
| # ``dtype_map`` in ``megatron/training/arguments.py`` (the short forms 'fp32', | ||
| # 'bf16', 'fp16', 'fp8' are the ones Megatron itself maps). The extra aliases | ||
| # ('bfloat16', 'float16'/'half', 'float32'/'float', 'float8'/'uint8') accept the | ||
| # spellings a user might reasonably write in YAML. ``fp8`` maps to ``torch.uint8`` | ||
| # because TransformerEngine represents FP8 optimizer state as uint8. | ||
| _DTYPE_NAME_TO_TORCH: Dict[str, torch.dtype] = { | ||
| "fp32": torch.float32, | ||
| "float32": torch.float32, | ||
| "float": torch.float32, | ||
| "bf16": torch.bfloat16, | ||
| "bfloat16": torch.bfloat16, | ||
| "fp16": torch.float16, | ||
| "float16": torch.float16, | ||
| "half": torch.float16, | ||
| "fp8": torch.uint8, | ||
| "float8": torch.uint8, | ||
| "uint8": torch.uint8, | ||
| } | ||
|
|
||
| # Per-field legal dtypes, enforced before the kwargs reach FusedAdam (which would | ||
| # otherwise raise a cryptic error deep inside TransformerEngine). Only fields that | ||
| # are actually forwarded to TE FusedAdam — and whose accepted set is verifiable in | ||
| # the TE source — are listed here. ``main_params_dtype`` (a.k.a. master weights) | ||
| # maps to FusedAdam's ``master_weight_dtype``, which only supports fp32/fp16; | ||
| # ``exp_avg_dtype`` / ``exp_avg_sq_dtype`` additionally allow bf16/fp8. | ||
| # ``main_grads_dtype`` is deliberately NOT listed: at the pinned megatron-core rev | ||
| # it is not forwarded to FusedAdam (see megatron/core/optimizer/__init__.py, which | ||
| # only passes exp_avg_dtype, exp_avg_sq_dtype, and main_params_dtype as | ||
| # master_weight_dtype), so there is no TE-backed legal set to enforce — it is still | ||
| # coerced str->dtype here and left for ``OptimizerConfig.__post_init__`` to validate | ||
| # (mirroring how ``params_dtype`` is handled: coerced, not field-validated). | ||
| # Fields not listed here accept any value in ``_DTYPE_NAME_TO_TORCH``. | ||
| _LEGAL_FIELD_DTYPES: Dict[str, Set[torch.dtype]] = { | ||
| "main_params_dtype": {torch.float32, torch.float16}, | ||
| "exp_avg_dtype": {torch.float32, torch.bfloat16, torch.float16, torch.uint8}, | ||
| "exp_avg_sq_dtype": {torch.float32, torch.bfloat16, torch.float16, torch.uint8}, | ||
| } | ||
|
|
||
|
|
||
| def coerce_optimizer_dtype_kwargs(optimizer_config_kwargs: Dict[str, Any]) -> Dict[str, Any]: | ||
| """Coerce ``*_dtype`` string values in Megatron optimizer kwargs to ``torch.dtype``. | ||
|
|
||
| Megatron's precision-aware ``OptimizerConfig`` types ``exp_avg_dtype`` / | ||
| ``exp_avg_sq_dtype`` / ``main_params_dtype`` (and friends) as real ``torch.dtype``, | ||
| but SkyRL forwards ``optimizer_config_kwargs`` verbatim from YAML/Hydra, which | ||
| delivers plain strings (e.g. ``"bf16"``). This converts any ``*_dtype`` key whose | ||
| value is a dtype-name string into the corresponding ``torch.dtype`` using Megatron-LM's | ||
| canonical mapping, validates the result against the per-field legal set, and leaves | ||
| everything else (non-``*_dtype`` keys, values already ``torch.dtype``) untouched. | ||
|
|
||
| Returns a new dict; the input is not mutated. | ||
|
|
||
| Raises: | ||
| ValueError: if a ``*_dtype`` value is an unrecognized dtype name, or if a coerced | ||
| dtype is illegal for that specific field (e.g. bf16/fp8 for ``main_params_dtype``). | ||
| """ | ||
| coerced: Dict[str, Any] = {} | ||
| for key, value in optimizer_config_kwargs.items(): | ||
| if not key.endswith("_dtype"): | ||
| coerced[key] = value | ||
| continue | ||
|
|
||
| if isinstance(value, torch.dtype): | ||
| dtype = value | ||
| elif isinstance(value, str): | ||
| name = value.strip().lower() | ||
| if name not in _DTYPE_NAME_TO_TORCH: | ||
| raise ValueError( | ||
| f"Unrecognized dtype name {value!r} for optimizer kwarg {key!r}. " | ||
| f"Expected one of {sorted(_DTYPE_NAME_TO_TORCH)} or a torch.dtype." | ||
| ) | ||
| dtype = _DTYPE_NAME_TO_TORCH[name] | ||
| else: | ||
| # Not a dtype-name string or torch.dtype (e.g. None); pass through so | ||
| # Megatron's own validation surfaces any problem. | ||
| coerced[key] = value | ||
| continue | ||
|
|
||
| legal = _LEGAL_FIELD_DTYPES.get(key) | ||
| if legal is not None and dtype not in legal: | ||
| legal_names = sorted({n for n, d in _DTYPE_NAME_TO_TORCH.items() if d in legal}) | ||
| raise ValueError(f"Illegal dtype {dtype} for optimizer kwarg {key!r}; legal values are {legal_names}.") | ||
| coerced[key] = dtype | ||
| return coerced | ||
207 changes: 207 additions & 0 deletions
207
tests/backends/skyrl_train/distributed/test_optimizer_dtype_coercion.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,207 @@ | ||
| """CPU tests for central Megatron optimizer-state dtype coercion. | ||
|
|
||
| Verifies that ``*_dtype`` string values forwarded through | ||
| ``MegatronConfig.optimizer_config_kwargs`` (e.g. "bf16" from YAML/Hydra) are | ||
| coerced to real ``torch.dtype`` before reaching Megatron's precision-aware | ||
| ``OptimizerConfig`` / TransformerEngine FusedAdam, using Megatron-LM's own | ||
| canonical dtype mapping; that illegal values for ``main_params_dtype`` (master | ||
| weights, fp32/fp16 only) are rejected; and that unrelated kwargs pass through | ||
| untouched. | ||
|
|
||
| ``TestCoerceOptimizerDtypeKwargs`` exercises the pure-Python (torch-only) | ||
| coercion helper, which lives in a module that does NOT import megatron-core, so | ||
| it runs unconditionally on the cheap CPU CI lane (which installs torch but not | ||
| megatron-core). ``TestInitMegatronOptimConfigDtypeCoercion`` builds a real | ||
| ``OptimizerConfig`` and is therefore skipped when megatron-core is not installed | ||
| (mirroring ``test_megatron_correctness.py``). No CUDA is required by either. | ||
|
|
||
| uv run --isolated --extra megatron --extra dev pytest \ | ||
| tests/backends/skyrl_train/distributed/test_optimizer_dtype_coercion.py -v | ||
| """ | ||
|
|
||
| import sys | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from skyrl.backends.skyrl_train.distributed.megatron.optimizer_dtype import ( | ||
| coerce_optimizer_dtype_kwargs, | ||
| ) | ||
|
|
||
| _has_megatron = "megatron" in sys.modules or __import__("importlib").util.find_spec("megatron") is not None | ||
|
|
||
|
|
||
| class TestCoerceOptimizerDtypeKwargs: | ||
| """``coerce_optimizer_dtype_kwargs`` maps dtype-name strings to torch.dtype. | ||
|
|
||
| Runs unconditionally: the helper module is torch-only (no megatron-core). | ||
| """ | ||
|
|
||
| def _coerce(self, kwargs: dict) -> dict: | ||
| return coerce_optimizer_dtype_kwargs(kwargs) | ||
|
|
||
| @pytest.mark.parametrize( | ||
| "name,expected", | ||
| [ | ||
| ("bf16", torch.bfloat16), | ||
| ("bfloat16", torch.bfloat16), | ||
| ("fp16", torch.float16), | ||
| ("float16", torch.float16), | ||
| ("half", torch.float16), | ||
| ("fp32", torch.float32), | ||
| ("float32", torch.float32), | ||
| ("float", torch.float32), | ||
| ("fp8", torch.uint8), | ||
| ("float8", torch.uint8), | ||
| ("uint8", torch.uint8), | ||
| ], | ||
| ) | ||
| def test_string_names_coerce_to_torch_dtype(self, name, expected): | ||
| """Each canonical/alias dtype name maps to the right torch.dtype.""" | ||
| # exp_avg_dtype legally accepts fp32/fp16/bf16/fp8, so it can exercise all names. | ||
| out = self._coerce({"exp_avg_dtype": name}) | ||
| assert out["exp_avg_dtype"] == expected | ||
| assert isinstance(out["exp_avg_dtype"], torch.dtype) | ||
|
|
||
| def test_fp8_maps_to_uint8(self): | ||
| """TE represents fp8 optimizer state as uint8.""" | ||
| out = self._coerce({"exp_avg_sq_dtype": "fp8"}) | ||
| assert out["exp_avg_sq_dtype"] is torch.uint8 | ||
|
|
||
| def test_case_and_whitespace_insensitive(self): | ||
| out = self._coerce({"exp_avg_dtype": " BF16 "}) | ||
| assert out["exp_avg_dtype"] is torch.bfloat16 | ||
|
|
||
| def test_already_torch_dtype_passes_through(self): | ||
| """A value already a torch.dtype is preserved as-is.""" | ||
| out = self._coerce({"exp_avg_dtype": torch.bfloat16}) | ||
| assert out["exp_avg_dtype"] is torch.bfloat16 | ||
|
|
||
| def test_main_params_dtype_accepts_fp32_and_fp16(self): | ||
| """main_params_dtype (master weights) legally accepts only fp32/fp16.""" | ||
| assert self._coerce({"main_params_dtype": "fp32"})["main_params_dtype"] is torch.float32 | ||
| assert self._coerce({"main_params_dtype": "fp16"})["main_params_dtype"] is torch.float16 | ||
|
|
||
| @pytest.mark.parametrize("bad", ["bf16", "fp8"]) | ||
| def test_main_params_dtype_rejects_bf16_and_fp8(self, bad): | ||
| """bf16/fp8 are illegal master-weight dtypes and must raise.""" | ||
| with pytest.raises(ValueError, match="main_params_dtype"): | ||
| self._coerce({"main_params_dtype": bad}) | ||
|
|
||
| @pytest.mark.parametrize( | ||
| "name,expected", [("bf16", torch.bfloat16), ("fp16", torch.float16), ("fp32", torch.float32)] | ||
| ) | ||
| def test_params_dtype_is_coerced_with_no_field_restriction(self, name, expected): | ||
| """``params_dtype`` ends in ``_dtype`` and has no legal-set entry, so it is | ||
| coerced for any recognized alias and overrides the bf16 default that | ||
| ``init_megatron_optim_config`` seeds (see optimizer.py).""" | ||
| out = self._coerce({"params_dtype": name}) | ||
| assert out["params_dtype"] is expected | ||
|
|
||
| def test_main_grads_dtype_coerced_but_not_field_validated(self): | ||
| """``main_grads_dtype`` is not forwarded to FusedAdam at the pinned rev, so it | ||
| has no legal-set row: it is coerced str->dtype but any value is accepted here, | ||
| leaving megatron-core's ``__post_init__`` to validate it. bf16 (which a legal | ||
| set would reject) coerces fine.""" | ||
| out = self._coerce({"main_grads_dtype": "bf16"}) | ||
| assert out["main_grads_dtype"] is torch.bfloat16 | ||
|
|
||
| def test_unrecognized_dtype_name_raises(self): | ||
| with pytest.raises(ValueError, match="Unrecognized dtype name"): | ||
| self._coerce({"exp_avg_dtype": "bf17"}) | ||
|
|
||
| def test_unrelated_kwargs_pass_through_untouched(self): | ||
| """Non-``*_dtype`` keys are returned unchanged.""" | ||
| kwargs = { | ||
| "use_precision_aware_optimizer": True, | ||
| "optimizer_offload_fraction": 0.5, | ||
| "overlap_cpu_optimizer_d2h_h2d": False, | ||
| "exp_avg_dtype": "bf16", | ||
| } | ||
| out = self._coerce(kwargs) | ||
| assert out["use_precision_aware_optimizer"] is True | ||
| assert out["optimizer_offload_fraction"] == 0.5 | ||
| assert out["overlap_cpu_optimizer_d2h_h2d"] is False | ||
| assert out["exp_avg_dtype"] is torch.bfloat16 | ||
|
|
||
| def test_non_string_non_dtype_dtype_value_passes_through(self): | ||
| """A ``*_dtype`` key whose value is neither a dtype name nor torch.dtype | ||
| (e.g. None) passes through so Megatron's own validation surfaces it.""" | ||
| out = self._coerce({"main_grads_dtype": None}) | ||
| assert out["main_grads_dtype"] is None | ||
|
|
||
| def test_input_not_mutated(self): | ||
| """The helper returns a new dict and does not mutate the input.""" | ||
| kwargs = {"exp_avg_dtype": "bf16"} | ||
| self._coerce(kwargs) | ||
| assert kwargs["exp_avg_dtype"] == "bf16" | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not _has_megatron, reason="megatron-core not installed") | ||
| class TestInitMegatronOptimConfigDtypeCoercion: | ||
| """End-to-end: ``init_megatron_optim_config`` builds a real OptimizerConfig | ||
| with coerced dtypes from string kwargs.""" | ||
|
|
||
| def test_string_dtype_kwargs_reach_optimizer_config(self): | ||
| from skyrl.backends.skyrl_train.distributed.megatron.optimizer import ( | ||
| init_megatron_optim_config, | ||
| ) | ||
| from skyrl.train.config import OptimizerConfig as SkyRLOptimizerConfig | ||
|
|
||
| optim_config = SkyRLOptimizerConfig() | ||
| config = init_megatron_optim_config( | ||
| optim_config, | ||
| { | ||
| "use_precision_aware_optimizer": True, | ||
| "exp_avg_dtype": "bf16", | ||
| "exp_avg_sq_dtype": "fp8", | ||
| "main_params_dtype": "fp32", | ||
| }, | ||
| ) | ||
| assert config.exp_avg_dtype is torch.bfloat16 | ||
| assert config.exp_avg_sq_dtype is torch.uint8 | ||
| assert config.main_params_dtype is torch.float32 | ||
|
|
||
| def test_params_dtype_string_override_reaches_optimizer_config(self): | ||
| """A string ``params_dtype`` override is coerced and replaces the seeded | ||
| bf16 default in the constructed OptimizerConfig.""" | ||
| from skyrl.backends.skyrl_train.distributed.megatron.optimizer import ( | ||
| init_megatron_optim_config, | ||
| ) | ||
| from skyrl.train.config import OptimizerConfig as SkyRLOptimizerConfig | ||
|
|
||
| config = init_megatron_optim_config(SkyRLOptimizerConfig(), {"params_dtype": "fp16"}) | ||
| assert config.params_dtype is torch.float16 | ||
|
|
||
| def test_default_kwargs_leave_dtypes_at_megatron_defaults(self): | ||
| """With no ``*_dtype`` overrides, OptimizerConfig keeps its fp32 defaults | ||
| (byte-identical to prior behavior).""" | ||
| from skyrl.backends.skyrl_train.distributed.megatron.optimizer import ( | ||
| init_megatron_optim_config, | ||
| ) | ||
| from skyrl.train.config import OptimizerConfig as SkyRLOptimizerConfig | ||
|
|
||
| config = init_megatron_optim_config(SkyRLOptimizerConfig(), {}) | ||
| assert config.exp_avg_dtype is torch.float32 | ||
| assert config.exp_avg_sq_dtype is torch.float32 | ||
| assert config.main_params_dtype is torch.float32 | ||
|
|
||
| def test_precision_aware_off_with_nonfp32_state_fast_fails_in_megatron(self): | ||
| """Coercing ``exp_avg_dtype='bf16'`` passes the helper's own validation, but | ||
| megatron-core's ``OptimizerConfig.__post_init__`` then asserts that | ||
| exp_avg_dtype can only be fp32 when ``use_precision_aware_optimizer`` is False. | ||
|
|
||
| This documents that the fast-fail is megatron's (a real AssertionError), | ||
| not a silent mis-coercion: the helper coerces the string fine, the rejection | ||
| happens downstream in OptimizerConfig construction. | ||
| """ | ||
| from skyrl.backends.skyrl_train.distributed.megatron.optimizer import ( | ||
| init_megatron_optim_config, | ||
| ) | ||
| from skyrl.train.config import OptimizerConfig as SkyRLOptimizerConfig | ||
|
|
||
| with pytest.raises(AssertionError, match="exp_avg_dtype can only be fp32"): | ||
| init_megatron_optim_config( | ||
| SkyRLOptimizerConfig(), | ||
| {"use_precision_aware_optimizer": False, "exp_avg_dtype": "bf16"}, | ||
| ) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If
optimizer_config_kwargsisNone(e.g., if it is omitted or set tonullin the YAML configuration), calling.items()on it will raise anAttributeError. Adding a defensiveNonecheck at the beginning of the function ensures robustness and prevents runtime crashes.