-
Notifications
You must be signed in to change notification settings - Fork 359
[fix] Honor served_model_name and surface HTTP errors in RemoteInferenceEngine #1783
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -115,12 +115,14 @@ def create_ray_wrapped_inference_engines_from_config( | |
|
|
||
|
|
||
| def create_remote_inference_engines_from_config(cfg: SkyRLTrainConfig, tokenizer: PreTrainedTokenizerBase): | ||
| # TODO(tgriggs): We may want a separate config for the model name in case | ||
| # it's different from the name used in the OpenAI API | ||
| ie_cfg = cfg.generator.inference_engine | ||
| # Use served_model_name if provided, otherwise fall back to the model path. | ||
| # served_model_name allows using a different model name for HTTP requests than the actual | ||
| # model path. See InferenceEngineConfig.served_model_name in skyrl/train/config/config.py. | ||
| model_name = ie_cfg.served_model_name if ie_cfg.served_model_name is not None else cfg.trainer.policy.model.path | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If both model_name = ie_cfg.served_model_name if ie_cfg.served_model_name is not None else cfg.trainer.policy.model.path
if not model_name:
raise ValueError(
"Model name must be specified. Please set either `generator.inference_engine.served_model_name` "
"or `trainer.policy.model.path`."
) |
||
| return create_remote_inference_engines( | ||
| urls=ie_cfg.remote_urls, | ||
| model_name=cfg.trainer.policy.model.path, | ||
| model_name=model_name, | ||
| engine_backend=ie_cfg.backend, | ||
| tokenizer=tokenizer, | ||
| tensor_parallel_size=ie_cfg.tensor_parallel_size, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,178 @@ | ||
| """ | ||
| Tests for `skyrl/backends/skyrl_train/inference_engines/remote_inference_engine.py`. | ||
|
|
||
| Run with: | ||
| uv run --isolated --extra skyrl-train --extra dev pytest tests/backends/skyrl_train/inference_engines/test_remote_inference_engine.py | ||
| """ | ||
|
|
||
| import asyncio | ||
| import threading | ||
| import time | ||
|
|
||
| import httpx | ||
| import pytest | ||
| import uvicorn | ||
| from fastapi import FastAPI, Request | ||
| from fastapi.responses import JSONResponse | ||
|
|
||
| from skyrl.backends.skyrl_train.inference_engines.remote_inference_engine import ( | ||
| RemoteInferenceEngine, | ||
| ) | ||
| from skyrl.backends.skyrl_train.inference_servers.common import get_open_port | ||
| from skyrl.train.config import ( | ||
| GeneratorConfig, | ||
| InferenceEngineConfig, | ||
| ModelConfig, | ||
| PolicyConfig, | ||
| SkyRLTrainConfig, | ||
| TrainerConfig, | ||
| ) | ||
|
|
||
| MODEL_PATH = "org/test-model" | ||
| SERVED_MODEL_NAME = "test-model" | ||
|
|
||
|
|
||
| class FakeTokenizer: | ||
| """Minimal tokenizer stub; `generate()` only calls `encode` on response texts.""" | ||
|
|
||
| def encode(self, text, add_special_tokens=False): | ||
| return list(range(len(text))) | ||
|
|
||
|
|
||
| def create_mock_vllm_server() -> FastAPI: | ||
| """Mock vLLM OpenAI-compatible server that only knows `SERVED_MODEL_NAME`.""" | ||
| app = FastAPI() | ||
|
|
||
| @app.get("/health") | ||
| async def health(): | ||
| return {"status": "ok"} | ||
|
|
||
| @app.post("/v1/completions") | ||
| async def completions(request: Request): | ||
| body = await request.json() | ||
| model = body.get("model") | ||
| if model != SERVED_MODEL_NAME: | ||
| # Mirrors vLLM's 404 response body for an unknown model name. | ||
| return JSONResponse( | ||
| status_code=404, | ||
| content={ | ||
| "error": { | ||
| "message": f"The model `{model}` does not exist.", | ||
| "type": "NotFoundError", | ||
| "param": "model", | ||
| "code": 404, | ||
| } | ||
| }, | ||
| ) | ||
| prompts = body.get("prompt", []) | ||
| return { | ||
| "choices": [{"index": i, "text": f"response {i}", "finish_reason": "stop"} for i in range(len(prompts))], | ||
| "model": model, | ||
| } | ||
|
|
||
| return app | ||
|
|
||
|
|
||
| def wait_ready(url: str, timeout: float = 5.0) -> bool: | ||
| """Wait for server to become healthy.""" | ||
| start = time.time() | ||
| while time.time() - start < timeout: | ||
| try: | ||
| if httpx.get(f"{url}/health", timeout=1.0).status_code == 200: | ||
| return True | ||
| except httpx.RequestError: | ||
| time.sleep(0.1) | ||
| return False | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module") | ||
| def mock_server(): | ||
| """Start a mock vLLM server, return its host:port (no scheme).""" | ||
| port = get_open_port() | ||
| config = uvicorn.Config(create_mock_vllm_server(), host="127.0.0.1", port=port, log_level="error") | ||
| server = uvicorn.Server(config) | ||
| threading.Thread(target=lambda: asyncio.run(server.serve()), daemon=True).start() | ||
| assert wait_ready(f"http://127.0.0.1:{port}"), "Mock server failed to start" | ||
|
|
||
| yield f"127.0.0.1:{port}" | ||
|
|
||
| server.should_exit = True | ||
| time.sleep(0.3) | ||
|
|
||
|
|
||
| def _make_engine(mock_server: str, model_name: str) -> RemoteInferenceEngine: | ||
| return RemoteInferenceEngine( | ||
| url=mock_server, | ||
| model_name=model_name, | ||
| engine_backend="vllm", | ||
| tokenizer=FakeTokenizer(), | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_generate_parses_choices(mock_server): | ||
| """Happy path: a 200 response with choices is parsed into InferenceEngineOutput.""" | ||
| engine = _make_engine(mock_server, SERVED_MODEL_NAME) | ||
| output = await engine.generate({"prompt_token_ids": [[1, 2, 3], [4, 5]], "sampling_params": {"max_tokens": 4}}) | ||
|
|
||
| assert output["responses"] == ["response 0", "response 1"] | ||
| assert output["stop_reasons"] == ["stop", "stop"] | ||
| assert len(output["response_ids"]) == 2 | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_generate_raises_on_http_error(mock_server): | ||
| """A non-200 response (e.g. vLLM's 404 for an unknown model name) must raise with the | ||
| status and error body instead of being silently parsed into empty outputs. | ||
|
|
||
| See https://github.com/NovaSky-AI/SkyRL/issues/1672. | ||
| """ | ||
| engine = _make_engine(mock_server, MODEL_PATH) | ||
| with pytest.raises(RuntimeError, match=r"404.*does not exist") as exc_info: | ||
| await engine.generate({"prompt_token_ids": [[1, 2, 3]], "sampling_params": {"max_tokens": 4}}) | ||
|
|
||
| # The served model name should appear in the error to make the mismatch debuggable. | ||
| assert MODEL_PATH in str(exc_info.value) | ||
|
|
||
|
|
||
| # ------------------------------------------- | ||
| # tests for create_remote_inference_engines_from_config | ||
| # -------------------------------------------- | ||
|
|
||
|
|
||
| def _make_remote_cfg(served_model_name=None) -> SkyRLTrainConfig: | ||
| return SkyRLTrainConfig( | ||
| trainer=TrainerConfig( | ||
| policy=PolicyConfig(model=ModelConfig(path=MODEL_PATH)), | ||
| ), | ||
| generator=GeneratorConfig( | ||
| inference_engine=InferenceEngineConfig( | ||
| backend="vllm", | ||
| run_engines_locally=False, | ||
| remote_urls=["127.0.0.1:8000"], | ||
| served_model_name=served_model_name, | ||
| ), | ||
| ), | ||
| ) | ||
|
|
||
|
|
||
| def test_create_remote_engines_uses_served_model_name(): | ||
| """`generator.inference_engine.served_model_name` is used as the model name when set.""" | ||
| from skyrl.train.entrypoints.main_base import ( | ||
| create_remote_inference_engines_from_config, | ||
| ) | ||
|
|
||
| cfg = _make_remote_cfg(served_model_name=SERVED_MODEL_NAME) | ||
| engines = create_remote_inference_engines_from_config(cfg, tokenizer=FakeTokenizer()) | ||
| assert all(engine.model_name == SERVED_MODEL_NAME for engine in engines) | ||
|
|
||
|
|
||
| def test_create_remote_engines_falls_back_to_model_path(): | ||
| """Without `served_model_name`, the policy model path is used as the model name.""" | ||
| from skyrl.train.entrypoints.main_base import ( | ||
| create_remote_inference_engines_from_config, | ||
| ) | ||
|
|
||
| cfg = _make_remote_cfg(served_model_name=None) | ||
| engines = create_remote_inference_engines_from_config(cfg, tokenizer=FakeTokenizer()) | ||
| assert all(engine.model_name == MODEL_PATH for engine in engines) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using
await resp.text()without specifying error handling can raise aUnicodeDecodeErrorif the remote server returns a non-UTF-8 error response (e.g., binary data or corrupted encoding on a 500 Internal Server Error). This secondary exception would mask the original HTTP status code and make debugging harder. Consider usingerrors="replace"to gracefully handle any decoding issues.