Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 202 additions & 11 deletions src/runpod_flash/cli/commands/build_utils/handler_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,49 @@ def handler(job):
runpod.serverless.start({{"handler": handler}})
'''

DEPLOYED_ASYNC_HANDLER_TEMPLATE = '''"""
Auto-generated deployed handler for resource: {resource_name}
Generated at: {timestamp}

Concurrent async endpoint handler: accepts plain JSON, no cloudpickle.
max_concurrency={max_concurrency}

This file is generated by the Flash build process. Do not edit manually.
"""

import importlib
import logging
import traceback

_logger = logging.getLogger(__name__)

# Import the function for this endpoint
{import_statement}


async def handler(job):
"""Async handler for concurrent QB endpoint. Accepts plain JSON kwargs."""
job_input = job.get("input", {{}})
try:
result = await {function_name}(**job_input)
return result
except Exception as e:
_logger.error(
"Deployed handler error for {function_name}: %s",
e,
exc_info=True,
)
return {{"error": str(e), "traceback": traceback.format_exc()}}


if __name__ == "__main__":
import runpod
runpod.serverless.start({{
"handler": handler,
"concurrency_modifier": lambda current: {max_concurrency},
}})
'''

DEPLOYED_CLASS_HANDLER_TEMPLATE = '''"""
Auto-generated deployed handler for class-based resource: {resource_name}
Generated at: {timestamp}
Expand Down Expand Up @@ -155,6 +198,84 @@ def handler(job):
'''


DEPLOYED_ASYNC_CLASS_HANDLER_TEMPLATE = '''"""
Auto-generated deployed handler for class-based resource: {resource_name}
Generated at: {timestamp}

Concurrent async class endpoint handler: accepts plain JSON, no cloudpickle.
max_concurrency={max_concurrency}

This file is generated by the Flash build process. Do not edit manually.
"""

import importlib
import logging
import traceback

_logger = logging.getLogger(__name__)

# import the class for this endpoint
{import_statement}

# instantiate once at cold start
_instance = {class_name}()

# public methods available for dispatch
_METHODS = {methods_dict}


async def handler(job):
"""Async handler for concurrent class-based QB endpoint.

Dispatches to a method on the singleton class instance.
If the class has exactly one public method, input is passed
directly as kwargs. If multiple methods exist, input must
include a "method" key to select the target.
"""
job_input = job.get("input", {{}})
try:
if len(_METHODS) == 1:
method_name = next(iter(_METHODS))
else:
method_name = job_input.pop("method", None)
if method_name is None:
return {{
"error": (
"class {class_name} has multiple methods: "
+ ", ".join(sorted(_METHODS))
+ ". include a \\"method\\" key in the input."
)
}}
if method_name not in _METHODS:
return {{
"error": (
f"unknown method '{{method_name}}' on {class_name}. "
f"available: {{', '.join(sorted(_METHODS))}}"
)
}}

method = getattr(_instance, method_name)
result = await method(**job_input)
return result
except Exception as e:
_logger.error(
"Deployed handler error for {class_name}.%s: %s",
job_input.get("method", "<default>"),
e,
exc_info=True,
)
return {{"error": str(e), "traceback": traceback.format_exc()}}


if __name__ == "__main__":
import runpod
runpod.serverless.start({{
"handler": handler,
"concurrency_modifier": lambda current: {max_concurrency},
}})
'''


class HandlerGenerator:
"""Generates handler_<name>.py files for each resource config."""

Expand Down Expand Up @@ -190,12 +311,7 @@ def generate_handlers(self) -> List[Path]:
return handler_paths

def _generate_handler(self, resource_name: str, resource_data: Any) -> Path:
"""Generate a handler file for a deployed QB endpoint.

Produces a plain-JSON handler that accepts raw dicts (no cloudpickle).
For classes, instantiates once at cold start and dispatches to methods.
For functions, calls the function directly with job input as kwargs.
"""
"""Generate a handler file for a deployed QB endpoint."""
handler_filename = f"handler_{resource_name}.py"
handler_path = self.build_dir / handler_filename

Expand All @@ -211,8 +327,14 @@ def _generate_handler(self, resource_name: str, resource_data: Any) -> Path:
else resource_data.get("functions", [])
)

max_concurrency = (
resource_data.max_concurrency
if hasattr(resource_data, "max_concurrency")
else resource_data.get("max_concurrency", 1)
)

handler_code = self._generate_deployed_handler_code(
resource_name, timestamp, functions
resource_name, timestamp, functions, max_concurrency
)

handler_path.write_text(handler_code)
Expand All @@ -225,11 +347,14 @@ def _generate_deployed_handler_code(
resource_name: str,
timestamp: str,
functions: List[Any],
max_concurrency: int = 1,
) -> str:
"""Generate deployed handler code for a QB endpoint.

Selects the class template for class-based workers and the
function template for plain functions.
Selects template based on max_concurrency and is_async:
- max_concurrency=1: current sync templates (unchanged behavior)
- max_concurrency>1 + async: new async templates with concurrency_modifier
- max_concurrency>1 + sync: sync templates with concurrency_modifier injected
"""
if not functions:
raise ValueError(
Expand All @@ -243,34 +368,100 @@ def _generate_deployed_handler_code(
is_class = (
func.is_class if hasattr(func, "is_class") else func.get("is_class", False)
)
is_async = (
func.is_async if hasattr(func, "is_async") else func.get("is_async", False)
)

import_statement = (
f"{name} = importlib.import_module('{module}').{name}"
if module and name
else "# No function to import"
)

# 100 is a soft upper bound: most GPU workloads saturate VRAM well
# below this. The warning nudges users to verify resource capacity.
if max_concurrency > 100:
logger.warning(
"High max_concurrency=%d for resource '%s'. Ensure your handler "
"and GPU can support this level of concurrent execution.",
max_concurrency,
resource_name,
)

if max_concurrency > 1 and not is_async:
logger.warning(
"max_concurrency=%d set on sync handler '%s'. "
"Only async handlers benefit from concurrent execution. "
"Consider making the handler async.",
max_concurrency,
resource_name,
)

if is_class:
class_methods = (
func.class_methods
if hasattr(func, "class_methods")
else func.get("class_methods", [])
)
methods_dict = {m: m for m in class_methods} if class_methods else {}
return DEPLOYED_CLASS_HANDLER_TEMPLATE.format(

if max_concurrency > 1 and is_async:
return DEPLOYED_ASYNC_CLASS_HANDLER_TEMPLATE.format(
resource_name=resource_name,
timestamp=timestamp,
import_statement=import_statement,
class_name=name or "None",
methods_dict=repr(methods_dict),
max_concurrency=max_concurrency,
)

code = DEPLOYED_CLASS_HANDLER_TEMPLATE.format(
resource_name=resource_name,
timestamp=timestamp,
import_statement=import_statement,
class_name=name or "None",
methods_dict=repr(methods_dict),
)
if max_concurrency > 1:
code = self._inject_concurrency_modifier(code, max_concurrency)
return code

# Function-based handler
if max_concurrency > 1 and is_async:
return DEPLOYED_ASYNC_HANDLER_TEMPLATE.format(
resource_name=resource_name,
timestamp=timestamp,
import_statement=import_statement,
function_name=name or "None",
max_concurrency=max_concurrency,
)

return DEPLOYED_HANDLER_TEMPLATE.format(
code = DEPLOYED_HANDLER_TEMPLATE.format(
resource_name=resource_name,
timestamp=timestamp,
import_statement=import_statement,
function_name=name or "None",
)
if max_concurrency > 1:
code = self._inject_concurrency_modifier(code, max_concurrency)
return code

@staticmethod
def _inject_concurrency_modifier(code: str, max_concurrency: int) -> str:
"""Replace the default runpod.serverless.start call with one including concurrency_modifier."""
start_call = 'runpod.serverless.start({"handler": handler})'
if start_call not in code:
raise ValueError(
"Unable to inject concurrency_modifier: expected "
f"{start_call!r} in generated handler code."
)
return code.replace(
start_call,
"runpod.serverless.start({\n"
' "handler": handler,\n'
f' "concurrency_modifier": lambda current: {max_concurrency},\n'
" })",
)

def _validate_handler_imports(self, handler_path: Path) -> None:
"""Validate that generated handler has valid Python syntax.
Expand Down
28 changes: 28 additions & 0 deletions src/runpod_flash/cli/commands/build_utils/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def _extract_deployment_config(

try:
resource_config = None
remote_cfg = None

if config_variable and hasattr(module, config_variable):
resource_config = getattr(module, config_variable)
Expand All @@ -180,6 +181,20 @@ def _extract_deployment_config(
if resource_config is None:
return config

# Extract max_concurrency from Endpoint facade before unwrapping.
# For module-level Endpoint variables, read from the instance.
# For inline @Endpoint() decorators, the facade is gone by the
# time we reach here -- read from __remote_config__ instead.
if hasattr(resource_config, "_max_concurrency"):
mc = resource_config._max_concurrency
if mc > 1:
config["max_concurrency"] = mc
elif (
isinstance(remote_cfg, dict)
and remote_cfg.get("max_concurrency", 1) > 1
):
config["max_concurrency"] = remote_cfg["max_concurrency"]

# unwrap Endpoint facade to the internal resource config
if hasattr(resource_config, "_build_resource_config"):
resource_config = resource_config._build_resource_config()
Expand Down Expand Up @@ -439,6 +454,19 @@ def build(self) -> Dict[str, Any]:
**deployment_config, # Include imageName, templateId, gpuIds, workers config
}

# max_concurrency is QB-only; warn and remove for LB endpoints
if (
is_load_balanced
and resources_dict[resource_name].get("max_concurrency", 1) > 1
):
logger.warning(
"max_concurrency=%d on LB endpoint '%s' is ignored. "
"LB endpoints handle concurrency via uvicorn.",
resources_dict[resource_name]["max_concurrency"],
resource_name,
)
resources_dict[resource_name].pop("max_concurrency", None)

if not is_load_balanced:
resources_dict[resource_name]["handler_file"] = (
f"handler_{resource_name}.py"
Expand Down
9 changes: 9 additions & 0 deletions src/runpod_flash/cli/commands/build_utils/scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,15 @@ def _metadata_from_remote_config(
is_async = inspect.iscoroutinefunction(original) or inspect.iscoroutinefunction(
obj
)
elif is_class and target_class is not None:
# A class is async if any of its public methods are coroutines.
# This drives handler template selection for concurrent endpoints.
is_async = any(
inspect.iscoroutinefunction(m)
for name, m in vars(target_class).items()
if not name.startswith("_")
and (inspect.isfunction(m) or inspect.iscoroutinefunction(m))
)

# function/class name: for classes, use the original class name.
# for functions, use __name__ from the unwrapped function.
Expand Down
17 changes: 16 additions & 1 deletion src/runpod_flash/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ def __init__(
scaler_value: int = 4,
template: Optional[PodTemplate] = None,
min_cuda_version: Optional[CudaVersion | str] = CudaVersion.V12_8,
max_concurrency: int = 1,
):
if gpu is not None and cpu is not None:
raise ValueError(
Expand All @@ -378,6 +379,10 @@ def __init__(
if name is None and id is None and image is not None:
raise ValueError("name or id is required when image= is set.")

if max_concurrency < 1:
raise ValueError(f"max_concurrency must be >= 1, got {max_concurrency}")
self._max_concurrency = max_concurrency

# name can be None here for QB decorator mode (@Endpoint(gpu=...)).
# it gets derived from the decorated function/class in __call__().
self.name = name
Expand Down Expand Up @@ -606,14 +611,24 @@ async def process(data: dict) -> dict: ...

from .client import remote as remote_decorator

return remote_decorator(
decorated = remote_decorator(
resource_config=resource_config,
dependencies=self.dependencies,
system_dependencies=self.system_dependencies,
accelerate_downloads=self.accelerate_downloads,
_internal=True,
)(func_or_class)

# Persist max_concurrency into __remote_config__ so the manifest
# builder can read it for inline @Endpoint() decorators where the
# Endpoint facade is not reachable via a module-level variable.
if self._max_concurrency > 1:
remote_cfg = getattr(decorated, "__remote_config__", None)
if isinstance(remote_cfg, dict):
remote_cfg["max_concurrency"] = self._max_concurrency

return decorated

# -- route decorators (lb mode) --

def _route(self, method: str, path: str):
Expand Down
Loading
Loading