Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import ray
from packaging import version
from ray.actor import ActorHandle
from ray.exceptions import ActorDiedError

if TYPE_CHECKING:
from skyrl.backends.skyrl_train.weight_sync.transfer_strategy import (
Expand Down Expand Up @@ -322,6 +323,13 @@ def create_ray_wrapped_inference_engines(
# NOTE(shu): set to 1 for LoRA
sleep_level = 1 if enable_lora else sleep_level
sleep_refs = [engine.inference_engine_actor.sleep.remote(level=sleep_level) for engine in engines]
ray.get(sleep_refs)
try:
ray.get(sleep_refs)
except ActorDiedError as e:
from skyrl.train.utils.ray_logging import reraise_with_actor_diagnostics

reraise_with_actor_diagnostics(
e, inference_engine_actors, "Inference engine actor(s) died during engine initialization."
)

return engines
38 changes: 21 additions & 17 deletions skyrl/backends/skyrl_train/inference_servers/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

import ray
from loguru import logger
from ray.exceptions import RayTaskError
Comment thread
jamesbraza marked this conversation as resolved.
from ray.util.placement_group import placement_group as ray_placement_group

from skyrl.env_vars import SKYRL_RAY_PG_TIMEOUT_IN_S
from skyrl.train.config import InferenceEngineConfig, SkyRLTrainConfig
from skyrl.train.utils.ray_logging import reraise_with_actor_diagnostics
from skyrl.train.utils.utils import (
ResolvedPlacementGroup,
get_ray_pg_ready_with_timeout,
Expand All @@ -29,6 +31,23 @@
VLLM_START_PORT = 8000


def _start_server_groups(server_groups: List[ServerGroup]) -> None:
"""Start all groups and block until every server actor reports ready."""
# Start all server groups in parallel (non-blocking)
all_refs = [ref for g in server_groups for ref in g.start(blocking=False)]
try:
# Wait for all servers to be ready in one shot
ray.get(all_refs)
except RayTaskError as e:
Comment thread
jamesbraza marked this conversation as resolved.
# Engine init runs inside the actor's start(), so its failure arrives as a
# RayTaskError from the still-alive actor
reraise_with_actor_diagnostics(
e,
[actor for g in server_groups for actor in g.get_actors()],
"Inference server actor(s) failed during startup.",
)


@dataclass
class InferenceServerSetup:
"""Inference server setup result with optional router and groups.
Expand Down Expand Up @@ -136,16 +155,7 @@ def create_inference_servers(
for i in range(num_decode)
]

# Start all prefill and decode groups in parallel (non-blocking)
all_refs = []
for g in prefill_server_groups:
all_refs.extend(g.start(blocking=False))

for g in decode_server_groups:
all_refs.extend(g.start(blocking=False))

# Wait for all servers to be ready in one shot
ray.get(all_refs)
_start_server_groups(prefill_server_groups + decode_server_groups)

# Collect URLs — refs are already resolved so lazy property returns immediately
prefill_urls = [info.url for g in prefill_server_groups for info in g.server_infos]
Expand Down Expand Up @@ -192,13 +202,7 @@ def create_inference_servers(
for i in range(ie_cfg.num_engines)
]

# Start all engine groups in parallel (non-blocking)
all_refs = []
for g in server_groups:
all_refs.extend(g.start(blocking=False))

# Wait for all servers to be ready in one shot
ray.get(all_refs)
_start_server_groups(server_groups)

# Collect URLs — refs are already resolved so lazy property returns immediately
server_urls = [info.url for g in server_groups for info in g.server_infos]
Expand Down
150 changes: 147 additions & 3 deletions skyrl/train/utils/ray_logging.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
"""
Helper to redirect Ray actor stdout/stderr to log file.
Helpers for Ray actor log handling.

This prevents infrastructure logs from polluting the driver's stdout,
allowing only training progress to be displayed to the user.
Redirects actor stdout/stderr to a log file so infrastructure logs don't pollute the
driver's stdout, and recovers those logs' tails for diagnostics when an actor fails.
"""

import contextlib
import os
import sys
import time
from typing import TYPE_CHECKING, NoReturn, Sequence

if TYPE_CHECKING:
from ray.actor import ActorHandle

# Caps total diagnostics size, keeping the tail, sized so the whole cap survives
# SkyRL Tracking.log_exception's truncation of the formatted exception
_MAX_DIAGNOSTICS_CHARS = 10_000


def redirect_actor_output_to_file():
Expand Down Expand Up @@ -35,3 +45,137 @@ def redirect_actor_output_to_file():
with open(log_file, "a", buffering=1) as log_f:
os.dup2(log_f.fileno(), sys.stdout.fileno())
os.dup2(log_f.fileno(), sys.stderr.fileno())


def _tail_file(path: str, max_lines: int, max_bytes: int = 256 * 1024) -> list[str]:
"""
Return up to the last ``max_lines`` lines of ``path``, reading at most ``max_bytes``.

The byte bound matters if the input file is unbounded in size, where reading it
could take nontrivial time on slow filesystems.
"""
with open(path, "rb") as f:
file_size = f.seek(0, os.SEEK_END)
f.seek(max(file_size - max_bytes, 0))
return (
f.read()
.decode(
# The seek can split a multibyte character;
# don't let one bad byte kill the diagnostics
errors="replace"
)
.splitlines()[-max_lines:]
)


def get_actor_logs_tail(
actor_ids: Sequence[str], *, max_lines_per_actor: int = 100, state_api_timeout_s: int = 10
) -> str | None:
"""
Best-effort collection of log tails for failure diagnostics. Never raises.

Looks in the two places actor stderr can land:
1. The shared `SKYRL_LOG_FILE`, when set: actors calling `redirect_actor_output_to_file`
redirect their stdout/stderr, and that of any subprocess they spawn (like vLLM's
engine core), into this file.
2. Each given actor's Ray worker stderr file (`worker-*.err` in the session logs dir),
which this function fetches via the Ray state API, covering actors on remote nodes.

Returns:
Joined log sections, or None if nothing could be collected.
"""
sections: list[str] = []

# Set by SkyRL `initialize_ray` util, in both its calling process
# and every Ray worker's runtime env
log_file = os.getenv("SKYRL_LOG_FILE")
if log_file:
with contextlib.suppress(Exception):
tail = _tail_file(log_file, max_lines=max_lines_per_actor)
if tail:
sections.append(
f"--- tail of SKYRL_LOG_FILE {log_file} "
f"(last {len(tail)} lines; infra actors redirect stdout/stderr here) ---\n" + "\n".join(tail)
)

with contextlib.suppress(Exception): # Never raise out of diagnostics collection
from ray._private.worker import get_dashboard_url
from ray.util.state import get_log

# The state API is served by Ray's dashboard HTTP server; without it, resolving the
# server URL blocks for 20 x 2-s retries before failing, per internal_kv_get_with_retry:
# https://github.com/ray-project/ray/blob/ray-2.51.1/python/ray/_private/utils.py#L1126-L1147
if get_dashboard_url():
for actor_id in actor_ids:
with contextlib.suppress(Exception):
stderr_tail = "".join(
get_log(
actor_id=actor_id,
suffix="err",
tail=max_lines_per_actor,
timeout=state_api_timeout_s,
errors="replace",
)
).rstrip()
if stderr_tail:
sections.append(
f"--- stderr tail of actor {actor_id} (full log: "
f"`ray logs actor --id {actor_id} --err`) ---\n" + stderr_tail
)

if not sections:
return None
diagnostics = "\n\n".join(sections)
if len(diagnostics) > _MAX_DIAGNOSTICS_CHARS:
prefix = "...(truncated)...\n"
diagnostics = prefix + diagnostics[len(prefix) - _MAX_DIAGNOSTICS_CHARS :]
return diagnostics


def reraise_with_actor_diagnostics(
e: BaseException, actors: "Sequence[ActorHandle]", context_message: str, log_flush_grace_s: float = 2
) -> NoReturn:
"""
Re-raise a Ray actor failure as a RuntimeError carrying the actors' log tails.

Ray surfaces actor-side failures without the actor's stderr, which is where the root
cause actually lives when a subprocess of the actor (e.g. a vLLM engine-core child) dies;
the driver-side exception bottoms out at vLLM's `wait_for_engine_startup` with just:
> RuntimeError: Engine core initialization failed. See root cause above. Failed core proc(s): {}

Args:
e: The original exception, chained via `__cause__`.
actors: Actor handles to fall back to when `e` doesn't name the failed actor.
context_message: Lead text describing what was being attempted.
log_flush_grace_s: Seconds to wait before snapshotting the logs.
"""
# ActorDiedError: publicly exposes the failed actor's id
# RayTaskError (a method failed on a still-alive actor): has a private _actor_id
failed_id: str | None = getattr(e, "actor_id", None) or getattr(e, "_actor_id", None)
Comment thread
jamesbraza marked this conversation as resolved.
diagnostics = None
with contextlib.suppress(Exception): # Don't mask the original failure
# vLLM relays the engine-core child's output to the actor's log via a pipe-reader
# thread that can lag the failure reaching the driver; give it a moment to drain
time.sleep(log_flush_grace_s)
diagnostics = get_actor_logs_tail(
actor_ids=[failed_id] if failed_id else [actor._actor_id.hex() for actor in actors]
)
if diagnostics is None:
log_file = os.getenv("SKYRL_LOG_FILE")
if log_file:
location = f"SKYRL_LOG_FILE ({log_file})"
else:
logs_dir = "<ray temp dir>/session_latest/logs"
with contextlib.suppress(Exception): # Tolerate Ray being uninitialized
import ray._private.worker

logs_dir = ray._private.worker._global_node.get_logs_dir_path()
location = f"{logs_dir}/worker-*.err on the failed actor's node"
diagnostics = (
f"(could not fetch actor logs; check {location}, "
f"or run: ray logs actor --id {failed_id or '<actor_id>'} --err)"
)
raise RuntimeError(
f"{context_message} The root-cause traceback usually lives in the failed actor's "
f"logs, not in the current Ray exception.\n\n{diagnostics}"
) from e
Loading
Loading