Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions skyrl/backends/skyrl_train/distributed/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from transformers import GenerationConfig, PretrainedConfig, PreTrainedTokenizer

from skyrl.backends.skyrl_train.utils.io import io
from skyrl.utils.tok import check_is_vlm, get_processor

DataT = TypeVar("DataT", bound=Union[Dict[str, Any], torch.Tensor])

Expand Down Expand Up @@ -138,6 +139,19 @@ def save_hf_configs(self, model_config: PretrainedConfig, hf_dir: str, tokenizer
# if the generation config isn't available, we don't save it
logger.warning(f"Could not save generation config for '{model_config.name_or_path}'. Error: {e}")

# VLMs need preprocessor_config.json (+ image/video processor configs) to be
# reloadable by AutoProcessor; the tokenizer alone is insufficient and vLLM
# crashes on load without it. Resolve from the original base model, same as
# generation_config above. No-op for text-only models (no vision_config).
# The VLM check reuses the already-loaded model_config (no extra I/O) and the
# whole block is guarded so a processor-resolution failure can't abort the save.
try:
if check_is_vlm(model_config):
processor = get_processor(model_config.name_or_path)
processor.save_pretrained(work_dir)
except Exception as e:
logger.warning(f"Could not save processor for '{model_config.name_or_path}'. Error: {e}")

@staticmethod
def get_rng_state():
"""Get current RNG state for reproducibility"""
Expand Down
32 changes: 31 additions & 1 deletion skyrl/utils/tok.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,36 @@
"""Tokenization related utilities"""

from transformers import AutoTokenizer, PreTrainedTokenizerFast
from transformers import (
AutoConfig,
AutoProcessor,
AutoTokenizer,
PreTrainedTokenizerFast,
)


def check_is_vlm(model_config_or_path) -> bool:
"""Returns True if the model config declares a non-null ``vision_config``.

Accepts either an already-loaded ``PretrainedConfig`` or a model name/path.
Passing the config avoids a redundant ``AutoConfig.from_pretrained`` round-trip
when the caller already holds it."""
if isinstance(model_config_or_path, str):
model_config = AutoConfig.from_pretrained(model_config_or_path, trust_remote_code=True)
else:
model_config = model_config_or_path
return hasattr(model_config, "vision_config") and getattr(model_config, "vision_config") is not None


def get_processor(model_name_or_path, **tokenizer_kwargs) -> AutoProcessor:
"""Gets processor for the given base model with the given parameters

Sets the pad token ID to EOS token ID if `None`"""
tokenizer_kwargs.setdefault("trust_remote_code", True)
processor = AutoProcessor.from_pretrained(model_name_or_path, **tokenizer_kwargs)
if processor.tokenizer.pad_token_id is None:
processor.tokenizer.pad_token_id = processor.tokenizer.eos_token_id
processor.tokenizer.pad_token = processor.tokenizer.eos_token
return processor
Comment thread
dinhxuanvu marked this conversation as resolved.


def get_tokenizer(model_name_or_path, **tokenizer_kwargs) -> AutoTokenizer:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""CPU tests for DistributedStrategy.save_hf_configs processor handling.

Verifies that a vision-language checkpoint export writes the HF processor
(``preprocessor_config.json`` etc.) and that text-only exports do not, without
requiring GPUs, distributed init, or network access.
"""

from unittest.mock import MagicMock, patch

from skyrl.backends.skyrl_train.distributed.strategy import DistributedStrategy


class _StubStrategy(DistributedStrategy):
"""Concrete strategy that stubs every abstractmethod so the CPU-only
``save_hf_configs`` path can be exercised in isolation."""

def setup_distributed(self): # pragma: no cover - stub
pass

def backward(self, loss, model, optimizer, **kwargs): # pragma: no cover - stub
pass

def optimizer_step(self, optimizer, model, scheduler, name="model", **kwargs): # pragma: no cover - stub
pass

def save_checkpoint(self, model, ckpt_dir, node_local_rank, optimizer, scheduler, tokenizer): # pragma: no cover
pass

def load_checkpoint(
self, model, ckpt_dir, optimizer, scheduler, load_module_strict, load_optimizer_states, load_lr_scheduler_states
): # pragma: no cover - stub
pass

def save_hf_model(self, model, output_dir, tokenizer=None, **kwargs): # pragma: no cover - stub
pass


def _make_model_config(name_or_path="some/base-model"):
model_config = MagicMock()
model_config.name_or_path = name_or_path
# save_pretrained is a no-op for the test; we only care about the processor.
model_config.save_pretrained = MagicMock()
return model_config


_STRATEGY_MOD = "skyrl.backends.skyrl_train.distributed.strategy"


class TestSaveHfConfigsProcessor:
def test_saves_processor_for_vlm(self, tmp_path):
"""A VLM export resolves the processor and saves it into the HF dir."""
processor = MagicMock()
model_config = _make_model_config()
with (
patch(f"{_STRATEGY_MOD}.check_is_vlm", return_value=True) as check_is_vlm,
patch(f"{_STRATEGY_MOD}.get_processor", return_value=processor) as get_processor,
patch(f"{_STRATEGY_MOD}.GenerationConfig"),
):
_StubStrategy().save_hf_configs(model_config, str(tmp_path))

# The check reuses the loaded config object (no redundant AutoConfig I/O).
check_is_vlm.assert_called_once_with(model_config)
get_processor.assert_called_once_with("some/base-model")
processor.save_pretrained.assert_called_once_with(str(tmp_path))

def test_no_processor_for_text_only(self, tmp_path):
"""A text-only export never resolves or saves a processor."""
model_config = _make_model_config()
with (
patch(f"{_STRATEGY_MOD}.check_is_vlm", return_value=False) as check_is_vlm,
patch(f"{_STRATEGY_MOD}.get_processor") as get_processor,
patch(f"{_STRATEGY_MOD}.GenerationConfig"),
):
_StubStrategy().save_hf_configs(model_config, str(tmp_path))

check_is_vlm.assert_called_once_with(model_config)
get_processor.assert_not_called()

def test_vlm_check_failure_does_not_raise(self, tmp_path):
"""A VLM-detection error is swallowed (export must not crash)."""
with (
patch(f"{_STRATEGY_MOD}.check_is_vlm", side_effect=RuntimeError("config unavailable")),
patch(f"{_STRATEGY_MOD}.get_processor") as get_processor,
patch(f"{_STRATEGY_MOD}.GenerationConfig"),
):
# Should not raise even though the VLM check blew up.
_StubStrategy().save_hf_configs(_make_model_config(), str(tmp_path))

get_processor.assert_not_called()
Comment thread
dinhxuanvu marked this conversation as resolved.

def test_processor_failure_does_not_raise(self, tmp_path):
"""A processor resolution error is swallowed (export must not crash)."""
with (
patch(f"{_STRATEGY_MOD}.check_is_vlm", return_value=True),
patch(f"{_STRATEGY_MOD}.get_processor", side_effect=RuntimeError("no processor")),
patch(f"{_STRATEGY_MOD}.GenerationConfig"),
):
# Should not raise.
_StubStrategy().save_hf_configs(_make_model_config(), str(tmp_path))

def test_skipped_when_no_name_or_path(self, tmp_path):
"""Models initialized without a base path skip processor + gen config."""
with (
patch(f"{_STRATEGY_MOD}.check_is_vlm") as check_is_vlm,
patch(f"{_STRATEGY_MOD}.get_processor") as get_processor,
patch(f"{_STRATEGY_MOD}.GenerationConfig"),
):
_StubStrategy().save_hf_configs(_make_model_config(name_or_path=""), str(tmp_path))

check_is_vlm.assert_not_called()
get_processor.assert_not_called()
Loading