Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion nemo_gym/cli/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import shlex
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from functools import wraps
from glob import glob
from os import makedirs
from os.path import exists
Expand All @@ -35,17 +36,19 @@
from devtools import pprint
from omegaconf import DictConfig, OmegaConf
from pydantic import Field
from rich.markup import escape
from rich.table import Table
from tqdm.auto import tqdm

from nemo_gym import PARENT_DIR, ROOT_DIR
from nemo_gym.cli_setup_command import run_command, setup_env_command
from nemo_gym.config_types import BaseNeMoGymCLIConfig
from nemo_gym.config_types import BaseNeMoGymCLIConfig, ConfigError
from nemo_gym.global_config import (
DRY_RUN_KEY_NAME,
NEMO_GYM_CONFIG_DICT_ENV_VAR_NAME,
NEMO_GYM_CONFIG_PATH_ENV_VAR_NAME,
NEMO_GYM_RESERVED_TOP_LEVEL_KEYS,
GlobalConfigDictParser,
GlobalConfigDictParserConfig,
get_global_config_dict,
)
Expand All @@ -66,6 +69,26 @@
_FORCE_KILL_REAP_TIMEOUT_SEC: int = 2


def exit_cleanly_on_config_error(fn):
"""Decorator: turn user-facing ConfigError into a clean message + non-zero exit.

Config mistakes (missing/typo'd config_paths, malformed config_paths, nothing configured to
run) should fail fast with an actionable message, not a Python traceback. Unexpected errors
still propagate normally.
"""

@wraps(fn)
def wrapper(*args, **kwargs):
try:
return fn(*args, **kwargs)
except ConfigError as e:
# escape() so '[...]' in the message (e.g. config_paths examples) isn't eaten as rich markup.
rich.print(f"[red]Error:[/red] {escape(str(e))}")
raise SystemExit(1)

return wrapper


class RunConfig(BaseNeMoGymCLIConfig):
"""
Start NeMo Gym servers for agents, models, and resources.
Expand Down Expand Up @@ -125,6 +148,10 @@ class RunHelper: # pragma: no cover
def start(self, global_config_dict_parser_config: GlobalConfigDictParserConfig) -> None:
global_config_dict = get_global_config_dict(global_config_dict_parser_config=global_config_dict_parser_config)

# Fail fast before starting Ray if nothing is configured to run (covers env run and the
# e2e rollout-collection path, which both start servers via this method).
GlobalConfigDictParser().raise_on_no_server_instances(global_config_dict)

# Initialize Ray cluster in the main process
# Note: This function will modify the global config dict - update `ray_head_node_address`
initialize_ray()
Expand Down Expand Up @@ -393,6 +420,7 @@ def check_http_server_statuses(self, successful_servers: List[str]) -> List[Tupl
return statuses


@exit_cleanly_on_config_error
def run(
global_config_dict_parser_config: Optional[GlobalConfigDictParserConfig] = None,
): # pragma: no cover
Expand Down
3 changes: 2 additions & 1 deletion nemo_gym/cli/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from tqdm.auto import tqdm

from nemo_gym.benchmarks import BENCHMARKS_DIR, BenchmarkConfig, _load_benchmarks_from_config_paths
from nemo_gym.cli.env import RunHelper
from nemo_gym.cli.env import RunHelper, exit_cleanly_on_config_error
from nemo_gym.config_types import BaseNeMoGymCLIConfig, BenchmarkDatasetConfig
from nemo_gym.global_config import (
ROLLOUT_INDEX_KEY_NAME,
Expand Down Expand Up @@ -225,6 +225,7 @@ def prepare_benchmark() -> None:
list(tqdm(results, total=len(validated)))


@exit_cleanly_on_config_error
def e2e_rollout_collection(): # pragma: no cover
global_config_dict = get_global_config_dict()

Expand Down
22 changes: 22 additions & 0 deletions nemo_gym/config_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,28 @@ def is_server_ref(config_dict: DictConfig) -> Optional[ServerRef]:
return None


class ConfigError(Exception):
"""Base for user-facing configuration errors.

These represent actionable user mistakes (typos, missing files, malformed input) rather than
internal bugs. The CLI catches `ConfigError` and prints just the message — no traceback —
while still leaving them as ordinary exceptions so callers like `validate` can catch and
format them.
"""


class ConfigPathNotFoundError(ConfigError, FileNotFoundError):
"""A `config_paths` entry could not be found in the cwd or the Gym install location."""


class MalformedConfigPathsError(ConfigError, ValueError):
"""`config_paths` was not a list of paths (e.g. a scalar string was passed)."""


class NoServerInstancesError(ConfigError, ValueError):
"""A run was requested but the merged config defines no server instances to start."""


########################################
# Dataset configs for handling and upload/download
########################################
Expand Down
45 changes: 42 additions & 3 deletions nemo_gym/global_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@

from nemo_gym import CACHE_DIR, PARENT_DIR, RESULTS_DIR, WORKING_DIR
from nemo_gym.config_types import (
ConfigPathNotFoundError,
MalformedConfigPathsError,
NoServerInstancesError,
ServerInstanceConfig,
WANDBConfig,
is_almost_server,
Expand Down Expand Up @@ -197,13 +200,27 @@ def load_extra_config_paths(self, config_paths: List[str]) -> Tuple[List[str], L
duplicate_config_paths: List[str] = []
# Just a careful note here that we explicitly mutate config_paths as it is being appended to
for config_path in config_paths:
original_entry = config_path
config_path = Path(config_path)
# Check cwd first for user's local configs, then install location
searched_locations = [config_path]
if not config_path.is_absolute():
cwd_path = Path.cwd() / config_path
config_path = cwd_path if cwd_path.exists() else PARENT_DIR / config_path
install_path = PARENT_DIR / config_path
searched_locations = [cwd_path, install_path]
config_path = cwd_path if cwd_path.exists() else install_path

extra_config = OmegaConf.load(config_path)
try:
extra_config = OmegaConf.load(config_path)
except FileNotFoundError as e:
# Dedupe while preserving order (cwd and install root coincide when run from the repo).
unique_locations = list(dict.fromkeys(str(p) for p in searched_locations))
searched = "\n".join(f" - {p}" for p in unique_locations)
raise ConfigPathNotFoundError(
f"""config_paths entry '{original_entry}' was not found. Looked in:
{searched}
Check the path is spelled correctly and is relative to your working directory or the Gym install root."""
) from e
for new_config_path in extra_config.get(CONFIG_PATHS_KEY_NAME) or []:
if new_config_path not in config_paths:
config_paths.append(new_config_path)
Expand Down Expand Up @@ -237,6 +254,21 @@ def filter_for_server_instance_configs(self, global_config_dict: DictConfig) ->

return server_instance_configs

def raise_on_no_server_instances(self, global_config_dict: DictConfig) -> None:
"""Fail fast if a run has no server instances to start.

Without this, `ng_run` with an empty/omitted `config_paths` starts the head server and Ray
and then hangs with nothing to run. We catch it before Ray initialises with an actionable
message instead.
"""
if self.filter_for_server_instance_configs(global_config_dict):
return

raise NoServerInstancesError(
"""No server instances are configured, so there is nothing to run. Pass one or more configs via config_paths, e.g.:
ng_run "+config_paths=[resources_servers/<env>/configs/<env>.yaml,responses_api_models/<model>/configs/<model>.yaml]\""""
)

def validate_and_populate_defaults(
self,
server_instance_configs: List[ServerInstanceConfig],
Expand Down Expand Up @@ -409,7 +441,14 @@ def parse(self, parse_config: Optional[GlobalConfigDictParserConfig] = None) ->
merged_config_for_config_paths = OmegaConf.merge(dotenv_extra_config, global_config_dict)
ta = TypeAdapter(List[str])
config_paths = merged_config_for_config_paths.get(CONFIG_PATHS_KEY_NAME) or []
config_paths = ta.validate_python(config_paths)
try:
config_paths = ta.validate_python(config_paths)
except ValidationError as e:
raise MalformedConfigPathsError(
f"""'{CONFIG_PATHS_KEY_NAME}' must be a list of paths. Got: {config_paths!r}.
Pass it as a Hydra list, e.g.:
ng_run "+{CONFIG_PATHS_KEY_NAME}=[resources_servers/<env>/configs/<env>.yaml]\""""
) from e

config_paths, extra_configs = self.load_extra_config_paths(config_paths)

Expand Down
34 changes: 33 additions & 1 deletion tests/unit_tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@
RunConfig,
RunHelper,
_run_module_tests_all,
exit_cleanly_on_config_error,
init_resources_server,
)
from nemo_gym.cli.general import display_help_legacy
from nemo_gym.config_types import ResourcesServerInstanceConfig
from nemo_gym.config_types import ConfigError, NoServerInstancesError, ResourcesServerInstanceConfig


# TODO: Eventually we want to add more tests to ensure that the CLI flows do not break
Expand Down Expand Up @@ -276,3 +277,34 @@ def run_one(p: Path) -> Path:
_run_module_tests_all(run_one, paths, max_concurrency=4)
# With a pool of 4, multiple modules must have been in flight simultaneously.
assert max_in_flight >= 2


class TestExitCleanlyOnConfigError:
"""The CLI decorator turns ConfigError into a clean message + non-zero exit, not a traceback."""

def test_config_error_becomes_clean_exit(self) -> None:
@exit_cleanly_on_config_error
def boom():
raise NoServerInstancesError("nothing to run")

with raises(SystemExit) as exc_info:
boom()
assert exc_info.value.code == 1

def test_non_config_error_propagates(self) -> None:
@exit_cleanly_on_config_error
def boom():
raise RuntimeError("unexpected")

with raises(RuntimeError):
boom()

def test_success_passes_through(self) -> None:
@exit_cleanly_on_config_error
def ok():
return 42

assert ok() == 42

def test_config_error_base_catches_subclasses(self) -> None:
assert issubclass(NoServerInstancesError, ConfigError)
76 changes: 76 additions & 0 deletions tests/unit_tests/test_global_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@
import nemo_gym.global_config
import nemo_gym.server_utils
from nemo_gym import CACHE_DIR, WORKING_DIR
from nemo_gym.config_types import (
ConfigPathNotFoundError,
MalformedConfigPathsError,
NoServerInstancesError,
)
from nemo_gym.global_config import (
DEFAULT_HEAD_SERVER_PORT,
NEMO_GYM_CONFIG_DICT_ENV_VAR_NAME,
Expand Down Expand Up @@ -977,3 +982,74 @@ def test_help(self, monkeypatch) -> None:

# Without the help override, this will SystemExit.
GlobalConfigDictParser.parse_global_config_dict_from_cli(None)


class TestConfigLoadErrors:
"""Actionable, fail-fast errors for bad/malformed/empty config_paths (no raw traceback)."""

def test_load_extra_config_paths_missing_relative_lists_both_locations(
self, monkeypatch: MonkeyPatch, tmp_path: Path
) -> None:
cwd, parent = tmp_path / "cwd", tmp_path / "parent"
cwd.mkdir()
parent.mkdir()
monkeypatch.chdir(cwd)
monkeypatch.setattr(nemo_gym.global_config, "PARENT_DIR", parent)

parser = GlobalConfigDictParser()
with raises(ConfigPathNotFoundError) as exc_info:
parser.load_extra_config_paths(["missing/nope.yaml"])

message = str(exc_info.value)
assert "missing/nope.yaml" in message
assert str(cwd / "missing/nope.yaml") in message
assert str(parent / "missing/nope.yaml") in message
assert "spelled correctly" in message

def test_load_extra_config_paths_missing_dedups_when_cwd_is_install_root(
self, monkeypatch: MonkeyPatch, tmp_path: Path
) -> None:
monkeypatch.chdir(tmp_path)
monkeypatch.setattr(nemo_gym.global_config, "PARENT_DIR", tmp_path)

parser = GlobalConfigDictParser()
with raises(ConfigPathNotFoundError) as exc_info:
parser.load_extra_config_paths(["missing/nope.yaml"])

assert str(exc_info.value).count(" - ") == 1

def test_load_extra_config_paths_missing_absolute_path(self, tmp_path: Path) -> None:
missing = tmp_path / "absent.yaml"
parser = GlobalConfigDictParser()
with raises(ConfigPathNotFoundError) as exc_info:
parser.load_extra_config_paths([str(missing)])

message = str(exc_info.value)
assert str(missing) in message
assert message.count(" - ") == 1

def test_parse_malformed_config_paths_raises_actionable_error(self) -> None:
parser = GlobalConfigDictParser()
parse_config = GlobalConfigDictParserConfig(
initial_global_config_dict=DictConfig({"config_paths": "not_a_list.yaml"}),
skip_load_from_cli=True,
skip_load_from_dotenv=True,
)
with raises(MalformedConfigPathsError) as exc_info:
parser.parse(parse_config)

message = str(exc_info.value)
assert "config_paths" in message
assert "list" in message

def test_raise_on_no_server_instances_raises_when_empty(self) -> None:
parser = GlobalConfigDictParser()
config = DictConfig({"config_paths": [], "head_server": {"port": 11000}})
with raises(NoServerInstancesError) as exc_info:
parser.raise_on_no_server_instances(config)
assert "config_paths" in str(exc_info.value)

def test_raise_on_no_server_instances_passes_with_a_server(self) -> None:
parser = GlobalConfigDictParser()
config = DictConfig({"my_server": {"resources_servers": {"x": {"entrypoint": "app.py", "domain": "other"}}}})
parser.raise_on_no_server_instances(config)
Loading