diff --git a/src/runpod_flash/cli/commands/build_utils/handler_generator.py b/src/runpod_flash/cli/commands/build_utils/handler_generator.py index c2ad8498..1247c758 100644 --- a/src/runpod_flash/cli/commands/build_utils/handler_generator.py +++ b/src/runpod_flash/cli/commands/build_utils/handler_generator.py @@ -63,6 +63,49 @@ def handler(job): runpod.serverless.start({{"handler": handler}}) ''' +DEPLOYED_ASYNC_HANDLER_TEMPLATE = '''""" +Auto-generated deployed handler for resource: {resource_name} +Generated at: {timestamp} + +Concurrent async endpoint handler: accepts plain JSON, no cloudpickle. +max_concurrency={max_concurrency} + +This file is generated by the Flash build process. Do not edit manually. +""" + +import importlib +import logging +import traceback + +_logger = logging.getLogger(__name__) + +# Import the function for this endpoint +{import_statement} + + +async def handler(job): + """Async handler for concurrent QB endpoint. Accepts plain JSON kwargs.""" + job_input = job.get("input", {{}}) + try: + result = await {function_name}(**job_input) + return result + except Exception as e: + _logger.error( + "Deployed handler error for {function_name}: %s", + e, + exc_info=True, + ) + return {{"error": str(e), "traceback": traceback.format_exc()}} + + +if __name__ == "__main__": + import runpod + runpod.serverless.start({{ + "handler": handler, + "concurrency_modifier": lambda current: {max_concurrency}, + }}) +''' + DEPLOYED_CLASS_HANDLER_TEMPLATE = '''""" Auto-generated deployed handler for class-based resource: {resource_name} Generated at: {timestamp} @@ -155,6 +198,84 @@ def handler(job): ''' +DEPLOYED_ASYNC_CLASS_HANDLER_TEMPLATE = '''""" +Auto-generated deployed handler for class-based resource: {resource_name} +Generated at: {timestamp} + +Concurrent async class endpoint handler: accepts plain JSON, no cloudpickle. +max_concurrency={max_concurrency} + +This file is generated by the Flash build process. Do not edit manually. +""" + +import importlib +import logging +import traceback + +_logger = logging.getLogger(__name__) + +# import the class for this endpoint +{import_statement} + +# instantiate once at cold start +_instance = {class_name}() + +# public methods available for dispatch +_METHODS = {methods_dict} + + +async def handler(job): + """Async handler for concurrent class-based QB endpoint. + + Dispatches to a method on the singleton class instance. + If the class has exactly one public method, input is passed + directly as kwargs. If multiple methods exist, input must + include a "method" key to select the target. + """ + job_input = job.get("input", {{}}) + try: + if len(_METHODS) == 1: + method_name = next(iter(_METHODS)) + else: + method_name = job_input.pop("method", None) + if method_name is None: + return {{ + "error": ( + "class {class_name} has multiple methods: " + + ", ".join(sorted(_METHODS)) + + ". include a \\"method\\" key in the input." + ) + }} + if method_name not in _METHODS: + return {{ + "error": ( + f"unknown method '{{method_name}}' on {class_name}. " + f"available: {{', '.join(sorted(_METHODS))}}" + ) + }} + + method = getattr(_instance, method_name) + result = await method(**job_input) + return result + except Exception as e: + _logger.error( + "Deployed handler error for {class_name}.%s: %s", + job_input.get("method", ""), + e, + exc_info=True, + ) + return {{"error": str(e), "traceback": traceback.format_exc()}} + + +if __name__ == "__main__": + import runpod + runpod.serverless.start({{ + "handler": handler, + "concurrency_modifier": lambda current: {max_concurrency}, + }}) +''' + + class HandlerGenerator: """Generates handler_.py files for each resource config.""" @@ -190,12 +311,7 @@ def generate_handlers(self) -> List[Path]: return handler_paths def _generate_handler(self, resource_name: str, resource_data: Any) -> Path: - """Generate a handler file for a deployed QB endpoint. - - Produces a plain-JSON handler that accepts raw dicts (no cloudpickle). - For classes, instantiates once at cold start and dispatches to methods. - For functions, calls the function directly with job input as kwargs. - """ + """Generate a handler file for a deployed QB endpoint.""" handler_filename = f"handler_{resource_name}.py" handler_path = self.build_dir / handler_filename @@ -211,8 +327,14 @@ def _generate_handler(self, resource_name: str, resource_data: Any) -> Path: else resource_data.get("functions", []) ) + max_concurrency = ( + resource_data.max_concurrency + if hasattr(resource_data, "max_concurrency") + else resource_data.get("max_concurrency", 1) + ) + handler_code = self._generate_deployed_handler_code( - resource_name, timestamp, functions + resource_name, timestamp, functions, max_concurrency ) handler_path.write_text(handler_code) @@ -225,11 +347,14 @@ def _generate_deployed_handler_code( resource_name: str, timestamp: str, functions: List[Any], + max_concurrency: int = 1, ) -> str: """Generate deployed handler code for a QB endpoint. - Selects the class template for class-based workers and the - function template for plain functions. + Selects template based on max_concurrency and is_async: + - max_concurrency=1: current sync templates (unchanged behavior) + - max_concurrency>1 + async: new async templates with concurrency_modifier + - max_concurrency>1 + sync: sync templates with concurrency_modifier injected """ if not functions: raise ValueError( @@ -243,6 +368,9 @@ def _generate_deployed_handler_code( is_class = ( func.is_class if hasattr(func, "is_class") else func.get("is_class", False) ) + is_async = ( + func.is_async if hasattr(func, "is_async") else func.get("is_async", False) + ) import_statement = ( f"{name} = importlib.import_module('{module}').{name}" @@ -250,6 +378,25 @@ def _generate_deployed_handler_code( else "# No function to import" ) + # 100 is a soft upper bound: most GPU workloads saturate VRAM well + # below this. The warning nudges users to verify resource capacity. + if max_concurrency > 100: + logger.warning( + "High max_concurrency=%d for resource '%s'. Ensure your handler " + "and GPU can support this level of concurrent execution.", + max_concurrency, + resource_name, + ) + + if max_concurrency > 1 and not is_async: + logger.warning( + "max_concurrency=%d set on sync handler '%s'. " + "Only async handlers benefit from concurrent execution. " + "Consider making the handler async.", + max_concurrency, + resource_name, + ) + if is_class: class_methods = ( func.class_methods @@ -257,20 +404,64 @@ def _generate_deployed_handler_code( else func.get("class_methods", []) ) methods_dict = {m: m for m in class_methods} if class_methods else {} - return DEPLOYED_CLASS_HANDLER_TEMPLATE.format( + + if max_concurrency > 1 and is_async: + return DEPLOYED_ASYNC_CLASS_HANDLER_TEMPLATE.format( + resource_name=resource_name, + timestamp=timestamp, + import_statement=import_statement, + class_name=name or "None", + methods_dict=repr(methods_dict), + max_concurrency=max_concurrency, + ) + + code = DEPLOYED_CLASS_HANDLER_TEMPLATE.format( resource_name=resource_name, timestamp=timestamp, import_statement=import_statement, class_name=name or "None", methods_dict=repr(methods_dict), ) + if max_concurrency > 1: + code = self._inject_concurrency_modifier(code, max_concurrency) + return code + + # Function-based handler + if max_concurrency > 1 and is_async: + return DEPLOYED_ASYNC_HANDLER_TEMPLATE.format( + resource_name=resource_name, + timestamp=timestamp, + import_statement=import_statement, + function_name=name or "None", + max_concurrency=max_concurrency, + ) - return DEPLOYED_HANDLER_TEMPLATE.format( + code = DEPLOYED_HANDLER_TEMPLATE.format( resource_name=resource_name, timestamp=timestamp, import_statement=import_statement, function_name=name or "None", ) + if max_concurrency > 1: + code = self._inject_concurrency_modifier(code, max_concurrency) + return code + + @staticmethod + def _inject_concurrency_modifier(code: str, max_concurrency: int) -> str: + """Replace the default runpod.serverless.start call with one including concurrency_modifier.""" + start_call = 'runpod.serverless.start({"handler": handler})' + if start_call not in code: + raise ValueError( + "Unable to inject concurrency_modifier: expected " + f"{start_call!r} in generated handler code." + ) + return code.replace( + start_call, + "runpod.serverless.start({\n" + ' "handler": handler,\n' + f' "concurrency_modifier": lambda current: {max_concurrency},\n' + " })", + ) def _validate_handler_imports(self, handler_path: Path) -> None: """Validate that generated handler has valid Python syntax. diff --git a/src/runpod_flash/cli/commands/build_utils/manifest.py b/src/runpod_flash/cli/commands/build_utils/manifest.py index 26571059..06b045ae 100644 --- a/src/runpod_flash/cli/commands/build_utils/manifest.py +++ b/src/runpod_flash/cli/commands/build_utils/manifest.py @@ -166,6 +166,7 @@ def _extract_deployment_config( try: resource_config = None + remote_cfg = None if config_variable and hasattr(module, config_variable): resource_config = getattr(module, config_variable) @@ -180,6 +181,20 @@ def _extract_deployment_config( if resource_config is None: return config + # Extract max_concurrency from Endpoint facade before unwrapping. + # For module-level Endpoint variables, read from the instance. + # For inline @Endpoint() decorators, the facade is gone by the + # time we reach here -- read from __remote_config__ instead. + if hasattr(resource_config, "_max_concurrency"): + mc = resource_config._max_concurrency + if mc > 1: + config["max_concurrency"] = mc + elif ( + isinstance(remote_cfg, dict) + and remote_cfg.get("max_concurrency", 1) > 1 + ): + config["max_concurrency"] = remote_cfg["max_concurrency"] + # unwrap Endpoint facade to the internal resource config if hasattr(resource_config, "_build_resource_config"): resource_config = resource_config._build_resource_config() @@ -439,6 +454,19 @@ def build(self) -> Dict[str, Any]: **deployment_config, # Include imageName, templateId, gpuIds, workers config } + # max_concurrency is QB-only; warn and remove for LB endpoints + if ( + is_load_balanced + and resources_dict[resource_name].get("max_concurrency", 1) > 1 + ): + logger.warning( + "max_concurrency=%d on LB endpoint '%s' is ignored. " + "LB endpoints handle concurrency via uvicorn.", + resources_dict[resource_name]["max_concurrency"], + resource_name, + ) + resources_dict[resource_name].pop("max_concurrency", None) + if not is_load_balanced: resources_dict[resource_name]["handler_file"] = ( f"handler_{resource_name}.py" diff --git a/src/runpod_flash/cli/commands/build_utils/scanner.py b/src/runpod_flash/cli/commands/build_utils/scanner.py index 46ca32ce..b38a9cd0 100644 --- a/src/runpod_flash/cli/commands/build_utils/scanner.py +++ b/src/runpod_flash/cli/commands/build_utils/scanner.py @@ -274,6 +274,15 @@ def _metadata_from_remote_config( is_async = inspect.iscoroutinefunction(original) or inspect.iscoroutinefunction( obj ) + elif is_class and target_class is not None: + # A class is async if any of its public methods are coroutines. + # This drives handler template selection for concurrent endpoints. + is_async = any( + inspect.iscoroutinefunction(m) + for name, m in vars(target_class).items() + if not name.startswith("_") + and (inspect.isfunction(m) or inspect.iscoroutinefunction(m)) + ) # function/class name: for classes, use the original class name. # for functions, use __name__ from the unwrapped function. diff --git a/src/runpod_flash/endpoint.py b/src/runpod_flash/endpoint.py index ab6a4cc5..c836149b 100644 --- a/src/runpod_flash/endpoint.py +++ b/src/runpod_flash/endpoint.py @@ -365,6 +365,7 @@ def __init__( scaler_value: int = 4, template: Optional[PodTemplate] = None, min_cuda_version: Optional[CudaVersion | str] = CudaVersion.V12_8, + max_concurrency: int = 1, ): if gpu is not None and cpu is not None: raise ValueError( @@ -378,6 +379,10 @@ def __init__( if name is None and id is None and image is not None: raise ValueError("name or id is required when image= is set.") + if max_concurrency < 1: + raise ValueError(f"max_concurrency must be >= 1, got {max_concurrency}") + self._max_concurrency = max_concurrency + # name can be None here for QB decorator mode (@Endpoint(gpu=...)). # it gets derived from the decorated function/class in __call__(). self.name = name @@ -606,7 +611,7 @@ async def process(data: dict) -> dict: ... from .client import remote as remote_decorator - return remote_decorator( + decorated = remote_decorator( resource_config=resource_config, dependencies=self.dependencies, system_dependencies=self.system_dependencies, @@ -614,6 +619,16 @@ async def process(data: dict) -> dict: ... _internal=True, )(func_or_class) + # Persist max_concurrency into __remote_config__ so the manifest + # builder can read it for inline @Endpoint() decorators where the + # Endpoint facade is not reachable via a module-level variable. + if self._max_concurrency > 1: + remote_cfg = getattr(decorated, "__remote_config__", None) + if isinstance(remote_cfg, dict): + remote_cfg["max_concurrency"] = self._max_concurrency + + return decorated + # -- route decorators (lb mode) -- def _route(self, method: str, path: str): diff --git a/src/runpod_flash/runtime/models.py b/src/runpod_flash/runtime/models.py index 502999a6..76d5c3c5 100644 --- a/src/runpod_flash/runtime/models.py +++ b/src/runpod_flash/runtime/models.py @@ -28,6 +28,7 @@ class ResourceConfig: False # LB endpoint (LoadBalancerSlsResource or LiveLoadBalancer) ) is_live_resource: bool = False # LiveLoadBalancer/LiveServerless (local dev only) + max_concurrency: int = 1 # Concurrent jobs per worker (QB only) @classmethod def from_dict(cls, data: Dict[str, Any]) -> "ResourceConfig": @@ -43,6 +44,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "ResourceConfig": makes_remote_calls=data.get("makes_remote_calls", True), is_load_balanced=data.get("is_load_balanced", False), is_live_resource=data.get("is_live_resource", False), + max_concurrency=data.get("max_concurrency", 1), ) diff --git a/tests/unit/cli/commands/build_utils/test_handler_generator.py b/tests/unit/cli/commands/build_utils/test_handler_generator.py index 8a61bfb7..bd81e7d5 100644 --- a/tests/unit/cli/commands/build_utils/test_handler_generator.py +++ b/tests/unit/cli/commands/build_utils/test_handler_generator.py @@ -1,5 +1,7 @@ """Tests for HandlerGenerator.""" +import ast +import logging import tempfile from pathlib import Path @@ -527,3 +529,294 @@ def test_empty_functions_raises(): generator = HandlerGenerator(manifest, build_dir) with pytest.raises(ValueError, match="has no functions"): generator.generate_handlers() + + +# -- concurrent handlers -- + + +def test_async_handler_with_concurrency(): + """max_concurrency > 1 + async produces async handler with concurrency_modifier.""" + with tempfile.TemporaryDirectory() as tmpdir: + build_dir = Path(tmpdir) + manifest = { + "version": "1.0", + "generated_at": "2026-01-02T10:00:00Z", + "project_name": "test_app", + "resources": { + "inference": { + "resource_type": "Endpoint", + "max_concurrency": 5, + "functions": [ + { + "name": "generate", + "module": "workers.inference", + "is_async": True, + "is_class": False, + } + ], + } + }, + } + generator = HandlerGenerator(manifest, build_dir) + handler_paths = generator.generate_handlers() + content = handler_paths[0].read_text() + assert "async def handler(job):" in content + assert "await generate(**job_input)" in content + assert "concurrency_modifier" in content + assert "lambda current: 5" in content + + +def test_sync_handler_with_concurrency(): + """max_concurrency > 1 + sync uses sync template with concurrency_modifier.""" + with tempfile.TemporaryDirectory() as tmpdir: + build_dir = Path(tmpdir) + manifest = { + "version": "1.0", + "generated_at": "2026-01-02T10:00:00Z", + "project_name": "test_app", + "resources": { + "worker": { + "resource_type": "Endpoint", + "max_concurrency": 3, + "functions": [ + { + "name": "process", + "module": "workers.cpu", + "is_async": False, + "is_class": False, + } + ], + } + }, + } + generator = HandlerGenerator(manifest, build_dir) + handler_paths = generator.generate_handlers() + content = handler_paths[0].read_text() + assert "def handler(job):" in content + assert "async def handler(job):" not in content + assert "concurrency_modifier" in content + assert "lambda current: 3" in content + + +def test_no_concurrency_modifier_when_default(): + """max_concurrency=1 (default) produces no concurrency_modifier.""" + with tempfile.TemporaryDirectory() as tmpdir: + build_dir = Path(tmpdir) + manifest = { + "version": "1.0", + "generated_at": "2026-01-02T10:00:00Z", + "project_name": "test_app", + "resources": { + "worker": { + "resource_type": "Endpoint", + "functions": [ + { + "name": "process", + "module": "workers.cpu", + "is_async": True, + "is_class": False, + } + ], + } + }, + } + generator = HandlerGenerator(manifest, build_dir) + handler_paths = generator.generate_handlers() + content = handler_paths[0].read_text() + assert "concurrency_modifier" not in content + assert 'runpod.serverless.start({"handler": handler})' in content + + +def test_async_handler_valid_syntax(): + """Generated async handler passes ast.parse validation.""" + with tempfile.TemporaryDirectory() as tmpdir: + build_dir = Path(tmpdir) + manifest = { + "version": "1.0", + "generated_at": "2026-01-02T10:00:00Z", + "project_name": "test_app", + "resources": { + "inference": { + "resource_type": "Endpoint", + "max_concurrency": 10, + "functions": [ + { + "name": "generate", + "module": "workers.inference", + "is_async": True, + "is_class": False, + } + ], + } + }, + } + generator = HandlerGenerator(manifest, build_dir) + handler_paths = generator.generate_handlers() + content = handler_paths[0].read_text() + ast.parse(content) + + +def test_async_class_handler_with_concurrency(): + """max_concurrency > 1 + async class produces async class handler.""" + with tempfile.TemporaryDirectory() as tmpdir: + build_dir = Path(tmpdir) + manifest = { + "version": "1.0", + "generated_at": "2026-01-02T10:00:00Z", + "project_name": "test_app", + "resources": { + "vllm_worker": { + "resource_type": "Endpoint", + "max_concurrency": 10, + "functions": [ + { + "name": "VLLMWorker", + "module": "workers.vllm", + "is_async": True, + "is_class": True, + "class_methods": ["generate"], + } + ], + } + }, + } + generator = HandlerGenerator(manifest, build_dir) + handler_paths = generator.generate_handlers() + content = handler_paths[0].read_text() + assert "async def handler(job):" in content + assert "await method(**job_input)" in content + assert "_instance = VLLMWorker()" in content + assert "concurrency_modifier" in content + assert "lambda current: 10" in content + assert "_run_maybe_async" not in content + + +def test_async_class_handler_valid_syntax(): + """Generated async class handler passes ast.parse validation.""" + with tempfile.TemporaryDirectory() as tmpdir: + build_dir = Path(tmpdir) + manifest = { + "version": "1.0", + "generated_at": "2026-01-02T10:00:00Z", + "project_name": "test_app", + "resources": { + "worker": { + "resource_type": "Endpoint", + "max_concurrency": 8, + "functions": [ + { + "name": "Worker", + "module": "w", + "is_async": True, + "is_class": True, + "class_methods": ["predict", "embed"], + } + ], + } + }, + } + generator = HandlerGenerator(manifest, build_dir) + handler_paths = generator.generate_handlers() + content = handler_paths[0].read_text() + ast.parse(content) + + +def test_sync_class_with_concurrency_uses_sync_template(): + """max_concurrency > 1 + sync class uses sync template with concurrency_modifier.""" + with tempfile.TemporaryDirectory() as tmpdir: + build_dir = Path(tmpdir) + manifest = { + "version": "1.0", + "generated_at": "2026-01-02T10:00:00Z", + "project_name": "test_app", + "resources": { + "worker": { + "resource_type": "Endpoint", + "max_concurrency": 4, + "functions": [ + { + "name": "SyncWorker", + "module": "w", + "is_async": False, + "is_class": True, + "class_methods": ["run"], + } + ], + } + }, + } + generator = HandlerGenerator(manifest, build_dir) + handler_paths = generator.generate_handlers() + content = handler_paths[0].read_text() + assert "def handler(job):" in content + assert "async def handler(job):" not in content + assert "_run_maybe_async" in content + assert "concurrency_modifier" in content + assert "lambda current: 4" in content + + +def test_sync_handler_with_concurrency_logs_warning(caplog): + """max_concurrency > 1 + sync handler logs a warning.""" + with tempfile.TemporaryDirectory() as tmpdir: + build_dir = Path(tmpdir) + manifest = { + "version": "1.0", + "generated_at": "2026-01-02T10:00:00Z", + "project_name": "test_app", + "resources": { + "worker": { + "resource_type": "Endpoint", + "max_concurrency": 3, + "functions": [ + { + "name": "process", + "module": "workers.cpu", + "is_async": False, + "is_class": False, + } + ], + } + }, + } + with caplog.at_level(logging.WARNING): + generator = HandlerGenerator(manifest, build_dir) + generator.generate_handlers() + assert any( + "max_concurrency=3" in r.message and "sync" in r.message + for r in caplog.records + ) + + +def test_high_concurrency_logs_warning(caplog): + """max_concurrency > 100 logs a high concurrency warning.""" + with tempfile.TemporaryDirectory() as tmpdir: + build_dir = Path(tmpdir) + manifest = { + "version": "1.0", + "generated_at": "2026-01-02T10:00:00Z", + "project_name": "test_app", + "resources": { + "inference": { + "resource_type": "Endpoint", + "max_concurrency": 150, + "functions": [ + { + "name": "generate", + "module": "workers.inference", + "is_async": True, + "is_class": False, + } + ], + } + }, + } + with caplog.at_level(logging.WARNING): + generator = HandlerGenerator(manifest, build_dir) + generator.generate_handlers() + assert any("max_concurrency=150" in r.message for r in caplog.records) + + +def test_inject_concurrency_modifier_raises_on_missing_start_call(): + """_inject_concurrency_modifier raises if the start call string is absent.""" + with pytest.raises(ValueError, match="Unable to inject concurrency_modifier"): + HandlerGenerator._inject_concurrency_modifier("some random code", 5) diff --git a/tests/unit/cli/commands/build_utils/test_runtime_scanner.py b/tests/unit/cli/commands/build_utils/test_runtime_scanner.py index 798c558e..64365676 100644 --- a/tests/unit/cli/commands/build_utils/test_runtime_scanner.py +++ b/tests/unit/cli/commands/build_utils/test_runtime_scanner.py @@ -739,3 +739,71 @@ def test_discovers_in_subdirectory(self, tmp_path): functions = scanner.discover_remote_functions() assert len(functions) == 1 assert functions[0].module_path == "workers.gpu_worker" + + +class TestClassAsyncDetection: + """scanner detects is_async for class-based workers.""" + + def test_sync_class_not_async(self, tmp_path): + _write_worker( + tmp_path, + "sync_cls.py", + """\ + from runpod_flash import remote, LiveServerless + cfg = LiveServerless(name="sync-cls") + + @remote(cfg) + class SyncWorker: + def predict(self, x): + return x + """, + ) + scanner = RuntimeScanner(tmp_path) + functions = scanner.discover_remote_functions() + assert len(functions) == 1 + assert functions[0].is_class is True + assert functions[0].is_async is False + + def test_async_class_detected(self, tmp_path): + _write_worker( + tmp_path, + "async_cls.py", + """\ + from runpod_flash import remote, LiveServerless + cfg = LiveServerless(name="async-cls") + + @remote(cfg) + class AsyncWorker: + async def predict(self, x): + return x + """, + ) + scanner = RuntimeScanner(tmp_path) + functions = scanner.discover_remote_functions() + assert len(functions) == 1 + assert functions[0].is_class is True + assert functions[0].is_async is True + + def test_mixed_class_detected_as_async(self, tmp_path): + """Class with both sync and async methods is detected as async.""" + _write_worker( + tmp_path, + "mixed_cls.py", + """\ + from runpod_flash import remote, LiveServerless + cfg = LiveServerless(name="mixed-cls") + + @remote(cfg) + class MixedWorker: + def setup(self): + pass + + async def predict(self, x): + return x + """, + ) + scanner = RuntimeScanner(tmp_path) + functions = scanner.discover_remote_functions() + assert len(functions) == 1 + assert functions[0].is_class is True + assert functions[0].is_async is True diff --git a/tests/unit/test_concurrency_manifest.py b/tests/unit/test_concurrency_manifest.py new file mode 100644 index 00000000..d220b675 --- /dev/null +++ b/tests/unit/test_concurrency_manifest.py @@ -0,0 +1,146 @@ +"""Tests for max_concurrency in manifest models and builder.""" + +from dataclasses import asdict +from pathlib import Path +from unittest.mock import patch + +from runpod_flash.cli.commands.build_utils.manifest import ManifestBuilder +from runpod_flash.cli.commands.build_utils.scanner import RemoteFunctionMetadata +from runpod_flash.runtime.models import ResourceConfig + + +class TestResourceConfigMaxConcurrency: + def test_default_is_one(self): + rc = ResourceConfig(resource_type="LiveServerless") + assert rc.max_concurrency == 1 + + def test_explicit_value(self): + rc = ResourceConfig(resource_type="LiveServerless", max_concurrency=5) + assert rc.max_concurrency == 5 + + def test_from_dict_with_max_concurrency(self): + data = { + "resource_type": "LiveServerless", + "max_concurrency": 10, + "functions": [], + } + rc = ResourceConfig.from_dict(data) + assert rc.max_concurrency == 10 + + def test_from_dict_missing_field_defaults_to_one(self): + data = { + "resource_type": "LiveServerless", + "functions": [], + } + rc = ResourceConfig.from_dict(data) + assert rc.max_concurrency == 1 + + def test_round_trip_through_dict(self): + rc = ResourceConfig(resource_type="LiveServerless", max_concurrency=7) + d = asdict(rc) + assert d["max_concurrency"] == 7 + rc2 = ResourceConfig.from_dict(d) + assert rc2.max_concurrency == 7 + + +class TestManifestBuilderMaxConcurrency: + def test_qb_resource_includes_max_concurrency_from_deployment_config(self): + """QB resource with max_concurrency in deployment_config includes it in manifest.""" + func = RemoteFunctionMetadata( + function_name="generate", + module_path="app", + resource_config_name="inference", + resource_type="LiveServerless", + is_async=True, + is_class=False, + is_load_balanced=False, + is_live_resource=False, + file_path=Path("/nonexistent/app.py"), + ) + + builder = ManifestBuilder( + project_name="test", + remote_functions=[func], + build_dir=None, + ) + + with patch.object( + builder, + "_extract_deployment_config", + return_value={"max_concurrency": 5}, + ): + manifest = builder.build() + + resource = manifest["resources"]["inference"] + assert resource["max_concurrency"] == 5 + + def test_lb_resource_omits_max_concurrency_and_warns(self): + """LB resource with max_concurrency > 1 logs warning and omits value.""" + func = RemoteFunctionMetadata( + function_name="health", + module_path="app", + resource_config_name="api", + resource_type="LiveLoadBalancer", + is_async=True, + is_class=False, + is_load_balanced=True, + is_live_resource=True, + http_method="GET", + http_path="/health", + file_path=Path("/nonexistent/app.py"), + ) + + builder = ManifestBuilder( + project_name="test", + remote_functions=[func], + build_dir=None, + ) + + with patch.object( + builder, + "_extract_deployment_config", + return_value={"max_concurrency": 5}, + ): + with patch( + "runpod_flash.cli.commands.build_utils.manifest.logger" + ) as mock_logger: + manifest = builder.build() + + resource = manifest["resources"]["api"] + assert "max_concurrency" not in resource + mock_logger.warning.assert_called_once() + assert ( + "max_concurrency=5" + in mock_logger.warning.call_args[0][0] + % mock_logger.warning.call_args[0][1:] + ) + + def test_qb_resource_without_max_concurrency_has_no_field(self): + """QB resource with no max_concurrency in deployment_config omits the field.""" + func = RemoteFunctionMetadata( + function_name="process", + module_path="worker", + resource_config_name="worker", + resource_type="LiveServerless", + is_async=True, + is_class=False, + is_load_balanced=False, + is_live_resource=False, + file_path=Path("/nonexistent/worker.py"), + ) + + builder = ManifestBuilder( + project_name="test", + remote_functions=[func], + build_dir=None, + ) + + with patch.object( + builder, + "_extract_deployment_config", + return_value={}, + ): + manifest = builder.build() + + resource = manifest["resources"]["worker"] + assert "max_concurrency" not in resource diff --git a/tests/unit/test_endpoint.py b/tests/unit/test_endpoint.py index 2f6cf58a..ceffc48e 100644 --- a/tests/unit/test_endpoint.py +++ b/tests/unit/test_endpoint.py @@ -1018,3 +1018,45 @@ def test_defaults_to_deploy_when_pod_id_set(self): from runpod_flash.endpoint import _is_live_provisioning assert _is_live_provisioning() is False + + +class TestMaxConcurrency: + def test_default_is_one(self): + ep = Endpoint(name="test") + assert ep._max_concurrency == 1 + + def test_explicit_value(self): + ep = Endpoint(name="test", max_concurrency=5) + assert ep._max_concurrency == 5 + + def test_rejects_zero(self): + with pytest.raises(ValueError, match="max_concurrency must be >= 1"): + Endpoint(name="test", max_concurrency=0) + + def test_rejects_negative(self): + with pytest.raises(ValueError, match="max_concurrency must be >= 1"): + Endpoint(name="test", max_concurrency=-1) + + def test_inline_decorator_persists_max_concurrency(self): + """Inline @Endpoint() persists max_concurrency into __remote_config__.""" + ep = Endpoint(name="test", gpu=GpuGroup.ANY, max_concurrency=5) + + async def generate(prompt: str) -> str: + return prompt + + decorated = ep(generate) + remote_cfg = getattr(decorated, "__remote_config__", None) + assert remote_cfg is not None + assert remote_cfg["max_concurrency"] == 5 + + def test_inline_decorator_omits_default_max_concurrency(self): + """Inline @Endpoint() with default max_concurrency=1 does not add the key.""" + ep = Endpoint(name="test", gpu=GpuGroup.ANY) + + async def generate(prompt: str) -> str: + return prompt + + decorated = ep(generate) + remote_cfg = getattr(decorated, "__remote_config__", None) + assert remote_cfg is not None + assert "max_concurrency" not in remote_cfg