Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,57 @@
JobSubmitResponse,
PingRequest,
PingResponse,
PolicyEngineBundle,
SimulationRequest,
)

logger = logging.getLogger(__name__)

router = APIRouter()
JOB_METADATA_DICT_NAME = "simulation-api-job-metadata"
DATASET_URIS = {
"us": {
"enhanced_cps": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.77.0",
"enhanced_cps_2024": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.77.0",
"cps": "hf://policyengine/policyengine-us-data/cps_2023.h5@1.77.0",
"cps_2023": "hf://policyengine/policyengine-us-data/cps_2023.h5@1.77.0",
"pooled_cps": "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5@1.77.0",
"pooled_3_year_cps_2023": "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5@1.77.0",
},
"uk": {
"enhanced_frs": "hf://policyengine/policyengine-uk-data-private/enhanced_frs_2023_24.h5@1.40.3",
"enhanced_frs_2023_24": "hf://policyengine/policyengine-uk-data-private/enhanced_frs_2023_24.h5@1.40.3",
"frs": "hf://policyengine/policyengine-uk-data-private/frs_2023_24.h5@1.40.3",
"frs_2023_24": "hf://policyengine/policyengine-uk-data-private/frs_2023_24.h5@1.40.3",
},
}


def _job_metadata_store():
return modal.Dict.from_name(JOB_METADATA_DICT_NAME, create_if_missing=True)


def _build_policyengine_bundle(
country: str, resolved_version: str, payload: dict
) -> PolicyEngineBundle:
dataset = payload.get("data")
if isinstance(dataset, str) and "://" in dataset:
resolved_dataset = dataset
elif isinstance(dataset, str):
resolved_dataset = DATASET_URIS.get(country.lower(), {}).get(dataset, dataset)
else:
resolved_dataset = None
return PolicyEngineBundle(
model_version=resolved_version,
dataset=resolved_dataset,
)


def _serialize_job_metadata(resolved_app_name: str, bundle: PolicyEngineBundle) -> dict:
return {
"resolved_app_name": resolved_app_name,
"policyengine_bundle": bundle.model_dump(),
}


def get_app_name(country: str, version: Optional[str]) -> tuple[str, str]:
Expand Down Expand Up @@ -74,12 +119,18 @@ async def submit_simulation(request: SimulationRequest):
# Spawn the job (returns immediately)
call = sim_func.spawn(payload)

bundle = _build_policyengine_bundle(request.country, resolved_version, payload)
job_metadata = _serialize_job_metadata(app_name, bundle)
_job_metadata_store()[call.object_id] = job_metadata

return JobSubmitResponse(
job_id=call.object_id,
status="submitted",
poll_url=f"/jobs/{call.object_id}",
country=request.country,
version=resolved_version,
resolved_app_name=app_name,
policyengine_bundle=bundle,
)


Expand All @@ -99,18 +150,32 @@ async def get_job_status(job_id: str):
except Exception:
raise HTTPException(status_code=404, detail=f"Job not found: {job_id}")

job_metadata = _job_metadata_store().get(job_id)

try:
result = call.get(timeout=0)
return JobStatusResponse(status="complete", result=result)
return JobStatusResponse(
status="complete", result=result, **(job_metadata or {})
)
except TimeoutError:
return JSONResponse(
status_code=202,
content={"status": "running", "result": None, "error": None},
content={
"status": "running",
"result": None,
"error": None,
**(job_metadata or {}),
},
)
except Exception as e:
return JSONResponse(
status_code=500,
content={"status": "failed", "result": None, "error": str(e)},
content={
"status": "failed",
"result": None,
"error": str(e),
**(job_metadata or {}),
},
)


Expand Down
13 changes: 13 additions & 0 deletions projects/policyengine-api-simulation/src/modal/gateway/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@ class SimulationRequest(BaseModel):
model_config = ConfigDict(extra="allow") # Pass through all other fields


class PolicyEngineBundle(BaseModel):
"""Resolved runtime provenance returned by the gateway."""

model_version: str
policyengine_version: Optional[str] = None
data_version: Optional[str] = None
dataset: Optional[str] = None


class JobSubmitResponse(BaseModel):
"""Response model for job submission."""

Expand All @@ -23,6 +32,8 @@ class JobSubmitResponse(BaseModel):
poll_url: str
country: str
version: str
resolved_app_name: str
policyengine_bundle: PolicyEngineBundle


class JobStatusResponse(BaseModel):
Expand All @@ -31,6 +42,8 @@ class JobStatusResponse(BaseModel):
status: str
result: Optional[dict] = None
error: Optional[str] = None
resolved_app_name: Optional[str] = None
policyengine_bundle: Optional[PolicyEngineBundle] = None


class PingRequest(BaseModel):
Expand Down
35 changes: 33 additions & 2 deletions projects/policyengine-api-simulation/tests/fixtures/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ def __getitem__(self, key: str):
raise KeyError(key)
return self._data[key]

def __setitem__(self, key: str, value):
self._data[key] = value

def get(self, key: str, default=None):
return self._data.get(key, default)

@classmethod
def from_name(cls, name: str):
"""Mock from_name that returns a MockDict based on name."""
Expand All @@ -28,8 +34,27 @@ def from_name(cls, name: str):
class MockFunctionCall:
"""Mock for Modal FunctionCall returned by spawn."""

registry = {}

def __init__(self, object_id: str = "mock-job-id-123"):
self.object_id = object_id
self.result = {"budget": {"total": 1000000}}
self.error = None
self.running = False
self.__class__.registry[object_id] = self

def get(self, timeout: int = 0):
if self.running:
raise TimeoutError()
if self.error is not None:
raise self.error
return self.result

@classmethod
def from_id(cls, object_id: str):
if object_id not in cls.registry:
raise KeyError(object_id)
return cls.registry[object_id]


class MockFunction:
Expand All @@ -38,10 +63,12 @@ class MockFunction:
def __init__(self):
self.last_payload = None
self.last_from_name_call = None
self.last_call = None

def spawn(self, payload: dict) -> MockFunctionCall:
self.last_payload = payload
return MockFunctionCall()
self.last_call = MockFunctionCall()
return self.last_call

@classmethod
def from_name(cls, app_name: str, func_name: str):
Expand Down Expand Up @@ -73,10 +100,13 @@ def test_something(mock_modal, client):
# Create mock objects
mock_func = MockFunction()
mock_dicts = {}
MockFunctionCall.registry = {}

class MockModalDict:
@staticmethod
def from_name(name: str):
def from_name(name: str, create_if_missing: bool = False):
if create_if_missing and name not in mock_dicts:
mock_dicts[name] = {}
if name not in mock_dicts:
raise KeyError(f"Mock dict not configured for: {name}")
return MockDict(mock_dicts[name])
Expand All @@ -91,6 +121,7 @@ def from_name(app_name: str, func_name: str):
class MockModal:
Dict = MockModalDict
Function = MockModalFunction
FunctionCall = MockFunctionCall

# Patch the modal import in the endpoints module
monkeypatch.setattr(endpoints, "modal", MockModal)
Expand Down
145 changes: 144 additions & 1 deletion projects/policyengine-api-simulation/tests/gateway/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest
from fastapi.testclient import TestClient

from tests.fixtures.endpoints import mock_modal # noqa: F401 - pytest fixture
pytest_plugins = ("tests.fixtures.endpoints",)


class TestGetAppName:
Expand Down Expand Up @@ -200,3 +200,146 @@ def test__given_submission__then_returns_job_id_and_poll_url(
assert data["job_id"] == "mock-job-id-123"
assert data["poll_url"] == "/jobs/mock-job-id-123"
assert data["status"] == "submitted"

def test__given_submission_with_data__then_returns_resolved_bundle_metadata(
self, mock_modal, client: TestClient
):
"""
Given a simulation submission with an explicit data URI
When the request completes
Then the response exposes the resolved app and submitted dataset provenance.
"""
# Given
mock_modal["dicts"]["simulation-api-us-versions"] = {
"latest": "1.500.0",
"1.500.0": "policyengine-simulation-us1-500-0-uk2-66-0",
}

request_body = {
"country": "us",
"scope": "macro",
"reform": {},
"data": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.77.0",
}

# When
response = client.post("/simulate/economy/comparison", json=request_body)

# Then
assert response.status_code == 200
data = response.json()
assert data["resolved_app_name"] == "policyengine-simulation-us1-500-0-uk2-66-0"
assert data["policyengine_bundle"] == {
"model_version": "1.500.0",
"policyengine_version": None,
"data_version": None,
"dataset": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.77.0",
}

def test__given_submission_with_alias_data__then_bundle_dataset_stays_unresolved(
self, mock_modal, client: TestClient
):
mock_modal["dicts"]["simulation-api-us-versions"] = {
"latest": "1.500.0",
"1.500.0": "policyengine-simulation-us1-500-0-uk2-66-0",
}

request_body = {
"country": "us",
"scope": "macro",
"reform": {},
"data": "enhanced_cps_2024",
}

response = client.post("/simulate/economy/comparison", json=request_body)

assert response.status_code == 200
data = response.json()
assert (
data["policyengine_bundle"]["dataset"]
== "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.77.0"
)

def test__given_submission_with_uk_alias_data__then_bundle_dataset_is_versioned_uri(
self, mock_modal, client: TestClient
):
mock_modal["dicts"]["simulation-api-uk-versions"] = {
"latest": "2.66.0",
"2.66.0": "policyengine-simulation-us1-500-0-uk2-66-0",
}

request_body = {
"country": "uk",
"scope": "macro",
"reform": {},
"data": "enhanced_frs",
}

response = client.post("/simulate/economy/comparison", json=request_body)

assert response.status_code == 200
data = response.json()
assert (
data["policyengine_bundle"]["dataset"]
== "hf://policyengine/policyengine-uk-data-private/enhanced_frs_2023_24.h5@1.40.3"
)

def test__given_submission_with_unknown_alias_data__then_bundle_dataset_is_preserved(
self, mock_modal, client: TestClient
):
mock_modal["dicts"]["simulation-api-us-versions"] = {
"latest": "1.500.0",
"1.500.0": "policyengine-simulation-us1-500-0-uk2-66-0",
}

request_body = {
"country": "us",
"scope": "macro",
"reform": {},
"data": "custom_dataset_label",
}

response = client.post("/simulate/economy/comparison", json=request_body)

assert response.status_code == 200
data = response.json()
assert data["policyengine_bundle"]["dataset"] == "custom_dataset_label"

def test__given_submitted_job__then_job_status_includes_bundle_metadata(
self, mock_modal, client: TestClient
):
"""
Given a submitted simulation job
When polling job status
Then the resolved bundle metadata is returned with the status response.
"""
# Given
mock_modal["dicts"]["simulation-api-us-versions"] = {
"latest": "1.500.0",
"1.500.0": "policyengine-simulation-us1-500-0-uk2-66-0",
}

submit_response = client.post(
"/simulate/economy/comparison",
json={
"country": "us",
"scope": "macro",
"reform": {},
"data": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.77.0",
},
)

# When
response = client.get(f"/jobs/{submit_response.json()['job_id']}")

# Then
assert response.status_code == 200
data = response.json()
assert data["status"] == "complete"
assert data["resolved_app_name"] == "policyengine-simulation-us1-500-0-uk2-66-0"
assert data["policyengine_bundle"] == {
"model_version": "1.500.0",
"policyengine_version": None,
"data_version": None,
"dataset": "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5@1.77.0",
}
Loading
Loading