From 93595d1fa403a27e3bf6875bf4bf9c890a8809a7 Mon Sep 17 00:00:00 2001 From: Jason Mancuso <7891333+jvmncs@users.noreply.github.com> Date: Thu, 4 Jun 2026 15:38:05 -0400 Subject: [PATCH 1/6] Add opaque HTTP rollout endpoint mode --- slime/backends/sglang_utils/http_endpoint.py | 87 ++++++++++ slime/ray/placement_group.py | 5 + slime/ray/rollout.py | 4 + slime/rollout/sglang_rollout.py | 43 ++++- slime/utils/arguments.py | 34 ++++ tests/test_megatron_argument_validation.py | 2 + tests/test_placement_group.py | 6 + tests/test_rollout_http_endpoint.py | 159 +++++++++++++++++++ 8 files changed, 334 insertions(+), 6 deletions(-) create mode 100644 slime/backends/sglang_utils/http_endpoint.py create mode 100644 tests/test_rollout_http_endpoint.py diff --git a/slime/backends/sglang_utils/http_endpoint.py b/slime/backends/sglang_utils/http_endpoint.py new file mode 100644 index 0000000000..a13ea58ab0 --- /dev/null +++ b/slime/backends/sglang_utils/http_endpoint.py @@ -0,0 +1,87 @@ +"""Helpers for rollout backends served by an opaque HTTP endpoint.""" + +from __future__ import annotations + +import dataclasses +import logging +from urllib.parse import urlparse + +logger = logging.getLogger(__name__) + + +def normalize_rollout_http_endpoint_url(url: str) -> str: + """Normalize an HTTP endpoint base URL used by rollout generation.""" + url = url.rstrip("/") + parsed = urlparse(url) + if parsed.scheme not in ("http", "https") or parsed.netloc == "": + raise ValueError( + f"Invalid rollout HTTP endpoint URL {url!r}. Use an absolute http:// or https:// URL." + ) + return url + + +def uses_rollout_http_endpoint(args) -> bool: + return bool(getattr(args, "rollout_http_endpoint_url", None)) + + +def rollout_http_endpoint_url(args, endpoint: str = "/generate") -> str: + base = normalize_rollout_http_endpoint_url(args.rollout_http_endpoint_url) + if not endpoint.startswith("/"): + endpoint = f"/{endpoint}" + return f"{base}{endpoint}" + + +@dataclasses.dataclass +class HttpEndpointRolloutServer: + """Rollout server backed by an opaque HTTP endpoint. + + The endpoint is intentionally not assumed to be an SGLang router: it may not + expose worker-management APIs such as ``/workers`` and it may represent an + elastic fleet with no stable per-engine handles. + """ + + endpoint_url: str + model_name: str = "default" + update_weights: bool = True + router_ip: str | None = None + router_port: int | None = None + server_groups: list = dataclasses.field(default_factory=list) + engines: list = dataclasses.field(default_factory=list) + engine_gpu_counts: list[int] = dataclasses.field(default_factory=list) + engine_gpu_offsets: list[int] = dataclasses.field(default_factory=list) + num_new_engines: int = 0 + + @property + def all_engines(self): + return self.engines + + def recover(self): + logger.warning("Fault tolerance is not supported for opaque HTTP rollout endpoints; skip recover.") + + def offload(self): + return [] + + def onload(self, tags: list[str] | None = None): + return [] + + def onload_weights(self): + return [] + + def onload_kv(self): + return [] + + +def start_http_endpoint_rollout_servers(args) -> dict[str, HttpEndpointRolloutServer]: + endpoint_url = normalize_rollout_http_endpoint_url(args.rollout_http_endpoint_url) + args.rollout_http_endpoint_url = endpoint_url + args.sglang_model_routers = {} + if getattr(args, "rollout_num_engines", None) is None: + args.rollout_num_engines = 1 + logger.info("Using opaque HTTP rollout endpoint: %s", endpoint_url) + return { + "default": HttpEndpointRolloutServer( + endpoint_url=endpoint_url, + model_name="default", + update_weights=True, + ) + } diff --git a/slime/ray/placement_group.py b/slime/ray/placement_group.py index c181c8e7f0..108d723263 100644 --- a/slime/ray/placement_group.py +++ b/slime/ray/placement_group.py @@ -103,6 +103,11 @@ def _get_placement_group_layout(args) -> tuple[int, int]: if args.debug_train_only: return actor_num_gpus, 0 + if getattr(args, "rollout_http_endpoint_url", None): + if args.debug_rollout_only: + return 0, 0 + return actor_num_gpus, actor_num_gpus + if args.rollout_external: if args.debug_rollout_only: return 0, 0 diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index 5766d6b171..865682db44 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -15,6 +15,7 @@ from sglang.srt.constants import GPU_MEMORY_TYPE_CUDA_GRAPH, GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS from slime.backends.sglang_utils.external import start_external_rollout_servers +from slime.backends.sglang_utils.http_endpoint import start_http_endpoint_rollout_servers, uses_rollout_http_endpoint from slime.backends.sglang_utils.sglang_config import ModelConfig, ServerGroupConfig, SglangConfig from slime.backends.sglang_utils.sglang_engine import SGLangEngine from slime.rollout.base_types import call_rollout_fn @@ -1080,6 +1081,9 @@ def start_rollout_servers(args, pg) -> tuple[dict[str, Any], list[Any]]: Note: ``init_http_client`` should be called separately before this, as the HTTP client is shared across all servers. """ + if uses_rollout_http_endpoint(args): + return start_http_endpoint_rollout_servers(args) + if args.rollout_external: return start_external_rollout_servers(args, start_router=_start_router) diff --git a/slime/rollout/sglang_rollout.py b/slime/rollout/sglang_rollout.py index bb87360639..de486b43f6 100644 --- a/slime/rollout/sglang_rollout.py +++ b/slime/rollout/sglang_rollout.py @@ -14,6 +14,7 @@ from packaging.version import parse from tqdm import tqdm +from slime.backends.sglang_utils.http_endpoint import rollout_http_endpoint_url, uses_rollout_http_endpoint from slime.backends.sglang_utils.server_control import abort_servers_until_idle from slime.rollout.base_types import RolloutFnEvalOutput, RolloutFnTrainOutput from slime.rollout.filter_hub.base_types import MetricGatherer, call_dynamic_filter @@ -63,7 +64,7 @@ def _prepare_prompt_ids(sample: Sample, tokenizer, processor: Any) -> list[int]: def get_model_url(args: Namespace, model_name: str, endpoint: str = "/generate") -> str: - """Return the router URL for a named model. + """Return the rollout URL for a named model. Use this in custom rollout functions to route requests to a specific model when multiple models are deployed via ``--sglang-config``:: @@ -71,9 +72,14 @@ def get_model_url(args: Namespace, model_name: str, endpoint: str = "/generate") url = get_model_url(args, "ref", "/generate") resp = await post(url, json=payload) - Falls back to the default router if *model_name* is not found or - ``sglang_model_routers`` is not set. + If ``--rollout-http-endpoint-url`` is set, returns that opaque endpoint + with *endpoint* appended and does not assume SGLang router APIs exist. + Otherwise, falls back to the default router if *model_name* is not found + or ``sglang_model_routers`` is not set. """ + if uses_rollout_http_endpoint(args): + return rollout_http_endpoint_url(args, endpoint) + routers = getattr(args, "sglang_model_routers", None) if routers and model_name in routers: ip, port = routers[model_name] @@ -154,7 +160,7 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A assert isinstance(sample.prompt, str) state = GenerateState(args) - url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate" + url = get_model_url(args, "default", "/generate") assert ( sample.status == Sample.Status.PENDING or sample.status == Sample.Status.ABORTED @@ -349,12 +355,13 @@ async def generate_and_rm_group( async def abort(args: Namespace, rollout_id: int) -> list[list[Sample]]: - aborted_samples = [] - state = GenerateState(args) assert not state.aborted state.aborted = True + if getattr(args, "rollout_http_endpoint_abort_strategy", None) == "cancel-only": + return await _cancel_pending_tasks(state) + if parse(sglang_router.__version__) <= parse("0.2.1"): response = await get(f"http://{args.sglang_router_ip}:{args.sglang_router_port}/list_workers") urls = response["urls"] @@ -364,6 +371,30 @@ async def abort(args: Namespace, rollout_id: int) -> list[list[Sample]]: await abort_servers_until_idle(urls) + return await _drain_aborted_pending_tasks(args, rollout_id, state) + + +async def _cancel_pending_tasks(state: GenerateState) -> list[list[Sample]]: + if not state.pendings: + return [] + pending = list(state.pendings) + for task in pending: + task.cancel() + results = await asyncio.gather(*pending, return_exceptions=True) + for result in results: + if isinstance(result, Exception): + logger.warning("Pending rollout task ended during cancel-only abort: %s", result) + state.pendings.clear() + return [] + + +async def _drain_aborted_pending_tasks( + args: Namespace, + rollout_id: int, + state: GenerateState, +) -> list[list[Sample]]: + aborted_samples = [] + # make sure all the pending tasks are finished count = 0 while state.pendings: diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index 6efe85eae7..8bd80623ff 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -11,6 +11,7 @@ from slime.backends.sglang_utils.arguments import sglang_parse_args from slime.backends.sglang_utils.arguments import validate_args as sglang_validate_args from slime.backends.sglang_utils.external import apply_external_engine_info_to_args +from slime.backends.sglang_utils.http_endpoint import normalize_rollout_http_endpoint_url from slime.utils.eval_config import EvalDatasetConfig, build_eval_dataset_configs, ensure_dataset_list from slime.utils.logging_utils import configure_logger @@ -545,6 +546,27 @@ def add_rollout_arguments(parser): nargs="+", help="Address and ports of the external engines.", ) + parser.add_argument( + "--rollout-http-endpoint-url", + type=str, + default=None, + help=( + "Opaque HTTP endpoint base URL for rollout generation. " + "When set, slime sends /generate requests to this endpoint " + "without launching or registering SGLang workers." + ), + ) + parser.add_argument( + "--rollout-http-endpoint-abort-strategy", + type=str, + choices=["cancel-only", "router-workers"], + default=None, + help=( + "Abort behavior for the default SGLang rollout. " + "'cancel-only' cancels local pending tasks and does not call router /workers; " + "'router-workers' uses the SGLang router worker list." + ), + ) return parser def add_fault_tolerance_arguments(parser): @@ -1876,6 +1898,18 @@ def slime_validate_args(args): ) args.debug_train_only = True + if args.rollout_http_endpoint_url is not None: + args.rollout_http_endpoint_url = normalize_rollout_http_endpoint_url(args.rollout_http_endpoint_url) + if args.rollout_http_endpoint_abort_strategy is None: + args.rollout_http_endpoint_abort_strategy = "cancel-only" + if getattr(args, "rollout_num_engines", None) is None: + args.rollout_num_engines = 1 + elif args.rollout_http_endpoint_abort_strategy is None: + args.rollout_http_endpoint_abort_strategy = "router-workers" + + if args.rollout_http_endpoint_url is not None and args.rollout_external_engine_addrs is not None: + raise ValueError("--rollout-http-endpoint-url and --rollout-external-engine-addrs are mutually exclusive.") + args.rollout_external = args.rollout_external_engine_addrs is not None if args.rollout_external and not args.debug_train_only: diff --git a/tests/test_megatron_argument_validation.py b/tests/test_megatron_argument_validation.py index 1f435cb577..0d43d849f7 100644 --- a/tests/test_megatron_argument_validation.py +++ b/tests/test_megatron_argument_validation.py @@ -267,6 +267,8 @@ def make_slime_validate_args(**overrides): save_debug_train_data=None, load_debug_rollout_data=None, rollout_external_engine_addrs=None, + rollout_http_endpoint_url=None, + rollout_http_endpoint_abort_strategy=None, debug_train_only=False, actor_num_gpus_per_node=8, actor_num_nodes=1, diff --git a/tests/test_placement_group.py b/tests/test_placement_group.py index c1ae8aedef..ea5da55e7c 100644 --- a/tests/test_placement_group.py +++ b/tests/test_placement_group.py @@ -40,6 +40,12 @@ def _args(**overrides): pytest.param({"colocate": True, "rollout_num_gpus": 0}, (16, 0), id="colocate_zero_rollout_gpus"), pytest.param({"rollout_external": True}, (16, 16), id="external"), pytest.param({"rollout_external": True, "debug_rollout_only": True}, (0, 0), id="external_debug_rollout"), + pytest.param({"rollout_http_endpoint_url": "https://rollout.example"}, (16, 16), id="http_endpoint"), + pytest.param( + {"rollout_http_endpoint_url": "https://rollout.example", "debug_rollout_only": True}, + (0, 0), + id="http_endpoint_debug_rollout", + ), ], ) def test_placement_group_layout(overrides, expected): diff --git a/tests/test_rollout_http_endpoint.py b/tests/test_rollout_http_endpoint.py new file mode 100644 index 0000000000..fae7ef2154 --- /dev/null +++ b/tests/test_rollout_http_endpoint.py @@ -0,0 +1,159 @@ +import asyncio +import sys +from argparse import Namespace +from pathlib import Path + +import pytest + +REPO_ROOT = Path(__file__).resolve().parents[1] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +try: + import ray # noqa: F401 +except ImportError: + pass + +try: + from tests.plugin_contracts._shared import install_stubs +except ImportError: + from plugin_contracts._shared import install_stubs + +install_stubs(with_sglang_router=True, with_transformers=True) + +from slime.backends.sglang_utils.http_endpoint import ( # noqa: E402 + normalize_rollout_http_endpoint_url, + start_http_endpoint_rollout_servers, +) +from slime.rollout import sglang_rollout # noqa: E402 +from slime.rollout.sglang_rollout import abort, generate, get_model_url # noqa: E402 +from slime.utils.types import Sample # noqa: E402 + +NUM_GPUS = 0 + + +def _args(**overrides): + values = { + "ci_test": False, + "rollout_http_endpoint_url": None, + "rollout_http_endpoint_abort_strategy": "router-workers", + "sglang_router_ip": "10.0.0.1", + "sglang_router_port": 30000, + "sglang_model_routers": None, + "router_policy": None, + "use_rollout_routing_replay": False, + "partial_rollout": False, + "mask_offpolicy_in_partial_rollout": False, + "sglang_speculative_algorithm": None, + } + values.update(overrides) + return Namespace(**values) + + +class _Tokenizer: + def encode(self, prompt, add_special_tokens=False): + assert add_special_tokens is False + return [101, len(prompt)] + + +class _GenerateState: + def __init__(self, args): + self.args = args + self.tokenizer = _Tokenizer() + self.processor = None + self.pendings = set() + self.aborted = False + + +def test_normalize_rollout_http_endpoint_url_requires_absolute_http_url(): + assert normalize_rollout_http_endpoint_url("https://rollout.example/") == "https://rollout.example" + with pytest.raises(ValueError, match="absolute http"): + normalize_rollout_http_endpoint_url("rollout.example") + + +def test_get_model_url_prefers_http_endpoint(): + args = _args( + rollout_http_endpoint_url="https://rollout.example/base/", + sglang_model_routers={"default": ("10.0.0.2", 30001)}, + ) + + assert get_model_url(args, "default", "/generate") == "https://rollout.example/base/generate" + assert get_model_url(args, "reward", "score") == "https://rollout.example/base/score" + + +def test_get_model_url_uses_model_router_without_http_endpoint(): + args = _args(sglang_model_routers={"reward": ("10.0.0.3", 30002)}) + + assert get_model_url(args, "reward", "/generate") == "http://10.0.0.3:30002/generate" + assert get_model_url(args, "missing", "/generate") == "http://10.0.0.1:30000/generate" + + +def test_generate_posts_to_http_endpoint(monkeypatch): + captured = {} + + async def fake_post(url, payload, headers=None): + captured["url"] = url + captured["payload"] = payload + captured["headers"] = headers + return { + "text": " answer", + "meta_info": { + "output_token_logprobs": [[-0.25, 42]], + "finish_reason": {"type": "stop"}, + "prompt_tokens": 2, + "cached_tokens": 1, + }, + } + + monkeypatch.setattr(sglang_rollout, "GenerateState", _GenerateState) + monkeypatch.setattr(sglang_rollout, "post", fake_post) + + args = _args(rollout_http_endpoint_url="https://rollout.example") + sample = asyncio.run(generate(args, Sample(index=0, prompt="hi"), {"max_new_tokens": 8})) + + assert captured["url"] == "https://rollout.example/generate" + assert captured["payload"]["input_ids"] == [101, 2] + assert captured["payload"]["return_logprob"] is True + assert sample.response == " answer" + assert sample.tokens == [101, 2, 42] + assert sample.status == Sample.Status.COMPLETED + + +def test_cancel_only_abort_does_not_query_router_workers(monkeypatch): + async def run(): + async def never_finishes(): + await asyncio.sleep(60) + + task = asyncio.create_task(never_finishes()) + state = _GenerateState(_args()) + state.pendings.add(task) + + def fake_state(_args): + return state + + async def fail_get(_url): + raise AssertionError("cancel-only abort must not query router workers") + + monkeypatch.setattr(sglang_rollout, "GenerateState", fake_state) + monkeypatch.setattr(sglang_rollout, "get", fail_get) + + result = await abort(_args(rollout_http_endpoint_abort_strategy="cancel-only"), rollout_id=7) + + assert result == [] + assert state.pendings == set() + assert task.cancelled() + + asyncio.run(run()) + + +def test_start_http_endpoint_rollout_servers_returns_no_engine_server(): + args = _args(rollout_http_endpoint_url="https://rollout.example/", rollout_num_engines=None) + + servers = start_http_endpoint_rollout_servers(args) + + server = servers["default"] + assert args.rollout_http_endpoint_url == "https://rollout.example" + assert args.rollout_num_engines == 1 + assert server.engines == [] + assert server.server_groups == [] + assert server.router_ip is None From d8526ee55a2c54ecfefdbf6222550ef3fa7f468f Mon Sep 17 00:00:00 2001 From: Jason Mancuso <7891333+jvmncs@users.noreply.github.com> Date: Thu, 4 Jun 2026 15:59:13 -0400 Subject: [PATCH 2/6] Add version-pinned rollout request policy --- slime/rollout/sglang_rollout.py | 81 ++++++++- slime/utils/arguments.py | 11 ++ slime/utils/http_utils.py | 14 +- .../test_plugin_runtime_hook_contracts.py | 23 +++ tests/test_rollout_http_endpoint.py | 169 +++++++++++++++++- 5 files changed, 281 insertions(+), 17 deletions(-) diff --git a/slime/rollout/sglang_rollout.py b/slime/rollout/sglang_rollout.py index de486b43f6..b1a730e760 100644 --- a/slime/rollout/sglang_rollout.py +++ b/slime/rollout/sglang_rollout.py @@ -34,11 +34,12 @@ from .rm_hub import async_rm, batched_async_rm -__all__ = ["generate_rollout", "get_model_url"] +__all__ = ["generate_rollout", "get_model_url", "rollout_request_context"] logger = logging.getLogger(__name__) _PROCESSOR_PROMPT_KEYS = {"input_ids", "attention_mask"} +_MISSING = object() def _prepare_prompt_ids(sample: Sample, tokenizer, processor: Any) -> list[int]: @@ -87,6 +88,67 @@ def get_model_url(args: Namespace, model_name: str, endpoint: str = "/generate") return f"http://{args.sglang_router_ip}:{args.sglang_router_port}{endpoint}" +@contextmanager +def rollout_request_context(args: Namespace, rollout_id: int, *, evaluation: bool = False): + old_rollout_id = getattr(args, "_rollout_request_rollout_id", _MISSING) + old_evaluation = getattr(args, "_rollout_request_evaluation", _MISSING) + args._rollout_request_rollout_id = int(rollout_id) + args._rollout_request_evaluation = bool(evaluation) + + try: + yield + finally: + _restore_context_attr(args, "_rollout_request_rollout_id", old_rollout_id) + _restore_context_attr(args, "_rollout_request_evaluation", old_evaluation) + + +def _restore_context_attr(args: Namespace, name: str, old_value: Any) -> None: + if old_value is _MISSING: + if hasattr(args, name): + delattr(args, name) + else: + setattr(args, name, old_value) + + +async def _post_generate( + args: Namespace, + url: str, + payload: dict[str, Any], + *, + headers: dict | None, + sample: Sample, +): + request = { + "url": url, + "payload": payload, + "headers": headers, + "max_retries": 60, + "retry_sleep": 1.0, + "rollout_id": getattr(args, "_rollout_request_rollout_id", None), + "evaluation": getattr(args, "_rollout_request_evaluation", False), + } + + if (hook_path := getattr(args, "custom_rollout_request_hook_path", None)) is not None: + hook = load_function(hook_path) + result = hook(args, sample, request) + if inspect.isawaitable(result): + result = await result + if result is not None: + if not isinstance(result, dict): + raise TypeError( + f"{hook_path} must return None or a dict of request updates, got {type(result).__name__}" + ) + request.update(result) + + return await post( + request["url"], + request["payload"], + max_retries=request["max_retries"], + headers=request["headers"], + retry_sleep=request["retry_sleep"], + ) + + class GenerateState(metaclass=SingletonMeta): """ The global state for the generation process. @@ -203,7 +265,7 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A headers = {"X-SMG-Routing-Key": sample.session_id} with trace_span(sample, "sglang_generate", attrs={"max_new_tokens": sampling_params["max_new_tokens"]}) as span: - output = await post(url, payload, headers=headers) + output = await _post_generate(args, url, payload, headers=headers, sample=sample) span.update(build_sglang_meta_trace_attrs(output["meta_info"])) if "output_token_logprobs" in output["meta_info"]: @@ -646,11 +708,12 @@ def generate_rollout( RolloutFnTrainOutput | RolloutFnEvalOutput: the output of the rollout """ assert args.rollout_global_dataset - if evaluation: - output, _ = run(eval_rollout(args, rollout_id)) + with rollout_request_context(args, rollout_id, evaluation=evaluation): + if evaluation: + output, _ = run(eval_rollout(args, rollout_id)) + return output + + output, aborted_samples = run(generate_rollout_async(args, rollout_id, data_source.get_samples)) + if aborted_samples: + data_source.add_samples(aborted_samples) return output - - output, aborted_samples = run(generate_rollout_async(args, rollout_id, data_source.get_samples)) - if aborted_samples: - data_source.add_samples(aborted_samples) - return output diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index 8bd80623ff..cce4f804e1 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -567,6 +567,17 @@ def add_rollout_arguments(parser): "'router-workers' uses the SGLang router worker list." ), ) + parser.add_argument( + "--custom-rollout-request-hook-path", + type=str, + default=None, + help=( + "Path to a hook called before each default SGLang rollout /generate request. " + "Signature: ``def hook(args, sample, request) -> None | dict``. " + "The request dict contains url, payload, headers, max_retries, retry_sleep, " + "rollout_id, and evaluation. Mutate it in place or return a dict of updates." + ), + ) return parser def add_fault_tolerance_arguments(parser): diff --git a/slime/utils/http_utils.py b/slime/utils/http_utils.py index b8c3a30fb3..ae3e57565c 100644 --- a/slime/utils/http_utils.py +++ b/slime/utils/http_utils.py @@ -162,7 +162,7 @@ def _next_actor(): return actor -async def _post(client, url, payload, max_retries=60, headers=None): +async def _post(client, url, payload, max_retries=60, headers=None, retry_sleep: float = 1.0): retry_count = 0 while retry_count < max_retries: response = None @@ -188,7 +188,7 @@ async def _post(client, url, payload, max_retries=60, headers=None): if retry_count >= max_retries: logger.info(f"Max retries ({max_retries}) reached, failing... (url={url})") raise e - await asyncio.sleep(1) + await asyncio.sleep(retry_sleep) continue finally: if response is not None: @@ -262,8 +262,8 @@ def __init__(self, concurrency: int): trust_env=False, # internal SGLang comm only — never route through system proxy ) - async def do_post(self, url, payload, max_retries=60, headers=None): - return await _post(self._client, url, payload, max_retries, headers=headers) + async def do_post(self, url, payload, max_retries=60, headers=None, retry_sleep: float = 1.0): + return await _post(self._client, url, payload, max_retries, headers=headers, retry_sleep=retry_sleep) # Create actors per node created = [] @@ -288,7 +288,7 @@ async def do_post(self, url, payload, max_retries=60, headers=None): _post_actors = created -async def post(url, payload, max_retries=60, headers=None): +async def post(url, payload, max_retries=60, headers=None, retry_sleep: float = 1.0): # If distributed mode is enabled and actors exist, dispatch via Ray. if _distributed_post_enabled and _post_actors: try: @@ -300,13 +300,13 @@ async def post(url, payload, max_retries=60, headers=None): # `min(32, cpu+4)`), which becomes a hard upper bound on the # number of in-flight POSTs that can be waited on in parallel # and produces large tail latencies under high concurrency. - obj_ref = actor.do_post.remote(url, payload, max_retries, headers=headers) + obj_ref = actor.do_post.remote(url, payload, max_retries, headers=headers, retry_sleep=retry_sleep) return await obj_ref except Exception as e: logger.info(f"[http_utils] Distributed POST failed, falling back to local: {e} (url={url})") # fall through to local - return await _post(_http_client, url, payload, max_retries, headers=headers) + return await _post(_http_client, url, payload, max_retries, headers=headers, retry_sleep=retry_sleep) async def get(url): diff --git a/tests/plugin_contracts/test_plugin_runtime_hook_contracts.py b/tests/plugin_contracts/test_plugin_runtime_hook_contracts.py index a8380feecc..3fb6d365fc 100644 --- a/tests/plugin_contracts/test_plugin_runtime_hook_contracts.py +++ b/tests/plugin_contracts/test_plugin_runtime_hook_contracts.py @@ -37,6 +37,7 @@ def run_contract_test_file() -> None: "custom-reward-post-process-path", "custom-convert-samples-to-train-data-path", "rollout-data-postprocess-path", + "custom-rollout-request-hook-path", ], ) @@ -73,6 +74,11 @@ def reference_rollout_data_postprocess(args, rollout_id, rollout_data) -> None: args.rollout_data_postprocess_called = True +def reference_rollout_request_hook(args, sample, request) -> None: + args.rollout_request_hook_called = True + request["payload"]["hooked"] = sample.index + + def make_sample(index: int, reward: float = 1.0) -> Sample: return Sample( index=index, @@ -128,6 +134,14 @@ def invoke_rollout_data_postprocess(fn): assert args.rollout_data_postprocess_called is True +def invoke_rollout_request_hook(fn): + args = type("Args", (), {})() + request = {"payload": {}} + assert fn(args, Sample(index=7), request) is None + assert args.rollout_request_hook_called is True + assert request["payload"]["hooked"] == 7 + + HOOK_CASES = [ HookCase( "custom_rollout_log", @@ -174,6 +188,15 @@ def invoke_rollout_data_postprocess(fn): ("args", "rollout_id", "rollout_data"), invoke_rollout_data_postprocess, ), + HookCase( + "rollout_request_hook", + "CUSTOM_ROLLOUT_REQUEST_HOOK_PATH", + "plugin_contracts.test_plugin_runtime_hook_contracts.reference_rollout_request_hook", + "slime/rollout/sglang_rollout.py", + "hook(args, sample, request)", + ("args", "sample", "request"), + invoke_rollout_request_hook, + ), ] diff --git a/tests/test_rollout_http_endpoint.py b/tests/test_rollout_http_endpoint.py index fae7ef2154..187cd7887d 100644 --- a/tests/test_rollout_http_endpoint.py +++ b/tests/test_rollout_http_endpoint.py @@ -45,6 +45,7 @@ def _args(**overrides): "partial_rollout": False, "mask_offpolicy_in_partial_rollout": False, "sglang_speculative_algorithm": None, + "custom_rollout_request_hook_path": None, } values.update(overrides) return Namespace(**values) @@ -91,7 +92,7 @@ def test_get_model_url_uses_model_router_without_http_endpoint(): def test_generate_posts_to_http_endpoint(monkeypatch): captured = {} - async def fake_post(url, payload, headers=None): + async def fake_post(url, payload, headers=None, **_kwargs): captured["url"] = url captured["payload"] = payload captured["headers"] = headers @@ -119,6 +120,172 @@ async def fake_post(url, payload, headers=None): assert sample.status == Sample.Status.COMPLETED +def test_generate_request_hook_can_add_exact_weight_version(monkeypatch): + captured = {} + + async def fake_post(url, payload, headers=None, max_retries=60, retry_sleep=1.0): + captured["url"] = url + captured["payload"] = payload + captured["max_retries"] = max_retries + captured["retry_sleep"] = retry_sleep + return { + "text": " answer", + "meta_info": { + "output_token_logprobs": [[-0.25, 42]], + "finish_reason": {"type": "stop"}, + "prompt_tokens": 2, + "cached_tokens": 1, + }, + } + + monkeypatch.setattr(sglang_rollout, "GenerateState", _GenerateState) + monkeypatch.setattr(sglang_rollout, "post", fake_post) + + def hook(args, sample, request): + assert args.rollout_http_endpoint_url == "https://rollout.example" + assert sample.index == 0 + assert request["rollout_id"] == 9 + assert request["evaluation"] is False + request["payload"]["weight_version"] = {"exact_version": request["rollout_id"]} + request["max_retries"] = 123 + request["retry_sleep"] = 0.25 + + monkeypatch.setattr(sglang_rollout, "load_function", lambda path: hook) + + args = _args( + rollout_http_endpoint_url="https://rollout.example", + custom_rollout_request_hook_path="example.hook", + ) + with sglang_rollout.rollout_request_context(args, rollout_id=9): + sample = asyncio.run(generate(args, Sample(index=0, prompt="hi"), {"max_new_tokens": 8})) + + assert captured["url"] == "https://rollout.example/generate" + assert captured["payload"]["weight_version"] == {"exact_version": 9} + assert captured["max_retries"] == 123 + assert captured["retry_sleep"] == 0.25 + assert sample.status == Sample.Status.COMPLETED + + +def test_generate_request_hook_can_return_request_updates(monkeypatch): + captured = {} + + async def fake_post(url, payload, headers=None, max_retries=60, retry_sleep=1.0): + captured["url"] = url + captured["payload"] = payload + captured["max_retries"] = max_retries + captured["retry_sleep"] = retry_sleep + return { + "text": " answer", + "meta_info": { + "output_token_logprobs": [[-0.25, 42]], + "finish_reason": {"type": "stop"}, + "prompt_tokens": 2, + "cached_tokens": 1, + }, + } + + monkeypatch.setattr(sglang_rollout, "GenerateState", _GenerateState) + monkeypatch.setattr(sglang_rollout, "post", fake_post) + + async def hook(_args, _sample, request): + payload = dict(request["payload"]) + payload["weight_version"] = {"min_required_version": request["rollout_id"]} + return { + "payload": payload, + "max_retries": 123, + "retry_sleep": 0.25, + } + + monkeypatch.setattr(sglang_rollout, "load_function", lambda path: hook) + + args = _args( + rollout_http_endpoint_url="https://rollout.example", + custom_rollout_request_hook_path="example.hook", + ) + with sglang_rollout.rollout_request_context(args, rollout_id=9): + sample = asyncio.run(generate(args, Sample(index=0, prompt="hi"), {"max_new_tokens": 8})) + + assert captured["url"] == "https://rollout.example/generate" + assert captured["payload"]["weight_version"] == {"min_required_version": 9} + assert captured["max_retries"] == 123 + assert captured["retry_sleep"] == 0.25 + assert sample.status == Sample.Status.COMPLETED + + +def test_generate_retries_until_exact_weight_version_is_available(monkeypatch): + aiohttp_web = pytest.importorskip("aiohttp.web") + httpx = pytest.importorskip("httpx") + + async def run(): + from slime.utils import http_utils + + attempts = [] + + async def handle_generate(request): + payload = await request.json() + attempts.append(payload) + assert payload["weight_version"] == {"exact_version": 11} + if len(attempts) == 1: + raise aiohttp_web.HTTPNotFound(text="weight version not loaded") + if len(attempts) == 2: + raise aiohttp_web.HTTPConflict(text="weight version still loading") + return aiohttp_web.json_response( + { + "text": " answer", + "meta_info": { + "output_token_logprobs": [[-0.25, 42]], + "finish_reason": {"type": "stop"}, + "prompt_tokens": 2, + "cached_tokens": 0, + }, + } + ) + + app = aiohttp_web.Application() + app.router.add_post("/generate", handle_generate) + runner = aiohttp_web.AppRunner(app) + await runner.setup() + site = aiohttp_web.TCPSite(runner, "127.0.0.1", 0) + await site.start() + port = site._server.sockets[0].getsockname()[1] + + old_client = http_utils._http_client + old_distributed = http_utils._distributed_post_enabled + old_post_actors = http_utils._post_actors + client = httpx.AsyncClient(timeout=httpx.Timeout(None), trust_env=False) + http_utils._http_client = client + http_utils._distributed_post_enabled = False + http_utils._post_actors = [] + try: + monkeypatch.setattr(sglang_rollout, "GenerateState", _GenerateState) + + def hook(_args, _sample, request): + request["payload"]["weight_version"] = {"exact_version": request["rollout_id"]} + request["max_retries"] = 5 + request["retry_sleep"] = 0.01 + + monkeypatch.setattr(sglang_rollout, "load_function", lambda path: hook) + + args = _args( + rollout_http_endpoint_url=f"http://127.0.0.1:{port}", + custom_rollout_request_hook_path="example.hook", + ) + with sglang_rollout.rollout_request_context(args, rollout_id=11): + sample = await generate(args, Sample(index=0, prompt="hi"), {"max_new_tokens": 8}) + finally: + await client.aclose() + http_utils._http_client = old_client + http_utils._distributed_post_enabled = old_distributed + http_utils._post_actors = old_post_actors + await runner.cleanup() + + assert len(attempts) == 3 + assert sample.status == Sample.Status.COMPLETED + assert sample.tokens == [101, 2, 42] + + asyncio.run(run()) + + def test_cancel_only_abort_does_not_query_router_workers(monkeypatch): async def run(): async def never_finishes(): From 4b46fd7b4494d0873dd09163851fd965abaab201 Mon Sep 17 00:00:00 2001 From: Jason Mancuso <7891333+jvmncs@users.noreply.github.com> Date: Thu, 4 Jun 2026 16:17:10 -0400 Subject: [PATCH 3/6] Add publish-only disk delta hooks --- slime/backends/megatron_utils/actor.py | 5 +- .../update_weight_from_distributed_delta.py | 82 +++-- slime/utils/arguments.py | 43 +++ tests/test_delta_publish_only.py | 302 ++++++++++++++++++ tests/test_megatron_argument_validation.py | 3 + 5 files changed, 401 insertions(+), 34 deletions(-) create mode 100644 tests/test_delta_publish_only.py diff --git a/slime/backends/megatron_utils/actor.py b/slime/backends/megatron_utils/actor.py index 941659f1fe..0e68f8739b 100644 --- a/slime/backends/megatron_utils/actor.py +++ b/slime/backends/megatron_utils/actor.py @@ -609,8 +609,9 @@ def update_weights(self) -> None: ) reconnect_rollout_engines = self.args.offload_train and self.args.use_critic and not self.args.colocate + force_connect_rollout_target = getattr(self.args, "update_weight_delta_publish_only", False) - if not rollout_engines and not reconnect_rollout_engines: + if not rollout_engines and not reconnect_rollout_engines and not force_connect_rollout_target: if dist.get_rank() == 0: logger.info("No updatable SGLang engines are running; skip weight update.") return @@ -620,7 +621,7 @@ def update_weights(self) -> None: elif self.args.offload_train: reload_process_groups() - if num_new_engines > 0 or reconnect_rollout_engines: + if num_new_engines > 0 or reconnect_rollout_engines or force_connect_rollout_target: self.weight_updater.connect_rollout_engines( rollout_engines, rollout_engine_lock, diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed_delta.py b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed_delta.py index fbe24bbc1c..b1aefb96a6 100644 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed_delta.py +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed_delta.py @@ -509,13 +509,15 @@ def __init__( self.writer: AsyncSafetensorsWriter | None = None self.delta_dir: str | None = None self._pre_push_hook: Callable | None = None - # Disk transport: each pass boundary publishes its accumulated files - # (the only globally-synced flush points, since ``_publish_batch`` - # contains collectives). ``_pre_push_hook`` may return a Future, in - # which case the receiver RPC is deferred behind it via - # ``_rpc_executor`` so the main encode thread continues immediately. - # ``_pending_publishes`` holds the resulting Future[list[ObjectRef]] - # on rank 0; ``_finalize_sync`` awaits them at end of sync. + self._publish_hook: Callable | None = None + self._publish_only = bool(getattr(args, "update_weight_delta_publish_only", False)) + self._publish_wait = getattr(args, "update_weight_delta_publish_wait", "next-sync") + # Direct disk transport publishes at each pass boundary so receiver + # apply can overlap later encoding. Publish-only disk transport emits + # one complete version at finalize time, so external consumers never + # observe a partially published version. ``_pre_push_hook`` may return + # a Future; ``_pending_publishes`` holds Future[list[ObjectRef]] values + # that rank 0 awaits at end of sync. self._pending_files: list[str] = [] self._pending_publishes: list = [] self._published_any: bool = False @@ -531,6 +533,10 @@ def __init__( from slime.utils.misc import load_function self._pre_push_hook = load_function(args.custom_delta_pre_push_path) + if getattr(args, "custom_delta_publish_path", None): + from slime.utils.misc import load_function + + self._publish_hook = load_function(args.custom_delta_publish_path) def connect_rollout_engines( self, @@ -584,7 +590,7 @@ def update_weights(self) -> None: if self._is_pp_src_rank: os.makedirs(self._version_dir, exist_ok=True) - if dist.get_rank() == 0: + if dist.get_rank() == 0 and not self._publish_only: ray.get([engine.pause_generation.remote() for engine in self.rollout_engines]) ray.get([engine.flush_cache.remote() for engine in self.rollout_engines]) dist.barrier(group=get_gloo_group()) @@ -649,13 +655,15 @@ def _send_weights(self, pbar: tqdm | None) -> None: def _flush_and_publish(self, bucket: DeltaBucket, pbar: tqdm | None) -> None: """ - End-of-sub-pass: drain the in-flight bucket, barrier all PP ranks, then - (disk-only) fire one publish RPC for everything since the last call. + End-of-sub-pass: drain the in-flight bucket and barrier all PP ranks. + Direct disk transport also fires one receiver RPC for everything since + the last call; publish-only transport waits until finalize so the hook + sees a complete version. """ if bucket.has_updates: self._flush_bucket(bucket, pbar) dist.barrier(group=get_gloo_group()) - if self.transport == "disk": + if self.transport == "disk" and not self._publish_only: self._publish_batch() def _pipeline_pass( @@ -746,13 +754,14 @@ def _flush_bucket(self, bucket: DeltaBucket, pbar: tqdm | None) -> None: def _publish_batch(self) -> None: """ - Drain pending fsyncs, invoke the pre-push hook (may return a Future for an - async durability step on shared FS), then defer rank 0's - ``update_weights_from_disk`` RPC behind that Future via ``_rpc_executor``. - Each deferred dispatch lands in ``_pending_publishes`` as a - Future[list[ObjectRef]]; ``_finalize_sync`` awaits both layers. Safe to call - with empty ``_pending_files``: the all_gather still synchronizes and rank 0 - skips the dispatch when no rank produced files. + Drain pending fsyncs, invoke the pre-push hook (may return a Future for + an async durability step on shared FS), gather filenames, then defer + rank 0's publish/direct-update work behind that Future via + ``_rpc_executor``. Each deferred dispatch lands in + ``_pending_publishes`` as a Future[list[ObjectRef]]; ``_finalize_sync`` + awaits both layers. Safe to call with empty ``_pending_files``: direct + disk transport skips the dispatch, while publish-only still calls the + publish hook so a no-op version can be made visible. """ self.writer.drain() dist.barrier(group=get_gloo_group()) @@ -768,24 +777,31 @@ def _publish_batch(self) -> None: flat = [f for sub in all_files for f in sub] self._pending_files.clear() - if dist.get_rank() == 0 and flat: + if dist.get_rank() == 0 and (flat or self._publish_only): version_dir = self._version_dir engines = list(self.rollout_engines) weight_version = str(self.weight_version) self._published_any = True def _fire_when_committed() -> list: + refs = [] if commit_future is not None: commit_future.result() - return [ - engine.update_weights_from_disk.remote( - model_path=version_dir, - files=flat, - load_format="delta", - weight_version=weight_version, + if self._publish_hook is not None: + hook_refs = self._publish_hook(self.args, version_dir, flat, weight_version, engines) + if hook_refs is not None: + refs.extend(hook_refs) + if not self._publish_only: + refs.extend( + engine.update_weights_from_disk.remote( + model_path=version_dir, + files=flat, + load_format="delta", + weight_version=weight_version, + ) + for engine in engines ) - for engine in engines - ] + return refs self._pending_publishes.append(self._rpc_executor.submit(_fire_when_committed)) @@ -801,24 +817,26 @@ def _finalize_sync(self) -> None: dist.barrier(group=get_gloo_group()) return - if self._pending_files: + if self._pending_files or self._publish_only: self._publish_batch() if dist.get_rank() == 0: # Each entry is a Future returning a list of ObjectRefs. Awaiting the # Futures unblocks the (commit-then-RPC) chain; ray.get waits for the # receivers' apply to finish. object_refs = [ref for fut in self._pending_publishes for ref in fut.result()] - ray.get(object_refs) + if object_refs: + ray.get(object_refs) self._pending_publishes.clear() - if not self._published_any: + if not self._publish_only and not self._published_any: # No delta files needed publishing this sync (e.g. all-zero diff). # Engines never saw the new version via update_weights_from_disk, so # bump it explicitly to keep their recorded version in sync with ours. weight_version = str(self.weight_version) ray.get([engine.set_weight_version.remote(weight_version) for engine in self.rollout_engines]) - if not self.args.update_weight_delta_keep_files: + if not self._publish_only and not self.args.update_weight_delta_keep_files: shutil.rmtree(self._version_dir, ignore_errors=True) - ray.get([engine.continue_generation.remote() for engine in self.rollout_engines]) + if not self._publish_only: + ray.get([engine.continue_generation.remote() for engine in self.rollout_engines]) dist.barrier(group=get_gloo_group()) def _record_metrics(self) -> None: diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index cce4f804e1..e392878537 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -203,6 +203,15 @@ def add_train_arguments(parser): "release. Prefer the transport-level directory flag for both full and delta disk sync." ), ) + parser.add_argument( + "--update-weight-delta-root", + type=str, + default=None, + help=( + "Optional root directory for publish-based disk delta metadata. " + "Defaults to the parent of --update-weight-delta-dir when omitted." + ), + ) parser.add_argument( "--update-weight-delta-keep-files", action="store_true", @@ -220,6 +229,28 @@ def add_train_arguments(parser): "Called from every trainer rank; the hook gates itself." ), ) + parser.add_argument( + "--custom-delta-publish-path", + type=str, + default=None, + help=( + "Path to a custom rank-0 function called after disk delta filenames are gathered " + "and the pre-push hook has completed. Signature: " + "``def hook(args, version_dir: str, files: list[str], weight_version: str, " + "rollout_engines) -> list | None``. Returned Ray ObjectRefs are awaited before " + "the sync completes." + ), + ) + parser.add_argument( + "--update-weight-delta-publish-only", + action="store_true", + default=False, + help=( + "For disk delta transport, publish gathered delta files through " + "--custom-delta-publish-path without issuing direct rollout-engine update RPCs. " + "Useful for elastic HTTP rollout endpoints that consume published versions." + ), + ) parser.add_argument( "--custom-model-provider-path", type=str, @@ -1762,6 +1793,9 @@ def _resolve_update_weight_disk_dir(args) -> None: def _validate_update_weight_args(args) -> None: _resolve_update_weight_disk_dir(args) + if args.update_weight_delta_publish_only and args.update_weight_mode != "delta": + raise ValueError("--update-weight-delta-publish-only requires --update-weight-mode=delta.") + if args.update_weight_mode == "delta": if args.update_weight_transport not in ("nccl", "disk"): raise ValueError( @@ -1774,6 +1808,15 @@ def _validate_update_weight_args(args) -> None: "weights via CUDA IPC (only a handle crosses processes), so the delta bookkeeping " "(snapshot + diff + sparse encode) is pure overhead." ) + if args.update_weight_transport == "disk" and args.update_weight_delta_root is None: + args.update_weight_delta_root = os.path.dirname(os.path.abspath(args.update_weight_disk_dir)) + if args.update_weight_delta_publish_only: + if args.update_weight_transport != "disk": + raise ValueError("--update-weight-delta-publish-only requires --update-weight-transport=disk.") + if not args.custom_delta_publish_path: + raise ValueError("--update-weight-delta-publish-only requires --custom-delta-publish-path.") + if not args.update_weight_delta_keep_files: + raise ValueError("--update-weight-delta-publish-only requires --update-weight-delta-keep-files.") def slime_validate_args(args): diff --git a/tests/test_delta_publish_only.py b/tests/test_delta_publish_only.py new file mode 100644 index 0000000000..8b78689fee --- /dev/null +++ b/tests/test_delta_publish_only.py @@ -0,0 +1,302 @@ +from __future__ import annotations + +import os +import sys +import types +from argparse import Namespace +from dataclasses import dataclass +from enum import Enum +from pathlib import Path + +import pytest + +torch = pytest.importorskip("torch") + + +def _install_import_stubs() -> None: + if "safetensors.torch" not in sys.modules: + try: + import safetensors.torch # noqa: F401 + except ImportError: + safetensors = types.ModuleType("safetensors") + safetensors_torch = types.ModuleType("safetensors.torch") + safetensors_torch.save = lambda tensors, metadata=None: b"" + safetensors.torch = safetensors_torch + sys.modules["safetensors"] = safetensors + sys.modules["safetensors.torch"] = safetensors_torch + + if "ray" not in sys.modules: + ray = types.ModuleType("ray") + actor = types.ModuleType("ray.actor") + + class ActorHandle: + pass + + class ObjectRef: + pass + + actor.ActorHandle = ActorHandle + ray.actor = actor + ray.ObjectRef = ObjectRef + ray.get = lambda refs: refs + sys.modules["ray"] = ray + sys.modules["ray.actor"] = actor + + if "megatron" not in sys.modules: + megatron = types.ModuleType("megatron") + core = types.ModuleType("megatron.core") + mpu = types.ModuleType("megatron.core.mpu") + mpu.get_data_parallel_rank = lambda with_context_parallel=False: 0 + mpu.get_tensor_model_parallel_rank = lambda: 0 + mpu.get_pipeline_model_parallel_rank = lambda: 0 + mpu.get_expert_model_parallel_world_size = lambda: 1 + mpu.get_expert_model_parallel_group = lambda: None + mpu.get_expert_tensor_parallel_world_size = lambda: 1 + mpu.get_expert_tensor_parallel_group = lambda: None + mpu.get_tensor_model_parallel_world_size = lambda: 1 + mpu.get_tensor_model_parallel_group = lambda: None + mpu.get_expert_model_parallel_rank = lambda: 0 + transformer = types.ModuleType("megatron.core.transformer") + transformer_layer = types.ModuleType("megatron.core.transformer.transformer_layer") + transformer_layer.get_transformer_layer_offset = lambda config, *args, **kwargs: 0 + core.mpu = mpu + core.transformer = transformer + megatron.core = core + sys.modules["megatron"] = megatron + sys.modules["megatron.core"] = core + sys.modules["megatron.core.mpu"] = mpu + sys.modules["megatron.core.transformer"] = transformer + sys.modules["megatron.core.transformer.transformer_layer"] = transformer_layer + + megatron_to_hf = types.ModuleType("slime.backends.megatron_utils.megatron_to_hf") + megatron_to_hf.convert_to_hf = lambda args, model_name, name, param, quantization_config: [(name, param)] + sys.modules.setdefault("slime.backends.megatron_utils.megatron_to_hf", megatron_to_hf) + + if "sglang" not in sys.modules: + sglang = types.ModuleType("sglang") + srt = types.ModuleType("sglang.srt") + sys.modules["sglang"] = sglang + sys.modules["sglang.srt"] = srt + + if "sglang.srt.layers.quantization.fp8_utils" not in sys.modules: + fp8_utils = types.ModuleType("sglang.srt.layers.quantization.fp8_utils") + fp8_utils.quant_weight_ue8m0 = None + fp8_utils.transform_scale_ue8m0 = None + sys.modules["sglang.srt.layers"] = types.ModuleType("sglang.srt.layers") + sys.modules["sglang.srt.layers.quantization"] = types.ModuleType("sglang.srt.layers.quantization") + sys.modules["sglang.srt.layers.quantization.fp8_utils"] = fp8_utils + + if "sglang.srt.model_loader.utils" not in sys.modules: + model_loader_utils = types.ModuleType("sglang.srt.model_loader.utils") + model_loader_utils.should_deepgemm_weight_requant_ue8m0 = None + sys.modules["sglang.srt.model_loader"] = types.ModuleType("sglang.srt.model_loader") + sys.modules["sglang.srt.model_loader.utils"] = model_loader_utils + + utils = sys.modules.get("sglang.srt.utils") + if utils is None: + utils = types.ModuleType("sglang.srt.utils") + utils.__path__ = [] + sys.modules["sglang.srt.utils"] = utils + utils.MultiprocessingSerializer = object + + patch_torch = types.ModuleType("sglang.srt.utils.patch_torch") + patch_torch.monkey_patch_torch_reductions = lambda: None + sys.modules.setdefault("sglang.srt.utils.patch_torch", patch_torch) + sys.modules.setdefault("sglang.srt.patch_torch", patch_torch) + + if "sglang.srt.managers.io_struct" not in sys.modules: + io_struct = types.ModuleType("sglang.srt.managers.io_struct") + + class DeltaEncoding(Enum): + INDICES = "indices" + DELTAS = "deltas" + DELTAS_ZSTD = "deltas_zstd" + + @dataclass + class DeltaParam: + name: str + dtype: str + shape: list[int] + pos_start: int + pos_end: int + pos_width: int + val_start: int + val_end: int + + @dataclass + class DeltaSpec: + encoding: DeltaEncoding + params: list[DeltaParam] + checksum: int + + io_struct.DeltaEncoding = DeltaEncoding + io_struct.DeltaParam = DeltaParam + io_struct.DeltaSpec = DeltaSpec + sys.modules["sglang.srt.managers"] = types.ModuleType("sglang.srt.managers") + sys.modules["sglang.srt.managers.io_struct"] = io_struct + + tensor_bucket = types.ModuleType("sglang.srt.weight_sync.tensor_bucket") + tensor_bucket.FlattenedTensorBucket = object + sys.modules.setdefault("sglang.srt.weight_sync", types.ModuleType("sglang.srt.weight_sync")) + sys.modules.setdefault("sglang.srt.weight_sync.tensor_bucket", tensor_bucket) + + +_install_import_stubs() + +from slime.backends.megatron_utils.update_weight import update_weight_from_distributed_delta as delta_mod # noqa: E402 + + +class _InlineFuture: + def __init__(self, value): + self._value = value + + def result(self): + return self._value + + +class _InlineExecutor: + def submit(self, fn): + return _InlineFuture(fn()) + + +class _FakeWriter: + def __init__(self): + self.drain_calls = 0 + + def drain(self): + self.drain_calls += 1 + + +class _RemoteMethod: + def __init__(self, owner, name): + self._owner = owner + self._name = name + + def remote(self, **kwargs): + self._owner.calls.append((self._name, kwargs)) + return f"{self._name}-ref" + + +class _FakeEngine: + def __init__(self): + self.calls = [] + self.update_weights_from_disk = _RemoteMethod(self, "update_weights_from_disk") + self.set_weight_version = _RemoteMethod(self, "set_weight_version") + self.continue_generation = _RemoteMethod(self, "continue_generation") + + +def _patch_single_rank_dist(monkeypatch): + barrier_calls = [] + gathered = [] + + monkeypatch.setattr(delta_mod, "get_gloo_group", lambda: None) + monkeypatch.setattr(delta_mod.dist, "get_rank", lambda: 0) + monkeypatch.setattr(delta_mod.dist, "get_world_size", lambda: 1) + monkeypatch.setattr(delta_mod.dist, "barrier", lambda group=None: barrier_calls.append(group)) + + def all_gather_object(outputs, value, group=None): + gathered.append((list(value), group)) + outputs[0] = list(value) + + monkeypatch.setattr(delta_mod.dist, "all_gather_object", all_gather_object) + return barrier_calls, gathered + + +def _make_publish_only_updater(tmp_path: Path, publish_hook, *, publish_wait: str = "next-sync"): + updater = delta_mod.UpdateWeightFromDistributedDelta.__new__(delta_mod.UpdateWeightFromDistributedDelta) + updater.args = Namespace(update_weight_delta_keep_files=True, update_weight_delta_publish_wait=publish_wait) + updater.transport = "disk" + updater._publish_only = True + updater._publish_wait = publish_wait + updater._pending_files = [] + updater._pending_publishes = [] + updater._published_any = False + updater._pre_push_hook = None + updater._publish_hook = publish_hook + updater._rpc_executor = _InlineExecutor() + updater.writer = _FakeWriter() + updater.weight_version = 7 + updater._version_dir = os.path.join(tmp_path, "weight_v000007") + os.makedirs(updater._version_dir, exist_ok=True) + updater.rollout_engines = [_FakeEngine()] + return updater + + +def test_publish_only_finalize_calls_publish_hook_without_engine_rpcs_or_cleanup(monkeypatch, tmp_path): + _patch_single_rank_dist(monkeypatch) + ray_get_calls = [] + monkeypatch.setattr(delta_mod.ray, "get", lambda refs: ray_get_calls.append(refs)) + monkeypatch.setattr(delta_mod.shutil, "rmtree", lambda *_args, **_kwargs: pytest.fail("publish-only must keep files")) + + hook_calls = [] + + def publish_hook(args, version_dir, files, weight_version, engines): + hook_calls.append((args, version_dir, files, weight_version, engines)) + return ["publish-ref"] + + updater = _make_publish_only_updater(tmp_path, publish_hook) + updater._pending_files = ["rank0000_flush000000.safetensors"] + + updater._finalize_sync() + + assert updater.writer.drain_calls == 1 + assert updater._pending_files == [] + assert updater._pending_publishes == [] + assert updater._published_any is True + assert hook_calls == [ + ( + updater.args, + updater._version_dir, + ["rank0000_flush000000.safetensors"], + "7", + updater.rollout_engines, + ) + ] + assert updater.rollout_engines[0].calls == [] + assert ray_get_calls == [["publish-ref"]] + assert os.path.isdir(updater._version_dir) + + +def test_publish_only_finalize_publishes_noop_version(monkeypatch, tmp_path): + _patch_single_rank_dist(monkeypatch) + ray_get_calls = [] + monkeypatch.setattr(delta_mod.ray, "get", lambda refs: ray_get_calls.append(refs)) + + hook_calls = [] + + def publish_hook(args, version_dir, files, weight_version, engines): + hook_calls.append((version_dir, files, weight_version, engines)) + return None + + updater = _make_publish_only_updater(tmp_path, publish_hook) + + updater._finalize_sync() + + assert updater.writer.drain_calls == 1 + assert updater._published_any is True + assert hook_calls == [(updater._version_dir, [], "7", updater.rollout_engines)] + assert updater.rollout_engines[0].calls == [] + assert ray_get_calls == [] + + +def test_publish_only_flush_defers_publish_until_finalize(monkeypatch, tmp_path): + barrier_calls, _gathered = _patch_single_rank_dist(monkeypatch) + updater = _make_publish_only_updater(tmp_path, publish_hook=None) + + class FakeBucket: + has_updates = True + + flush_calls = [] + monkeypatch.setattr(updater, "_flush_bucket", lambda bucket, pbar: flush_calls.append((bucket, pbar))) + monkeypatch.setattr( + updater, + "_publish_batch", + lambda: pytest.fail("publish-only should publish once from _finalize_sync"), + ) + + bucket = FakeBucket() + updater._flush_and_publish(bucket, pbar=None) + + assert flush_calls == [(bucket, None)] + assert len(barrier_calls) == 1 + assert updater._pending_publishes == [] diff --git a/tests/test_megatron_argument_validation.py b/tests/test_megatron_argument_validation.py index 0d43d849f7..5178a82143 100644 --- a/tests/test_megatron_argument_validation.py +++ b/tests/test_megatron_argument_validation.py @@ -269,6 +269,7 @@ def make_slime_validate_args(**overrides): rollout_external_engine_addrs=None, rollout_http_endpoint_url=None, rollout_http_endpoint_abort_strategy=None, + update_weight_delta_publish_only=False, debug_train_only=False, actor_num_gpus_per_node=8, actor_num_nodes=1, @@ -362,6 +363,7 @@ def test_update_weight_delta_rejects_colocate(monkeypatch): update_weight_transport="nccl", update_weight_disk_dir=None, update_weight_delta_dir=None, + update_weight_delta_publish_only=False, colocate=True, ) @@ -377,6 +379,7 @@ def test_update_weight_delta_rejects_unknown_transport(monkeypatch): update_weight_transport="tensor", update_weight_disk_dir=None, update_weight_delta_dir=None, + update_weight_delta_publish_only=False, colocate=False, ) From 76510d769865a30b77c0e6b493c47abeb57fca33 Mon Sep 17 00:00:00 2001 From: Jason Mancuso <7891333+jvmncs@users.noreply.github.com> Date: Thu, 11 Jun 2026 00:44:57 -0400 Subject: [PATCH 4/6] Pipeline publish-only hook awaits by one sync In publish-only mode, _finalize_sync now dispatches the publish hook without awaiting its refs; the start of the next sync (or disconnect_rollout_engines) drains them, so the publish overlaps a full training step with at most one version outstanding. Failures surface one sync late. Direct disk transport still drains before cleanup and resume. --- .../update_weight_from_distributed_delta.py | 70 +++++++++++++------ slime/backends/sglang_utils/http_endpoint.py | 4 +- slime/utils/arguments.py | 17 ++++- tests/test_delta_publish_only.py | 49 ++++++++++++- 4 files changed, 113 insertions(+), 27 deletions(-) diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed_delta.py b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed_delta.py index b1aefb96a6..0ca1d25556 100644 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed_delta.py +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed_delta.py @@ -516,8 +516,11 @@ def __init__( # apply can overlap later encoding. Publish-only disk transport emits # one complete version at finalize time, so external consumers never # observe a partially published version. ``_pre_push_hook`` may return - # a Future; ``_pending_publishes`` holds Future[list[ObjectRef]] values - # that rank 0 awaits at end of sync. + # a Future; ``_pending_publishes`` holds Future[list[ObjectRef]] values. + # Direct transport drains them at end of sync (resume + cleanup depend + # on them); publish-only transport leaves the last publish in flight + # across the training step and drains it at the start of the next sync, + # so at most one version's publish is ever outstanding. self._pending_files: list[str] = [] self._pending_publishes: list = [] self._published_any: bool = False @@ -567,6 +570,8 @@ def connect_rollout_engines( self._group_name = f"slime-pp_{pp_rank}" def disconnect_rollout_engines(self) -> None: + # A queued publish holds engine handles; settle it before dropping them. + self._drain_pending_publishes() if self.transport == "nccl": super().disconnect_rollout_engines() @@ -597,7 +602,10 @@ def update_weights(self) -> None: self.density_nnz = self.density_numel = self.wire_bytes = self._flush_idx = 0 self._pending_files.clear() - self._pending_publishes.clear() + # Publish-only mode leaves the previous sync's publish in flight across + # the training step; settle it (and surface any publish failure, one + # sync late) before producing the next version. + self._drain_pending_publishes() self._published_any = False if self.writer is not None: self.writer.reset_counters() @@ -758,10 +766,11 @@ def _publish_batch(self) -> None: an async durability step on shared FS), gather filenames, then defer rank 0's publish/direct-update work behind that Future via ``_rpc_executor``. Each deferred dispatch lands in - ``_pending_publishes`` as a Future[list[ObjectRef]]; ``_finalize_sync`` - awaits both layers. Safe to call with empty ``_pending_files``: direct - disk transport skips the dispatch, while publish-only still calls the - publish hook so a no-op version can be made visible. + ``_pending_publishes`` as a Future[list[ObjectRef]]; direct disk + transport awaits both layers in ``_finalize_sync``, publish-only at the + start of the next sync. Safe to call with empty ``_pending_files``: + direct disk transport skips the dispatch, while publish-only still + calls the publish hook so a no-op version can be made visible. """ self.writer.drain() dist.barrier(group=get_gloo_group()) @@ -809,7 +818,14 @@ def _finalize_sync(self) -> None: """ Per-transport end-of-sync. NCCL: each flush already broadcasted; just resume. Disk: publish the trailing files, wait for all streamed applies to land, then - cleanup + resume. + cleanup + resume. Publish-only: dispatch the version's publish and return + without awaiting it by default — the hook runs concurrently with the + next training step and the start of the next sync settles it. With + ``--update-weight-delta-publish-wait=sync``, publish-only drains the + hook before returning so the next rollout dispatch starts only after + the hook's readiness contract has been satisfied. In both modes the + version dir must outlive the sync, which publish-only's no-cleanup + rule already ensures. """ if self.transport == "nccl": if dist.get_rank() == 0: @@ -819,26 +835,40 @@ def _finalize_sync(self) -> None: if self._pending_files or self._publish_only: self._publish_batch() - if dist.get_rank() == 0: - # Each entry is a Future returning a list of ObjectRefs. Awaiting the - # Futures unblocks the (commit-then-RPC) chain; ray.get waits for the - # receivers' apply to finish. - object_refs = [ref for fut in self._pending_publishes for ref in fut.result()] - if object_refs: - ray.get(object_refs) - self._pending_publishes.clear() - if not self._publish_only and not self._published_any: + if dist.get_rank() == 0 and self._publish_only and self._publish_wait == "sync": + self._drain_pending_publishes() + if dist.get_rank() == 0 and not self._publish_only: + # Resume + cleanup must order after the receivers' apply, so drain + # the publish chain before either. + self._drain_pending_publishes() + if not self._published_any: # No delta files needed publishing this sync (e.g. all-zero diff). # Engines never saw the new version via update_weights_from_disk, so # bump it explicitly to keep their recorded version in sync with ours. weight_version = str(self.weight_version) ray.get([engine.set_weight_version.remote(weight_version) for engine in self.rollout_engines]) - if not self._publish_only and not self.args.update_weight_delta_keep_files: + if not self.args.update_weight_delta_keep_files: shutil.rmtree(self._version_dir, ignore_errors=True) - if not self._publish_only: - ray.get([engine.continue_generation.remote() for engine in self.rollout_engines]) + ray.get([engine.continue_generation.remote() for engine in self.rollout_engines]) dist.barrier(group=get_gloo_group()) + def _drain_pending_publishes(self) -> None: + """ + Await every queued (commit-then-publish) Future, then the ObjectRefs it + returned. Awaiting the Futures unblocks the (commit-then-RPC) chain; + ray.get waits for the hook's work and the receivers' apply to finish. + Re-raises a failed publish here, on the draining rank (rank 0; the list + is empty elsewhere). In publish-only ``next-sync`` mode that is one + sync after the failing dispatch; in ``sync`` mode it is still inside the + publishing sync. + """ + if not self._pending_publishes: + return + object_refs = [ref for fut in self._pending_publishes for ref in fut.result()] + self._pending_publishes.clear() + if object_refs: + ray.get(object_refs) + def _record_metrics(self) -> None: """ Allreduce density/byte counters across PP-src ranks; stash on diff --git a/slime/backends/sglang_utils/http_endpoint.py b/slime/backends/sglang_utils/http_endpoint.py index a13ea58ab0..e0b87d0a54 100644 --- a/slime/backends/sglang_utils/http_endpoint.py +++ b/slime/backends/sglang_utils/http_endpoint.py @@ -14,9 +14,7 @@ def normalize_rollout_http_endpoint_url(url: str) -> str: url = url.rstrip("/") parsed = urlparse(url) if parsed.scheme not in ("http", "https") or parsed.netloc == "": - raise ValueError( - f"Invalid rollout HTTP endpoint URL {url!r}. Use an absolute http:// or https:// URL." - ) + raise ValueError(f"Invalid rollout HTTP endpoint URL {url!r}. Use an absolute http:// or https:// URL.") return url diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index e392878537..0713b8654b 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -238,7 +238,22 @@ def add_train_arguments(parser): "and the pre-push hook has completed. Signature: " "``def hook(args, version_dir: str, files: list[str], weight_version: str, " "rollout_engines) -> list | None``. Returned Ray ObjectRefs are awaited before " - "the sync completes." + "the sync completes. With --update-weight-delta-publish-only, " + "--update-weight-delta-publish-wait controls whether this happens in the same " + "sync or at the start of the next sync." + ), + ) + parser.add_argument( + "--update-weight-delta-publish-wait", + type=str, + choices=["next-sync", "sync"], + default="next-sync", + help=( + "When --update-weight-delta-publish-only is set, choose when rank 0 waits for " + "--custom-delta-publish-path to finish. 'next-sync' pipelines publish work " + "across the next training step and surfaces failures one sync late. 'sync' " + "blocks update_weights until the publish hook returns, useful when the hook " + "polls rollout-fleet readiness before allowing the next rollout dispatch." ), ) parser.add_argument( diff --git a/tests/test_delta_publish_only.py b/tests/test_delta_publish_only.py index 8b78689fee..62fbde37db 100644 --- a/tests/test_delta_publish_only.py +++ b/tests/test_delta_publish_only.py @@ -226,7 +226,9 @@ def test_publish_only_finalize_calls_publish_hook_without_engine_rpcs_or_cleanup _patch_single_rank_dist(monkeypatch) ray_get_calls = [] monkeypatch.setattr(delta_mod.ray, "get", lambda refs: ray_get_calls.append(refs)) - monkeypatch.setattr(delta_mod.shutil, "rmtree", lambda *_args, **_kwargs: pytest.fail("publish-only must keep files")) + monkeypatch.setattr( + delta_mod.shutil, "rmtree", lambda *_args, **_kwargs: pytest.fail("publish-only must keep files") + ) hook_calls = [] @@ -241,7 +243,6 @@ def publish_hook(args, version_dir, files, weight_version, engines): assert updater.writer.drain_calls == 1 assert updater._pending_files == [] - assert updater._pending_publishes == [] assert updater._published_any is True assert hook_calls == [ ( @@ -253,9 +254,17 @@ def publish_hook(args, version_dir, files, weight_version, engines): ) ] assert updater.rollout_engines[0].calls == [] - assert ray_get_calls == [["publish-ref"]] + # The publish stays in flight across the training step: finalize must not + # await its refs; the next sync (or disconnect) drains it. + assert len(updater._pending_publishes) == 1 + assert ray_get_calls == [] assert os.path.isdir(updater._version_dir) + updater._drain_pending_publishes() + + assert updater._pending_publishes == [] + assert ray_get_calls == [["publish-ref"]] + def test_publish_only_finalize_publishes_noop_version(monkeypatch, tmp_path): _patch_single_rank_dist(monkeypatch) @@ -276,9 +285,43 @@ def publish_hook(args, version_dir, files, weight_version, engines): assert updater._published_any is True assert hook_calls == [(updater._version_dir, [], "7", updater.rollout_engines)] assert updater.rollout_engines[0].calls == [] + assert len(updater._pending_publishes) == 1 + + updater._drain_pending_publishes() + + assert updater._pending_publishes == [] assert ray_get_calls == [] +def test_publish_only_sync_wait_drains_publish_before_return(monkeypatch, tmp_path): + _patch_single_rank_dist(monkeypatch) + ray_get_calls = [] + monkeypatch.setattr(delta_mod.ray, "get", lambda refs: ray_get_calls.append(refs)) + + updater = _make_publish_only_updater(tmp_path, lambda *a: ["publish-ref"], publish_wait="sync") + updater._pending_files = ["rank0000_flush000000.safetensors"] + + updater._finalize_sync() + + assert updater._pending_publishes == [] + assert ray_get_calls == [["publish-ref"]] + + +def test_disconnect_drains_pending_publish(monkeypatch, tmp_path): + _patch_single_rank_dist(monkeypatch) + ray_get_calls = [] + monkeypatch.setattr(delta_mod.ray, "get", lambda refs: ray_get_calls.append(refs)) + + updater = _make_publish_only_updater(tmp_path, lambda *a: ["publish-ref"]) + updater._finalize_sync() + assert len(updater._pending_publishes) == 1 + + updater.disconnect_rollout_engines() + + assert updater._pending_publishes == [] + assert ray_get_calls == [["publish-ref"]] + + def test_publish_only_flush_defers_publish_until_finalize(monkeypatch, tmp_path): barrier_calls, _gathered = _patch_single_rank_dist(monkeypatch) updater = _make_publish_only_updater(tmp_path, publish_hook=None) From 1499c3efe65322e596e95089172c95c63e37331c Mon Sep 17 00:00:00 2001 From: Jason Mancuso <7891333+jvmncs@users.noreply.github.com> Date: Thu, 11 Jun 2026 00:44:57 -0400 Subject: [PATCH 5/6] [docs] document opaque HTTP rollout endpoint and publish-only delta sync --- docs/en/advanced/delta-weight-sync.md | 30 ++++++++++++++++++++ docs/en/advanced/external-rollout-engines.md | 30 ++++++++++++++++++-- docs/en/get_started/customization.md | 22 +++++++++++++- docs/zh/advanced/delta-weight-sync.md | 30 ++++++++++++++++++++ docs/zh/advanced/external-rollout-engines.md | 30 ++++++++++++++++++-- docs/zh/get_started/customization.md | 19 ++++++++++++- 6 files changed, 153 insertions(+), 8 deletions(-) diff --git a/docs/en/advanced/delta-weight-sync.md b/docs/en/advanced/delta-weight-sync.md index a1e670f81d..9dd4fa2c9f 100644 --- a/docs/en/advanced/delta-weight-sync.md +++ b/docs/en/advanced/delta-weight-sync.md @@ -4,6 +4,7 @@ - [Quick Start](#quick-start) - [Mode vs Transport](#mode-vs-transport) - [How It Works](#how-it-works) +- [Publish-Only Disk Delta](#publish-only-disk-delta) - [Encoding Choice](#encoding-choice) - [Why Not Colocated](#why-not-colocated) @@ -92,6 +93,35 @@ For both transports, the receiver ends up calling the same `_apply_delta_payload Selective overwrite has no arithmetic — the receiver writes the trainer's exact bytes at changed positions — so it's lossless by construction and there's no notion of drift to fight with periodic base re-syncs. +## Publish-Only Disk Delta + +The disk path above pushes each version to known engines: rank 0 calls every engine's `update_weights_from_disk(load_format="delta")` and the sync ends when all engines acknowledge. That requires stable engine handles. When the serving side is an elastic fleet that consumes published versions on its own schedule — e.g. behind an [opaque HTTP rollout endpoint](external-rollout-engines.md#opaque-http-rollout-endpoint) — invert the direction with publish-only mode: + +```bash +--update-weight-mode delta +--update-weight-transport disk +--update-weight-delta-publish-only +--custom-delta-publish-path my_pkg.publish.publish_delta +--update-weight-delta-keep-files +``` + +Instead of firing per-engine RPCs, rank 0 invokes your publish hook once per sync, after every delta file has been written and the optional `--custom-delta-pre-push-path` hook has committed: + +```python +def publish_delta(args, version_dir: str, files: list[str], weight_version: str, rollout_engines) -> list | None: + ... # e.g. upload version_dir to object storage, then announce weight_version +``` + +Returned Ray ObjectRefs are awaited before the version counts as settled. Behavior differences from the direct disk path: + +- **One complete version per sync.** Direct disk transport publishes at each pass boundary so receivers can overlap apply with later encoding; publish-only defers everything to finalize, so external consumers never observe a partially published version. +- **Publish wait is configurable.** By default, `--update-weight-delta-publish-wait=next-sync` leaves the dispatched publish in flight across the next training step and settles it at the start of the next sync (or on disconnect). A failed publish therefore surfaces one sync late, on rank 0. Set `--update-weight-delta-publish-wait=sync` when the publish hook should block `update_weights`, for example because it polls an external rollout fleet until enough replicas report the new version ready. +- **Engines are left alone.** Generation is not paused, caches are not flushed, and no update RPCs are issued; consumers decide when to pick up a version. If the rollout endpoint supports request-level weight constraints, attach them from a `--custom-rollout-request-hook-path` hook so requests routed to lagging replicas fail/retry before doing unusable rollout compute. +- **No cleanup.** slime cannot know when consumers finish reading a version, so `--update-weight-delta-keep-files` is required and version-directory lifecycle belongs to you (e.g. the publish hook can prune old versions once uploaded). +- **No-op versions still publish.** If a sync produces no changed bytes, the hook is still called with an empty file list so consumers' version counters can advance. + +`--update-weight-delta-root` optionally names a root directory for publish-side metadata; it defaults to the parent of `--update-weight-disk-dir` and is passed through to hooks via `args`. + ## Encoding Choice `--update-weight-encoding` picks how positions are packed. All three share the same on-wire layout (`__positions__` uint8 blob + `__values__` tensor + per-param manifest); decoder dispatches on the metadata. diff --git a/docs/en/advanced/external-rollout-engines.md b/docs/en/advanced/external-rollout-engines.md index 498afa0c5f..a48b3a0eb9 100644 --- a/docs/en/advanced/external-rollout-engines.md +++ b/docs/en/advanced/external-rollout-engines.md @@ -2,13 +2,15 @@ An external rollout engine is an SGLang engine that is not launched by the slime training job. Another system deploys and owns the engine lifecycle; slime connects to those engines during training, registers a router, and syncs updated actor weights when needed. -This page is a roadmap. Use it to decide when to use `--rollout-external-engine-addrs`, when to stay with `--sglang-config`, and which weight-update path to pick for external deployments. +This page is a roadmap. Use it to decide when to use `--rollout-external-engine-addrs`, when to use `--rollout-http-endpoint-url`, when to stay with `--sglang-config`, and which weight-update path to pick for external deployments. ## Where To Start | Goal | Recommended entry point | | :--- | :--- | | Engines are already launched externally and slime should only connect for rollout | `--rollout-external-engine-addrs` | +| Rollout serving is an elastic fleet behind a single HTTP URL, with no stable per-engine handles | `--rollout-http-endpoint-url` | +| The serving side pulls published weight versions instead of receiving direct update RPCs | `--update-weight-delta-publish-only`, see [Publish-Only Disk Delta](delta-weight-sync.md#publish-only-disk-delta) | | slime should still launch engines, but you need PD disaggregation, multi-model serving, heterogeneous server groups, or per-group overrides | [SGLang Config](sglang-config.md) | | Trainer and external engines can form an NCCL group | Default `--update-weight-mode full --update-weight-transport nccl` | | Trainer and external engines cannot form an NCCL group, but can see the same filesystem path | `--update-weight-mode full --update-weight-transport disk` | @@ -38,6 +40,27 @@ slime queries each engine's `/server_info` or `/get_server_info` endpoint and in This path fits deployments where serving is owned outside the training job: a separate inference cluster, a separate Ray cluster, manually warmed SGLang engines, or a rollout service managed by another orchestrator. +## Opaque HTTP Rollout Endpoint + +`--rollout-external-engine-addrs` still assumes SGLang engines with stable addresses: slime queries `/server_info` per engine, registers each one with a router, and pushes weight updates to known engine handles. Some deployments cannot offer that contract — for example a serverless or autoscaled inference fleet behind one URL, where workers come and go and no worker-management API is exposed. For those, point slime at the endpoint directly: + +```bash +python train.py \ + --rollout-http-endpoint-url https://rollout.example.com \ + ... +``` + +In this mode slime launches no engines and no router, and assumes nothing about the endpoint beyond the generation route: rollout requests POST to `{url}/generate`, and `get_model_url(args, ...)` in custom rollout functions resolves to the endpoint as well. No rollout GPUs are allocated in the placement group, `/server_info` is never queried, and slime fault tolerance does not manage the fleet — recovery is the endpoint operator's job. `--rollout-http-endpoint-url` and `--rollout-external-engine-addrs` are mutually exclusive. + +Two companion flags adapt the default SGLang rollout to an endpoint that lacks router APIs: + +- `--rollout-http-endpoint-abort-strategy {cancel-only,router-workers}`: how `abort` behaves between rollouts. `cancel-only` (the default when an endpoint URL is set) cancels slime's local pending generation tasks without calling the router's worker-list or per-worker abort APIs. `router-workers` keeps the existing router-based abort and remains the default otherwise. Note that `cancel-only` does not collect partial samples, so it does not compose with `--partial-rollout`. +- `--custom-rollout-request-hook-path`: optional hook called before each default SGLang `/generate` request. Signature: `def hook(args, sample, request) -> None | dict`. The `request` dict contains `url`, `payload`, `headers`, `max_retries`, `retry_sleep`, `rollout_id`, and `evaluation`; mutate it in place or return a dict of updates. + +Use the request hook for rollout-endpoint admission control. For example, a hook may attach `"weight_version": {"exact_version": }` or `"weight_version": {"min_required_version": }` and increase `max_retries`/`retry_sleep`. Those request fields avoid wasted rollout compute when an opaque router sends the request to a replica that has not loaded a usable version yet. They do not define SLIME's off-policy or staleness semantics; the trainer schedule and loss/correction path still decide which versions are valid. + +For weight sync, an elastic fleet usually cannot receive per-engine `update_weights_from_disk` RPCs either. Combine the endpoint with publish-only delta sync, where the trainer publishes each complete weight version through a custom hook and the serving side consumes it on its own schedule — see [Publish-Only Disk Delta](delta-weight-sync.md#publish-only-disk-delta). If request-level minimum-version retry is enough, leave publish-only in its default pipelined mode. If the publish hook polls rollout-fleet status and you want the next rollout dispatch to wait for that readiness threshold, set `--update-weight-delta-publish-wait=sync`. + ## Relationship With `--sglang-config` `--rollout-external-engine-addrs` and `--sglang-config` are mutually exclusive because they own different boundaries: @@ -108,8 +131,9 @@ For encoding choices, wire layout, receiver-side selective overwrite, and tuning - External engines can use an independent SGLang environment; they do not need the slime or Megatron training environment. - Disk transport supports different GPU models or vendors between training and rollout, as long as SGLang supports the target hardware and model format. - Disk transport requires trainer and SGLang engines to see the same `--update-weight-disk-dir` path; a path visible only to the trainer is not enough. -- External engines are not recovered by slime fault tolerance; their lifecycle belongs to the external deployment system. -- `--sglang-config` and `--rollout-external-engine-addrs` are mutually exclusive. +- External engines are not recovered by slime fault tolerance; their lifecycle belongs to the external deployment system. The same applies to fleets behind `--rollout-http-endpoint-url`. +- `--sglang-config` and `--rollout-external-engine-addrs` are mutually exclusive, as are `--rollout-external-engine-addrs` and `--rollout-http-endpoint-url`. +- An opaque HTTP endpoint only needs to serve the generation route; worker-management APIs are never called. If the fleet cannot accept direct weight-update RPCs, use publish-only delta sync. - Delta mode does not support `--colocate`, because colocated sync uses CUDA IPC handles and delta encoding does not reduce the actual transfer. ## Related Work diff --git a/docs/en/get_started/customization.md b/docs/en/get_started/customization.md index 77f5cd5e34..389cda02b6 100644 --- a/docs/en/get_started/customization.md +++ b/docs/en/get_started/customization.md @@ -28,6 +28,7 @@ Below is a summary of all available customization interfaces and their purposes. | [`--custom-megatron-init-path`](#17-megatron-hooks) | Custom initialization after Megatron setup. | | [`--custom-megatron-before-log-prob-hook-path`](#17-megatron-hooks) | Custom logic before log probability computation. | | [`--custom-megatron-before-train-step-hook-path`](#17-megatron-hooks) | Custom logic before each training step. | +| [`--custom-rollout-request-hook-path`](#19-rollout-request-hook---custom-rollout-request-hook-path) | Customize each default SGLang `/generate` request before dispatch. | ## Agentic workflows through customization interfaces @@ -457,6 +458,25 @@ Stabilize MoE RL training by recording and replaying expert routing decisions to | `--use-routing-replay` | Forward-backward routing consistency in training. ([arXiv:2507.18071](https://arxiv.org/abs/2507.18071)) | | `--use-rollout-routing-replay` | R3: Replay routing from rollout during training. Supported by slime's default `sglang_router` path. ([arXiv:2510.11370](https://arxiv.org/abs/2510.11370)) | +--- + +### 19. Rollout Request Hook (`--custom-rollout-request-hook-path`) + +**Signature**: +```python +def hook(args, sample, request) -> None | dict +``` + +**Purpose**: Customize each default SGLang rollout `/generate` request before it +is sent. `request` contains `url`, `payload`, `headers`, `max_retries`, +`retry_sleep`, `rollout_id`, and `evaluation`. Mutate it in place or return a +dict of updates. + +This hook is useful for external rollout providers that need request-level +admission control, for example adding `payload["weight_version"]` so a request +routed to a lagging replica fails and retries before doing unusable rollout +compute. + ## Testing Custom Function Paths slime also provides CPU-only contract tests for customization interfaces. These tests resolve components through import-path strings, so they can validate both built-in hooks and user-defined implementations passed through the same CLI arguments used by training. @@ -470,7 +490,7 @@ The tests live under `tests/plugin_contracts/` and are grouped by hook shape: - `tests/plugin_contracts/test_plugin_path_loading_contracts.py` Covers `--eval-function-path`, `--custom-rm-path`, `--dynamic-sampling-filter-path`, `--buffer-filter-path`, `--data-source-path`, `--rollout-sample-filter-path`, and `--rollout-all-samples-process-path` - `tests/plugin_contracts/test_plugin_runtime_hook_contracts.py` - Covers `--custom-rollout-log-function-path`, `--custom-eval-rollout-log-function-path`, `--custom-reward-post-process-path`, `--custom-convert-samples-to-train-data-path`, and `--rollout-data-postprocess-path` + Covers `--custom-rollout-log-function-path`, `--custom-eval-rollout-log-function-path`, `--custom-reward-post-process-path`, `--custom-convert-samples-to-train-data-path`, `--rollout-data-postprocess-path`, and `--custom-rollout-request-hook-path` Run all customization contract tests locally: diff --git a/docs/zh/advanced/delta-weight-sync.md b/docs/zh/advanced/delta-weight-sync.md index f009dc954a..ad17ac23f9 100644 --- a/docs/zh/advanced/delta-weight-sync.md +++ b/docs/zh/advanced/delta-weight-sync.md @@ -4,6 +4,7 @@ - [快速开始](#快速开始) - [同步模式与传输方式](#同步模式与传输方式) - [工作原理](#工作原理) +- [Publish-Only 磁盘 Delta](#publish-only-磁盘-delta) - [编码选择](#编码选择) - [为何不支持 colocated](#为何不支持-colocated) @@ -88,6 +89,35 @@ Delta NCCL 和 delta 磁盘共用同一条发送管线、同一种 wire 布局 选择性覆写没有任何算术运算 —— 接收端在变化位置直接写入训练端的精确字节 —— 因此天然无损,也不存在数值漂移问题,无需周期性 base 同步。 +## Publish-Only 磁盘 Delta + +上面的磁盘路径把每个版本推送给已知 engine:rank 0 调用每个 engine 的 `update_weights_from_disk(load_format="delta")`,所有 engine 确认后同步才结束。这要求 engine 句柄稳定。当 serving 侧是一个按自己节奏消费已发布版本的弹性集群——例如位于 [opaque HTTP rollout endpoint](external-rollout-engines.md#opaque-http-rollout-endpoint) 之后——可以用 publish-only 模式反转方向: + +```bash +--update-weight-mode delta +--update-weight-transport disk +--update-weight-delta-publish-only +--custom-delta-publish-path my_pkg.publish.publish_delta +--update-weight-delta-keep-files +``` + +rank 0 不再发出 per-engine RPC,而是在每次同步中调用一次你的 publish hook——此时所有 delta 文件已经写完,可选的 `--custom-delta-pre-push-path` hook 也已提交: + +```python +def publish_delta(args, version_dir: str, files: list[str], weight_version: str, rollout_engines) -> list | None: + ... # 例如把 version_dir 上传到对象存储,然后公告 weight_version +``` + +返回的 Ray ObjectRef 会在该版本视为完成之前被等待。与直接磁盘路径的行为差异: + +- **每次同步发布一个完整版本。** 直接磁盘传输在每个 pass 边界发布,让接收端的 apply 与后续编码重叠;publish-only 把所有发布推迟到 finalize,外部消费者永远不会看到只发布了一半的版本。 +- **发布等待可配置。** 默认 `--update-weight-delta-publish-wait=next-sync` 会让已派发的 publish 在下一个训练 step 期间保持 in flight,并在下一次同步开始时(或 disconnect 时)结算。因此 publish 失败会晚一个同步周期才在 rank 0 上暴露。如果 publish hook 会轮询外部 rollout 集群、并且希望下一次 rollout dispatch 等到足够副本就绪后再开始,可以设置 `--update-weight-delta-publish-wait=sync`。 +- **不打扰 engine。** 不暂停生成、不清空 cache、不发出任何 update RPC;消费者自己决定何时拉取新版本。如果 rollout endpoint 支持请求级权重约束,可以在 `--custom-rollout-request-hook-path` hook 中附加这些约束,让路由到落后副本的请求尽早失败并重试,避免生成不可用样本。 +- **不做清理。** slime 无法知道消费者何时读完一个版本,所以必须加 `--update-weight-delta-keep-files`,版本目录的生命周期由你负责(例如 publish hook 可以在上传完成后清理旧版本)。 +- **空 delta 也会发布。** 如果某次同步没有任何字节变化,hook 仍会以空文件列表被调用,让消费者的版本计数得以推进。 + +`--update-weight-delta-root` 可选地指定发布侧元数据的根目录;缺省为 `--update-weight-disk-dir` 的父目录,并通过 `args` 透传给 hook。 + ## 编码选择 `--update-weight-encoding` 决定位置如何打包。三种编码共用同一种 wire 布局(`__positions__` uint8 块 + `__values__` 张量 + per-param manifest),解码端根据 metadata 分派。 diff --git a/docs/zh/advanced/external-rollout-engines.md b/docs/zh/advanced/external-rollout-engines.md index 9aae0ef5ec..46dd999225 100644 --- a/docs/zh/advanced/external-rollout-engines.md +++ b/docs/zh/advanced/external-rollout-engines.md @@ -2,13 +2,15 @@ External rollout engine 指的是:SGLang engine 不由 slime 训练任务启动,而是由外部系统预先部署和管理;slime 只在训练时连接这些 engine,注册 router,并在需要时同步训练后的 actor 权重。 -这篇文档是一个导航页。它帮助你判断什么时候该用 `--rollout-external-engine-addrs`,什么时候该继续使用 `--sglang-config`,以及 external 场景下该选择 full checkpoint update from disk 还是 delta update。 +这篇文档是一个导航页。它帮助你判断什么时候该用 `--rollout-external-engine-addrs`,什么时候该用 `--rollout-http-endpoint-url`,什么时候该继续使用 `--sglang-config`,以及 external 场景下该选择 full checkpoint update from disk 还是 delta update。 ## 从哪里开始 | 目标 | 推荐入口 | | :--- | :--- | | engine 已经由外部系统启动,只想让 slime 连接并做 rollout | `--rollout-external-engine-addrs` | +| rollout serving 是单一 HTTP URL 背后的弹性集群,没有稳定的 per-engine 句柄 | `--rollout-http-endpoint-url` | +| serving 侧主动拉取发布的权重版本,而不是接收直接的 update RPC | `--update-weight-delta-publish-only`,见 [Publish-Only 磁盘 Delta](delta-weight-sync.md#publish-only-磁盘-delta) | | engine 仍由 slime 启动,但需要 PD 分离、多模型、异构 server group 或 per-group overrides | [SGLang Config](sglang-config.md) | | 训练器和 external engine 可以建立 NCCL group | 默认的 `--update-weight-mode full --update-weight-transport nccl` | | 训练器和 external engine 不能建立 NCCL group,但能共享同一路径的文件系统 | `--update-weight-mode full --update-weight-transport disk` | @@ -38,6 +40,27 @@ slime 会请求每个 engine 的 `/server_info` 或 `/get_server_info`,推断 这条路径适合 serving 生命周期由训练任务外部管理的部署:例如独立的推理集群、跨 Ray 集群部署、手工预热的 SGLang engine,或由其他编排系统管理的 rollout service。 +## Opaque HTTP Rollout Endpoint + +`--rollout-external-engine-addrs` 仍然假设 SGLang engine 有稳定地址:slime 会逐个查询 `/server_info`,把每个 engine 注册到 router,并向已知 engine 句柄推送权重更新。有些部署无法提供这种契约——例如单一 URL 背后的 serverless 或自动扩缩容推理集群,worker 随时增减,也不暴露任何 worker 管理 API。这种情况下让 slime 直接指向 endpoint: + +```bash +python train.py \ + --rollout-http-endpoint-url https://rollout.example.com \ + ... +``` + +在这个模式下,slime 不启动任何 engine 和 router,对 endpoint 的假设只有生成路由:rollout 请求 POST 到 `{url}/generate`,自定义 rollout function 里的 `get_model_url(args, ...)` 也解析到该 endpoint。placement group 中不会分配 rollout GPU,`/server_info` 永远不会被查询,slime 的 fault tolerance 也不管理这个集群——故障恢复由 endpoint 运营方负责。`--rollout-http-endpoint-url` 与 `--rollout-external-engine-addrs` 互斥。 + +两个配套参数让默认 SGLang rollout 适配没有 router API 的 endpoint: + +- `--rollout-http-endpoint-abort-strategy {cancel-only,router-workers}`:控制两次 rollout 之间 `abort` 的行为。`cancel-only`(设置 endpoint URL 时的默认值)只取消 slime 本地待完成的生成任务,不调用 router 的 worker 列表或 per-worker abort API。`router-workers` 保留原有基于 router 的 abort,在其他情况下仍是默认值。注意 `cancel-only` 不收集 partial sample,因此与 `--partial-rollout` 不兼容。 +- `--custom-rollout-request-hook-path`:可选 hook,在默认 SGLang `/generate` 请求发出前调用。签名为 `def hook(args, sample, request) -> None | dict`。`request` dict 包含 `url`、`payload`、`headers`、`max_retries`、`retry_sleep`、`rollout_id` 和 `evaluation`;可以原地修改,也可以返回一个 dict 覆盖字段。 + +请求级权重约束应通过这个 hook 添加。例如 hook 可以加入 `"weight_version": {"exact_version": }` 或 `"weight_version": {"min_required_version": }`,并调整 `max_retries`/`retry_sleep`。这些字段用于 opaque router 把请求路由到落后副本时尽早失败并重试,避免浪费 rollout compute;它们不定义 SLIME 的 off-policy 或 staleness 语义,真正的有效版本仍由训练调度和 loss/correction 路径决定。 + +至于权重同步,弹性集群通常也无法接收 per-engine 的 `update_weights_from_disk` RPC。可以把 endpoint 与 publish-only delta 同步组合使用:训练端通过自定义 hook 发布每个完整的权重版本,serving 侧按自己的节奏消费——见 [Publish-Only 磁盘 Delta](delta-weight-sync.md#publish-only-磁盘-delta)。如果请求级最低版本重试已经足够,保留 publish-only 的默认流水线模式即可;如果 publish hook 会轮询 rollout 集群状态、并且你希望下一次 rollout dispatch 等待该就绪阈值,可以设置 `--update-weight-delta-publish-wait=sync`。 + ## 与 `--sglang-config` 的关系 `--rollout-external-engine-addrs` 和 `--sglang-config` 互斥,因为它们负责不同的边界: @@ -108,8 +131,9 @@ delta update 面向大模型、跨集群或跨数据中心训推解耦。它不 - external engine 可以使用独立 SGLang 环境;不需要安装 slime 或 Megatron 训练环境。 - disk transport 支持训练和 rollout 使用不同型号或不同厂家的 GPU,前提是 SGLang 支持对应硬件和模型格式。 - disk transport 要求训练端和 SGLang engine 看到同一个 `--update-weight-disk-dir` 路径;路径只在训练端可见是不够的。 -- external engine 当前不支持 slime 的 fault tolerance 恢复流程;engine 生命周期由外部系统负责。 -- `--sglang-config` 与 `--rollout-external-engine-addrs` 互斥。 +- external engine 当前不支持 slime 的 fault tolerance 恢复流程;engine 生命周期由外部系统负责。`--rollout-http-endpoint-url` 背后的集群同理。 +- `--sglang-config` 与 `--rollout-external-engine-addrs` 互斥;`--rollout-external-engine-addrs` 与 `--rollout-http-endpoint-url` 也互斥。 +- opaque HTTP endpoint 只需要提供生成路由;slime 不会调用任何 worker 管理 API。如果集群无法接收直接的权重更新 RPC,请使用 publish-only delta 同步。 - delta mode 不支持 `--colocate`,因为 colocated 权重同步通过 CUDA IPC 传句柄,delta 编码不会节省实际传输量。 ## 参考工作 diff --git a/docs/zh/get_started/customization.md b/docs/zh/get_started/customization.md index 5b95f05463..eb92095c08 100644 --- a/docs/zh/get_started/customization.md +++ b/docs/zh/get_started/customization.md @@ -28,6 +28,7 @@ slime 通过函数路径参数提供了广泛的自定义能力。这些参数 | [`--custom-megatron-init-path`](#17-megatron-hook) | Megatron 设置后的自定义初始化。 | | [`--custom-megatron-before-log-prob-hook-path`](#17-megatron-hook) | log probability 计算前的自定义逻辑。 | | [`--custom-megatron-before-train-step-hook-path`](#17-megatron-hook) | 每个训练步骤前的自定义逻辑。 | +| [`--custom-rollout-request-hook-path`](#19-rollout-request-hook---custom-rollout-request-hook-path) | 在默认 SGLang `/generate` 请求发出前自定义请求。 | ## 通过 customization 接口实现 agentic workflow @@ -459,6 +460,22 @@ def custom_hook(args, rollout_id, step_id, model, optimizer, opt_param_scheduler | `--use-routing-replay` | 训练中前向-反向路由一致性。([arXiv:2507.18071](https://arxiv.org/abs/2507.18071)) | | `--use-rollout-routing-replay` | R3:在训练时重放 rollout 阶段的路由。slime 默认的 `sglang_router` 路径支持该功能。([arXiv:2510.11370](https://arxiv.org/abs/2510.11370)) | +--- + +### 19. Rollout Request Hook (`--custom-rollout-request-hook-path`) + +**函数签名**: +```python +def hook(args, sample, request) -> None | dict +``` + +**用途**: 在默认 SGLang rollout `/generate` 请求发出前自定义该请求。`request` +包含 `url`、`payload`、`headers`、`max_retries`、`retry_sleep`、`rollout_id` +和 `evaluation`。可以原地修改,也可以返回一个 dict 覆盖字段。 + +这个 hook 适合外部 rollout provider 的请求级 admission control,例如加入 +`payload["weight_version"]`,让路由到落后副本的请求在生成不可用样本前失败并重试。 + ## 自定义函数路径的测试 slime 现在也提供了一组 CPU 契约测试,用于校验这些 customization 接口。测试会通过字符串形式的导入路径来动态加载组件,因此既能回归仓库内置 hook,也能验证用户通过和训练时完全相同的 CLI 参数传入的自定义实现。 @@ -472,7 +489,7 @@ slime 现在也提供了一组 CPU 契约测试,用于校验这些 customizati - `tests/plugin_contracts/test_plugin_path_loading_contracts.py` 覆盖 `--eval-function-path`、`--custom-rm-path`、`--dynamic-sampling-filter-path`、`--buffer-filter-path`、`--data-source-path`、`--rollout-sample-filter-path`、`--rollout-all-samples-process-path` - `tests/plugin_contracts/test_plugin_runtime_hook_contracts.py` - 覆盖 `--custom-rollout-log-function-path`、`--custom-eval-rollout-log-function-path`、`--custom-reward-post-process-path`、`--custom-convert-samples-to-train-data-path`、`--rollout-data-postprocess-path` + 覆盖 `--custom-rollout-log-function-path`、`--custom-eval-rollout-log-function-path`、`--custom-reward-post-process-path`、`--custom-convert-samples-to-train-data-path`、`--rollout-data-postprocess-path`、`--custom-rollout-request-hook-path` 本地运行全部 customization 契约测试: From 4ea02f0ee6a4cef5ebcd90fd1c85888035e3d85b Mon Sep 17 00:00:00 2001 From: Jason Mancuso <7891333+jvmncs@users.noreply.github.com> Date: Wed, 17 Jun 2026 16:20:18 -0400 Subject: [PATCH 6/6] Fix HTTP-endpoint rollout: return (servers, init_handles) tuple MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit start_rollout_servers' HTTP-endpoint branch returned the bare servers dict from start_http_endpoint_rollout_servers, but the caller (RolloutManager.__init__, rollout.py:435) and this function's `-> tuple[dict, list]` annotation expect a (servers, init_handles) tuple — every other branch returns one. In endpoint mode this raised `ValueError: not enough values to unpack (expected 2, got 1)` at RolloutManager init. HTTP endpoints have no local engine init handles, so []. Co-Authored-By: Claude Opus 4.8 (1M context) --- slime/ray/rollout.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index 865682db44..b85f34a5e7 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -1082,7 +1082,11 @@ def start_rollout_servers(args, pg) -> tuple[dict[str, Any], list[Any]]: as the HTTP client is shared across all servers. """ if uses_rollout_http_endpoint(args): - return start_http_endpoint_rollout_servers(args) + # HTTP endpoints have no local engines to initialize, so there are no + # pending init handles. Return the (servers, init_handles) tuple the + # caller (RolloutManager.__init__) and this function's annotation expect, + # matching the other branches below. + return start_http_endpoint_rollout_servers(args), [] if args.rollout_external: return start_external_rollout_servers(args, start_router=_start_router)