From 20dee2b48805ff5badd1c00604535389503cc010 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Tue, 10 Mar 2026 05:52:09 -0400 Subject: [PATCH 01/13] Log exceptions instead of silently swallowing them in calculate The axes code path silently discarded all exceptions (`pass`), causing variables like NJ gross income to return null with no error trace. Now logs the full traceback via logging.exception(). Fixes #3322 Co-Authored-By: Claude Opus 4.6 --- changelog.d/fix-silent-exception-swallowing.fixed.md | 1 + policyengine_api/country.py | 7 +++---- 2 files changed, 4 insertions(+), 4 deletions(-) create mode 100644 changelog.d/fix-silent-exception-swallowing.fixed.md diff --git a/changelog.d/fix-silent-exception-swallowing.fixed.md b/changelog.d/fix-silent-exception-swallowing.fixed.md new file mode 100644 index 000000000..4b10062e5 --- /dev/null +++ b/changelog.d/fix-silent-exception-swallowing.fixed.md @@ -0,0 +1 @@ +Log exceptions instead of silently swallowing them during household calculations. diff --git a/policyengine_api/country.py b/policyengine_api/country.py index befa49851..0cc7f3806 100644 --- a/policyengine_api/country.py +++ b/policyengine_api/country.py @@ -1,4 +1,5 @@ import importlib +import logging from flask import Response import json from policyengine_core.taxbenefitsystems import TaxBenefitSystem @@ -445,11 +446,9 @@ def calculate( entity_result ) except Exception as e: - if "axes" in household: - pass - else: + logging.exception(f"Error computing {variable_name} for {entity_id}") + if "axes" not in household: household[entity_plural][entity_id][variable_name][period] = None - print(f"Error computing {variable_name} for {entity_id}: {e}") tracer_output = simulation.tracer.computation_log log_lines = tracer_output.lines(aggregate=False, max_depth=10) From bf46019d4ab698f4a0e7f1b86931668ee6cda668 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Wed, 8 Apr 2026 20:59:21 -0400 Subject: [PATCH 02/13] Add budget window batch economy endpoint --- changelog.d/budget-window-batch.fixed.md | 1 + policyengine_api/routes/economy_routes.py | 67 +++ policyengine_api/services/economy_service.py | 490 ++++++++++++++++--- tests/unit/services/test_economy_service.py | 269 ++++++++++ 4 files changed, 759 insertions(+), 68 deletions(-) create mode 100644 changelog.d/budget-window-batch.fixed.md diff --git a/changelog.d/budget-window-batch.fixed.md b/changelog.d/budget-window-batch.fixed.md new file mode 100644 index 000000000..9bd03006f --- /dev/null +++ b/changelog.d/budget-window-batch.fixed.md @@ -0,0 +1 @@ +Added a budget-window economy endpoint that batches yearly impact calculations with bounded server-side concurrency and returns aggregated progress plus totals. diff --git a/policyengine_api/routes/economy_routes.py b/policyengine_api/routes/economy_routes.py index 4279a1b1b..a55e75d37 100644 --- a/policyengine_api/routes/economy_routes.py +++ b/policyengine_api/routes/economy_routes.py @@ -2,6 +2,7 @@ from policyengine_api.services.economy_service import ( EconomyService, EconomicImpactResult, + BudgetWindowEconomicImpactResult, ) from policyengine_api.utils import get_current_law_policy_id from policyengine_api.utils.payload_validators import validate_country @@ -68,3 +69,69 @@ def get_economic_impact(country_id: str, policy_id: int, baseline_policy_id: int status=200, mimetype="application/json", ) + + +@validate_country +@economy_bp.route( + "//economy//over//budget-window", + methods=["GET"], +) +def get_budget_window_economic_impact( + country_id: str, policy_id: int, baseline_policy_id: int +): + policy_id = int(policy_id or get_current_law_policy_id(country_id)) + baseline_policy_id = int( + baseline_policy_id or get_current_law_policy_id(country_id) + ) + + query_parameters = request.args + options = dict(query_parameters) + options = json.loads(json.dumps(options)) + region = options.pop("region") + dataset = options.pop("dataset", "default") + start_year = options.pop("start_year") + window_size = int(options.pop("window_size")) + + include_district_breakdowns_raw = options.pop( + "include_district_breakdowns", "false" + ) + include_district_breakdowns = include_district_breakdowns_raw.lower() == "true" + if include_district_breakdowns and country_id == "us" and region == "us": + dataset = "national-with-breakdowns" + + target: Literal["general", "cliff"] = options.pop("target", "general") + api_version = options.pop("version", COUNTRY_PACKAGE_VERSIONS.get(country_id)) + + economic_impact_result: BudgetWindowEconomicImpactResult = ( + economy_service.get_budget_window_economic_impact( + country_id=country_id, + policy_id=policy_id, + baseline_policy_id=baseline_policy_id, + region=region, + dataset=dataset, + start_year=start_year, + window_size=window_size, + options=options, + api_version=api_version, + target=target, + ) + ) + + result_dict = economic_impact_result.to_dict() + + return Response( + json.dumps( + { + "status": result_dict["status"], + "message": result_dict["message"], + "result": result_dict["data"], + "progress": result_dict["progress"], + "completed_years": result_dict["completed_years"], + "computing_years": result_dict["computing_years"], + "queued_years": result_dict["queued_years"], + "error": result_dict["error"], + } + ), + status=200, + mimetype="application/json", + ) diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index 031696286..caac4d2f6 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -4,7 +4,6 @@ ) from policyengine_api.constants import ( COUNTRY_PACKAGE_VERSIONS, - REGION_PREFIXES, EXECUTION_STATUSES_SUCCESS, EXECUTION_STATUSES_FAILURE, EXECUTION_STATUSES_PENDING, @@ -24,9 +23,10 @@ import datetime from typing import Literal, Any, Optional, Annotated, Union from dotenv import load_dotenv -from pydantic import BaseModel +from pydantic import BaseModel, Field import numpy as np from enum import Enum +from concurrent.futures import ThreadPoolExecutor load_dotenv() @@ -57,6 +57,7 @@ class ImpactStatus(Enum): COMPLETE_STATUSES = [ImpactStatus.OK.value, ImpactStatus.ERROR.value] COMPUTING_STATUS = ImpactStatus.COMPUTING.value +BUDGET_WINDOW_MAX_ACTIVE_YEARS = 3 class EconomicImpactSetupOptions(BaseModel): @@ -118,6 +119,79 @@ def error(cls, message: str) -> "EconomicImpactResult": return cls(status=ImpactStatus.ERROR, data=None) +class BudgetWindowEconomicImpactResult(BaseModel): + """ + Model for a batch budget-window economic impact response. + """ + + status: ImpactStatus + data: Optional[dict] = None + progress: Optional[int] = None + completed_years: list[str] = Field(default_factory=list) + computing_years: list[str] = Field(default_factory=list) + queued_years: list[str] = Field(default_factory=list) + message: Optional[str] = None + error: Optional[str] = None + + model_config = {"frozen": True} + + def to_dict(self) -> dict[str, Any]: + return { + "status": self.status.value, + "data": self.data, + "progress": self.progress, + "completed_years": self.completed_years, + "computing_years": self.computing_years, + "queued_years": self.queued_years, + "message": self.message, + "error": self.error, + } + + @classmethod + def completed(cls, data: dict) -> "BudgetWindowEconomicImpactResult": + return cls(status=ImpactStatus.OK, data=data, progress=100) + + @classmethod + def computing( + cls, + *, + progress: int, + completed_years: list[str], + computing_years: list[str], + queued_years: list[str], + message: str, + ) -> "BudgetWindowEconomicImpactResult": + return cls( + status=ImpactStatus.COMPUTING, + data=None, + progress=progress, + completed_years=completed_years, + computing_years=computing_years, + queued_years=queued_years, + message=message, + ) + + @classmethod + def failed( + cls, + message: str, + *, + completed_years: Optional[list[str]] = None, + computing_years: Optional[list[str]] = None, + queued_years: Optional[list[str]] = None, + ) -> "BudgetWindowEconomicImpactResult": + logger.log_struct({"message": message}, severity="ERROR") + return cls( + status=ImpactStatus.ERROR, + data=None, + completed_years=completed_years or [], + computing_years=computing_years or [], + queued_years=queued_years or [], + message=message, + error=message, + ) + + class EconomyService: """ Service for calculating economic impact of policy reforms; this is connected @@ -151,96 +225,376 @@ def get_economic_impact( # regions that don't contain a region prefix. if country_id == "us": region = normalize_us_region(region) + economic_impact_setup_options = self._build_economic_impact_setup_options( + country_id=country_id, + policy_id=policy_id, + baseline_policy_id=baseline_policy_id, + region=region, + dataset=dataset, + time_period=time_period, + options=options, + api_version=api_version, + target=target, + ) + + return self._get_or_create_economic_impact( + setup_options=economic_impact_setup_options + ) - # Set up logging - process_id: str = self._create_process_id() + except Exception as e: + print(f"Error getting economic impact: {str(e)}") + raise e - options_hash = ( - "[" + "&".join([f"{k}={v}" for k, v in options.items()]) + "]" + def get_budget_window_economic_impact( + self, + country_id: str, + policy_id: int, + baseline_policy_id: int, + region: str, + dataset: str, + start_year: str, + window_size: int, + options: dict, + api_version: str, + target: Literal["general", "cliff"] = "general", + max_active_years: int = BUDGET_WINDOW_MAX_ACTIVE_YEARS, + ) -> BudgetWindowEconomicImpactResult: + try: + if country_id == "us": + region = normalize_us_region(region) + + start_year_int = int(start_year) + if window_size < 1: + raise ValueError("window_size must be at least 1") + + years = [str(start_year_int + index) for index in range(window_size)] + setup_options_by_year = { + year: self._build_economic_impact_setup_options( + country_id=country_id, + policy_id=policy_id, + baseline_policy_id=baseline_policy_id, + region=region, + dataset=dataset, + time_period=year, + options=dict(options), + api_version=api_version, + target=target, + ) + for year in years + } + + completed_impacts: dict[str, dict] = {} + computing_years: list[str] = [] + queued_years: list[str] = [] + + for year in years: + result = self._get_existing_economic_impact( + setup_options=setup_options_by_year[year] + ) + + if result is None: + queued_years.append(year) + continue + + if result.status == ImpactStatus.OK: + completed_impacts[year] = self._extract_budget_window_annual_impact( + year=year, impact_data=result.data or {} + ) + continue + + if result.status == ImpactStatus.COMPUTING: + computing_years.append(year) + continue + + completed_years = [ + completed_year + for completed_year in years + if completed_year in completed_impacts + ] + return BudgetWindowEconomicImpactResult.failed( + result.data.get("message") + if isinstance(result.data, dict) + else f"Budget-window calculation failed for {year}", + completed_years=completed_years, + computing_years=computing_years, + queued_years=queued_years, + ) + + available_slots = max(0, max_active_years - len(computing_years)) + years_to_start = queued_years[:available_slots] + remaining_queued_years = queued_years[available_slots:] + + if years_to_start: + max_workers = min(len(years_to_start), max_active_years) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + future_year_pairs = [ + ( + year, + executor.submit( + self.get_economic_impact, + country_id=country_id, + policy_id=policy_id, + baseline_policy_id=baseline_policy_id, + region=region, + dataset=dataset, + time_period=year, + options=dict(options), + api_version=api_version, + target=target, + ), + ) + for year in years_to_start + ] + + for year, future in future_year_pairs: + result = future.result() + + if result.status == ImpactStatus.OK: + completed_impacts[year] = ( + self._extract_budget_window_annual_impact( + year=year, impact_data=result.data or {} + ) + ) + elif result.status == ImpactStatus.COMPUTING: + computing_years.append(year) + else: + completed_years = [ + completed_year + for completed_year in years + if completed_year in completed_impacts + ] + return BudgetWindowEconomicImpactResult.failed( + f"Budget-window calculation failed for {year}", + completed_years=completed_years, + computing_years=computing_years, + queued_years=remaining_queued_years, + ) + + completed_years = [ + completed_year + for completed_year in years + if completed_year in completed_impacts + ] + + if len(completed_years) == len(years): + ordered_annual_impacts = [ + completed_impacts[year] + for year in years + if year in completed_impacts + ] + return BudgetWindowEconomicImpactResult.completed( + self._build_budget_window_output( + start_year=start_year, + window_size=window_size, + annual_impacts=ordered_annual_impacts, + ) + ) + + progress = round((len(completed_years) / len(years)) * 100) + return BudgetWindowEconomicImpactResult.computing( + progress=progress, + completed_years=completed_years, + computing_years=computing_years, + queued_years=remaining_queued_years, + message=self._build_budget_window_progress_message( + completed_years=completed_years, + total_years=len(years), + computing_years=computing_years, + queued_years=remaining_queued_years, + ), ) + except Exception as e: + print(f"Error getting budget-window economic impact: {str(e)}") + raise e - country_package_version = COUNTRY_PACKAGE_VERSIONS.get(country_id) + def _build_economic_impact_setup_options( + self, + *, + country_id: str, + policy_id: int, + baseline_policy_id: int, + region: str, + dataset: str, + time_period: str, + options: dict, + api_version: str, + target: Literal["general", "cliff"] = "general", + ) -> EconomicImpactSetupOptions: + process_id: str = self._create_process_id() + options_hash = "[" + "&".join([f"{k}={v}" for k, v in options.items()]) + "]" + + country_package_version = COUNTRY_PACKAGE_VERSIONS.get(country_id) + if country_id == "uk": + country_package_version = None + + return EconomicImpactSetupOptions.model_validate( + { + "process_id": process_id, + "country_id": country_id, + "reform_policy_id": policy_id, + "baseline_policy_id": baseline_policy_id, + "region": region, + "dataset": dataset, + "time_period": time_period, + "options": options, + "api_version": api_version, + "target": target, + "model_version": country_package_version, + "data_version": get_dataset_version(country_id), + "options_hash": options_hash, + } + ) + + def _get_or_create_economic_impact( + self, setup_options: EconomicImpactSetupOptions + ) -> EconomicImpactResult: + logger.log_struct( + { + "message": "Received request for economic impact; checking if already in reform_impacts table", + **setup_options.model_dump(), + }, + severity="INFO", + ) - if country_id == "uk": - country_package_version = None + most_recent_impact: dict | None = self._get_most_recent_impact( + setup_options=setup_options + ) - economic_impact_setup_options = EconomicImpactSetupOptions.model_validate( + impact_action: ImpactAction = self._determine_impact_action( + most_recent_impact=most_recent_impact + ) + + if impact_action == ImpactAction.COMPLETED: + logger.log_struct( { - "process_id": process_id, - "country_id": country_id, - "reform_policy_id": policy_id, - "baseline_policy_id": baseline_policy_id, - "region": region, - "dataset": dataset, - "time_period": time_period, - "options": options, - "api_version": api_version, - "target": target, - "model_version": country_package_version, - "data_version": get_dataset_version(country_id), - "options_hash": options_hash, - } + "message": "Found completed economic impact in db; returning result", + **setup_options.model_dump(), + }, + severity="INFO", + ) + return self._handle_completed_impact(most_recent_impact=most_recent_impact) + + if impact_action == ImpactAction.COMPUTING: + logger.log_struct( + { + "message": "Found computing economic impact record in db; confirming this is still computing", + **setup_options.model_dump(), + }, + severity="INFO", + ) + return self._handle_computing_impact( + setup_options=setup_options, + most_recent_impact=most_recent_impact, ) - # Logging that we've received a request + if impact_action == ImpactAction.CREATE: logger.log_struct( { - "message": "Received request for economic impact; checking if already in reform_impacts table", - **economic_impact_setup_options.model_dump(), + "message": "No previous economic impact record found in db; creating new simulation run", + **setup_options.model_dump(), }, severity="INFO", ) + return self._handle_create_impact(setup_options=setup_options) + + raise ValueError(f"Unexpected impact action: {impact_action}") - most_recent_impact: dict | None = self._get_most_recent_impact( - setup_options=economic_impact_setup_options, + def _get_existing_economic_impact( + self, setup_options: EconomicImpactSetupOptions + ) -> Optional[EconomicImpactResult]: + most_recent_impact = self._get_most_recent_impact(setup_options=setup_options) + if not most_recent_impact: + return None + + status = most_recent_impact.get("status") + if status == ImpactStatus.ERROR.value: + error_message = most_recent_impact.get("message") or ( + f"Economic impact failed for {setup_options.time_period}" + ) + return EconomicImpactResult( + status=ImpactStatus.ERROR, + data={"message": error_message}, ) - impact_action: ImpactAction = self._determine_impact_action( + if status == ImpactStatus.OK.value: + return self._handle_completed_impact(most_recent_impact=most_recent_impact) + + if status == ImpactStatus.COMPUTING.value: + return self._handle_computing_impact( + setup_options=setup_options, most_recent_impact=most_recent_impact, ) - if impact_action == ImpactAction.COMPLETED: - logger.log_struct( - { - "message": "Found completed economic impact in db; returning result", - **economic_impact_setup_options.model_dump(), - }, - severity="INFO", - ) - return self._handle_completed_impact( - most_recent_impact=most_recent_impact, - ) + raise ValueError(f"Unknown impact status: {status}") - if impact_action == ImpactAction.COMPUTING: - logger.log_struct( - { - "message": "Found computing economic impact record in db; confirming this is still computing", - **economic_impact_setup_options.model_dump(), - }, - severity="INFO", - ) - return self._handle_computing_impact( - setup_options=economic_impact_setup_options, - most_recent_impact=most_recent_impact, - ) + def _extract_budget_window_annual_impact( + self, year: str, impact_data: dict + ) -> dict[str, Union[str, int, float]]: + budget = impact_data.get("budget", {}) + state_tax_revenue_impact = budget.get("state_tax_revenue_impact", 0) + tax_revenue_impact = budget.get("tax_revenue_impact", 0) - if impact_action == ImpactAction.CREATE: - logger.log_struct( - { - "message": "No previous economic impact record found in db; creating new simulation run", - **economic_impact_setup_options.model_dump(), - }, - severity="INFO", - ) - return self._handle_create_impact( - setup_options=economic_impact_setup_options, - ) + return { + "year": year, + "taxRevenueImpact": tax_revenue_impact, + "federalTaxRevenueImpact": tax_revenue_impact - state_tax_revenue_impact, + "stateTaxRevenueImpact": state_tax_revenue_impact, + "benefitSpendingImpact": budget.get("benefit_spending_impact", 0), + "budgetaryImpact": budget.get("budgetary_impact", 0), + } - raise ValueError(f"Unexpected impact action: {impact_action}") + def _sum_budget_window_annual_impacts(self, annual_impacts: list[dict]) -> dict: + totals = { + "year": "Total", + "taxRevenueImpact": 0, + "federalTaxRevenueImpact": 0, + "stateTaxRevenueImpact": 0, + "benefitSpendingImpact": 0, + "budgetaryImpact": 0, + } - except Exception as e: - print(f"Error getting economic impact: {str(e)}") - raise e + for annual_impact in annual_impacts: + totals["taxRevenueImpact"] += annual_impact["taxRevenueImpact"] + totals["federalTaxRevenueImpact"] += annual_impact[ + "federalTaxRevenueImpact" + ] + totals["stateTaxRevenueImpact"] += annual_impact["stateTaxRevenueImpact"] + totals["benefitSpendingImpact"] += annual_impact["benefitSpendingImpact"] + totals["budgetaryImpact"] += annual_impact["budgetaryImpact"] + + return totals + + def _build_budget_window_output( + self, *, start_year: str, window_size: int, annual_impacts: list[dict] + ) -> dict: + return { + "kind": "budgetWindow", + "startYear": start_year, + "endYear": str(int(start_year) + window_size - 1), + "windowSize": window_size, + "annualImpacts": annual_impacts, + "totals": self._sum_budget_window_annual_impacts(annual_impacts), + } + + def _build_budget_window_progress_message( + self, + *, + completed_years: list[str], + total_years: int, + computing_years: list[str], + queued_years: list[str], + ) -> str: + completed_count = len(completed_years) + if computing_years: + active_years = ", ".join(computing_years[:2]) + if len(computing_years) > 2: + active_years = f"{active_years} + {len(computing_years) - 2} more" + return f"Scoring {active_years} ({completed_count} of {total_years} complete)..." + + if queued_years: + return f"Queued {queued_years[0]} ({completed_count} of {total_years} complete)..." + + return f"Scoring budget window ({completed_count} of {total_years} complete)..." def _get_previous_impacts( self, diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index c49783bad..0bb2f3c18 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -35,6 +35,23 @@ ) +def make_mock_budget_impact_data( + *, + tax_revenue_impact: int, + state_tax_revenue_impact: int, + benefit_spending_impact: int, + budgetary_impact: int, +): + return { + "budget": { + "tax_revenue_impact": tax_revenue_impact, + "state_tax_revenue_impact": state_tax_revenue_impact, + "benefit_spending_impact": benefit_spending_impact, + "budgetary_impact": budgetary_impact, + } + } + + class TestEconomyService: class TestGetEconomicImpact: @pytest.fixture @@ -233,6 +250,258 @@ def test__given_exception__raises_error( economy_service.get_economic_impact(**base_params) assert str(exc_info.value) == "Database error" + class TestGetBudgetWindowEconomicImpact: + @pytest.fixture + def economy_service(self): + return EconomyService() + + @pytest.fixture + def base_params(self): + return { + "country_id": MOCK_COUNTRY_ID, + "policy_id": MOCK_POLICY_ID, + "baseline_policy_id": MOCK_BASELINE_POLICY_ID, + "region": MOCK_REGION, + "dataset": MOCK_DATASET, + "start_year": "2026", + "window_size": 3, + "options": MOCK_OPTIONS, + "api_version": MOCK_API_VERSION, + "target": "general", + } + + def test__given_all_years_completed__returns_aggregated_budget_window_result( + self, economy_service, base_params + ): + def make_setup(*, time_period, **_kwargs): + return EconomicImpactSetupOptions( + process_id=MOCK_PROCESS_ID, + country_id=MOCK_COUNTRY_ID, + reform_policy_id=MOCK_POLICY_ID, + baseline_policy_id=MOCK_BASELINE_POLICY_ID, + region=MOCK_REGION, + dataset=MOCK_DATASET, + time_period=time_period, + options=MOCK_OPTIONS, + api_version=MOCK_API_VERSION, + target="general", + options_hash=MOCK_OPTIONS_HASH, + ) + + yearly_results = { + "2026": EconomicImpactResult.completed( + make_mock_budget_impact_data( + tax_revenue_impact=100, + state_tax_revenue_impact=20, + benefit_spending_impact=-10, + budgetary_impact=90, + ) + ), + "2027": EconomicImpactResult.completed( + make_mock_budget_impact_data( + tax_revenue_impact=120, + state_tax_revenue_impact=30, + benefit_spending_impact=-20, + budgetary_impact=100, + ) + ), + "2028": EconomicImpactResult.completed( + make_mock_budget_impact_data( + tax_revenue_impact=140, + state_tax_revenue_impact=40, + benefit_spending_impact=-30, + budgetary_impact=110, + ) + ), + } + + with ( + patch.object( + economy_service, + "_build_economic_impact_setup_options", + side_effect=make_setup, + ), + patch.object( + economy_service, + "_get_existing_economic_impact", + side_effect=lambda setup_options: yearly_results[ + setup_options.time_period + ], + ) as mock_get_existing, + patch.object( + economy_service, "get_economic_impact" + ) as mock_get_economic_impact, + ): + result = economy_service.get_budget_window_economic_impact( + **base_params + ) + + assert result.status == ImpactStatus.OK + assert result.progress == 100 + assert result.data["annualImpacts"] == [ + { + "year": "2026", + "taxRevenueImpact": 100, + "federalTaxRevenueImpact": 80, + "stateTaxRevenueImpact": 20, + "benefitSpendingImpact": -10, + "budgetaryImpact": 90, + }, + { + "year": "2027", + "taxRevenueImpact": 120, + "federalTaxRevenueImpact": 90, + "stateTaxRevenueImpact": 30, + "benefitSpendingImpact": -20, + "budgetaryImpact": 100, + }, + { + "year": "2028", + "taxRevenueImpact": 140, + "federalTaxRevenueImpact": 100, + "stateTaxRevenueImpact": 40, + "benefitSpendingImpact": -30, + "budgetaryImpact": 110, + }, + ] + assert result.data["totals"] == { + "year": "Total", + "taxRevenueImpact": 360, + "federalTaxRevenueImpact": 270, + "stateTaxRevenueImpact": 90, + "benefitSpendingImpact": -60, + "budgetaryImpact": 300, + } + assert mock_get_existing.call_count == 3 + mock_get_economic_impact.assert_not_called() + + def test__given_missing_years__starts_only_up_to_remaining_active_slots( + self, economy_service, base_params + ): + def make_setup(*, time_period, **_kwargs): + return EconomicImpactSetupOptions( + process_id=MOCK_PROCESS_ID, + country_id=MOCK_COUNTRY_ID, + reform_policy_id=MOCK_POLICY_ID, + baseline_policy_id=MOCK_BASELINE_POLICY_ID, + region=MOCK_REGION, + dataset=MOCK_DATASET, + time_period=time_period, + options=MOCK_OPTIONS, + api_version=MOCK_API_VERSION, + target="general", + options_hash=MOCK_OPTIONS_HASH, + ) + + base_params["window_size"] = 5 + + existing_results = { + "2026": EconomicImpactResult.completed( + make_mock_budget_impact_data( + tax_revenue_impact=100, + state_tax_revenue_impact=20, + benefit_spending_impact=-10, + budgetary_impact=90, + ) + ), + "2027": EconomicImpactResult.computing(), + "2028": None, + "2029": None, + "2030": None, + } + + with ( + patch.object( + economy_service, + "_build_economic_impact_setup_options", + side_effect=make_setup, + ), + patch.object( + economy_service, + "_get_existing_economic_impact", + side_effect=lambda setup_options: existing_results[ + setup_options.time_period + ], + ), + patch.object( + economy_service, + "get_economic_impact", + return_value=EconomicImpactResult.computing(), + ) as mock_get_economic_impact, + ): + result = economy_service.get_budget_window_economic_impact( + **base_params + ) + + assert result.status == ImpactStatus.COMPUTING + assert result.progress == 20 + assert result.completed_years == ["2026"] + assert result.computing_years == ["2027", "2028", "2029"] + assert result.queued_years == ["2030"] + assert "1 of 5 complete" in result.message + assert mock_get_economic_impact.call_count == 2 + started_years = sorted( + call.kwargs["time_period"] + for call in mock_get_economic_impact.call_args_list + ) + assert started_years == ["2028", "2029"] + + def test__given_year_error__returns_budget_window_error( + self, economy_service, base_params + ): + def make_setup(*, time_period, **_kwargs): + return EconomicImpactSetupOptions( + process_id=MOCK_PROCESS_ID, + country_id=MOCK_COUNTRY_ID, + reform_policy_id=MOCK_POLICY_ID, + baseline_policy_id=MOCK_BASELINE_POLICY_ID, + region=MOCK_REGION, + dataset=MOCK_DATASET, + time_period=time_period, + options=MOCK_OPTIONS, + api_version=MOCK_API_VERSION, + target="general", + options_hash=MOCK_OPTIONS_HASH, + ) + + with ( + patch.object( + economy_service, + "_build_economic_impact_setup_options", + side_effect=make_setup, + ), + patch.object( + economy_service, + "_get_existing_economic_impact", + side_effect=[ + EconomicImpactResult.completed( + make_mock_budget_impact_data( + tax_revenue_impact=100, + state_tax_revenue_impact=20, + benefit_spending_impact=-10, + budgetary_impact=90, + ) + ), + EconomicImpactResult( + status=ImpactStatus.ERROR, + data={"message": "Calculation failed for 2027"}, + ), + None, + ], + ), + patch.object( + economy_service, "get_economic_impact" + ) as mock_get_economic_impact, + ): + result = economy_service.get_budget_window_economic_impact( + **base_params + ) + + assert result.status == ImpactStatus.ERROR + assert result.error == "Calculation failed for 2027" + assert result.completed_years == ["2026"] + mock_get_economic_impact.assert_not_called() + class TestGetPreviousImpacts: @pytest.fixture def economy_service(self): From b0f0f8ad46ebc7f8155d4b91c27fbd1901ed9cdd Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Wed, 8 Apr 2026 21:51:06 -0400 Subject: [PATCH 03/13] Harden budget window batch API --- policyengine_api/routes/economy_routes.py | 114 +++++++++++------- policyengine_api/services/economy_service.py | 68 ++++++++--- .../services/reform_impacts_service.py | 3 +- .../test_economy_budget_window_routes.py | 107 ++++++++++++++++ tests/unit/services/test_economy_service.py | 19 ++- 5 files changed, 247 insertions(+), 64 deletions(-) create mode 100644 tests/to_refactor/python/test_economy_budget_window_routes.py diff --git a/policyengine_api/routes/economy_routes.py b/policyengine_api/routes/economy_routes.py index a55e75d37..9b71532da 100644 --- a/policyengine_api/routes/economy_routes.py +++ b/policyengine_api/routes/economy_routes.py @@ -14,6 +14,25 @@ economy_service = EconomyService() +def _json_response(payload: dict, status: int = 200) -> Response: + return Response( + json.dumps(payload), + status=status, + mimetype="application/json", + ) + + +def _bad_request_response(message: str) -> Response: + return _json_response( + { + "status": "error", + "message": message, + "result": None, + }, + status=400, + ) + + @validate_country @economy_bp.route( "//economy//over/", @@ -58,16 +77,12 @@ def get_economic_impact(country_id: str, policy_id: int, baseline_policy_id: int result_dict: dict[str, str | dict | None] = economic_impact_result.to_dict() - return Response( - json.dumps( - { - "status": result_dict["status"], - "message": None, - "result": result_dict["data"], - } - ), - status=200, - mimetype="application/json", + return _json_response( + { + "status": result_dict["status"], + "message": None, + "result": result_dict["data"], + } ) @@ -87,10 +102,23 @@ def get_budget_window_economic_impact( query_parameters = request.args options = dict(query_parameters) options = json.loads(json.dumps(options)) - region = options.pop("region") + region = options.pop("region", None) + if not region: + return _bad_request_response("Missing required query parameter: region") + dataset = options.pop("dataset", "default") - start_year = options.pop("start_year") - window_size = int(options.pop("window_size")) + start_year = options.pop("start_year", None) + if not start_year: + return _bad_request_response("Missing required query parameter: start_year") + + window_size_raw = options.pop("window_size", None) + if window_size_raw is None: + return _bad_request_response("Missing required query parameter: window_size") + + try: + window_size = int(window_size_raw) + except (TypeError, ValueError): + return _bad_request_response("window_size must be an integer") include_district_breakdowns_raw = options.pop( "include_district_breakdowns", "false" @@ -100,38 +128,42 @@ def get_budget_window_economic_impact( dataset = "national-with-breakdowns" target: Literal["general", "cliff"] = options.pop("target", "general") + if target != "general": + return _bad_request_response( + "Budget-window calculations only support target=general" + ) + api_version = options.pop("version", COUNTRY_PACKAGE_VERSIONS.get(country_id)) - economic_impact_result: BudgetWindowEconomicImpactResult = ( - economy_service.get_budget_window_economic_impact( - country_id=country_id, - policy_id=policy_id, - baseline_policy_id=baseline_policy_id, - region=region, - dataset=dataset, - start_year=start_year, - window_size=window_size, - options=options, - api_version=api_version, - target=target, + try: + economic_impact_result: BudgetWindowEconomicImpactResult = ( + economy_service.get_budget_window_economic_impact( + country_id=country_id, + policy_id=policy_id, + baseline_policy_id=baseline_policy_id, + region=region, + dataset=dataset, + start_year=start_year, + window_size=window_size, + options=options, + api_version=api_version, + target=target, + ) ) - ) + except ValueError as error: + return _bad_request_response(str(error)) result_dict = economic_impact_result.to_dict() - return Response( - json.dumps( - { - "status": result_dict["status"], - "message": result_dict["message"], - "result": result_dict["data"], - "progress": result_dict["progress"], - "completed_years": result_dict["completed_years"], - "computing_years": result_dict["computing_years"], - "queued_years": result_dict["queued_years"], - "error": result_dict["error"], - } - ), - status=200, - mimetype="application/json", + return _json_response( + { + "status": result_dict["status"], + "message": result_dict["message"], + "result": result_dict["data"], + "progress": result_dict["progress"], + "completed_years": result_dict["completed_years"], + "computing_years": result_dict["computing_years"], + "queued_years": result_dict["queued_years"], + "error": result_dict["error"], + } ) diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index caac4d2f6..c91d6bd3e 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -27,6 +27,7 @@ import numpy as np from enum import Enum from concurrent.futures import ThreadPoolExecutor +from threading import Lock load_dotenv() @@ -58,6 +59,7 @@ class ImpactStatus(Enum): COMPLETE_STATUSES = [ImpactStatus.OK.value, ImpactStatus.ERROR.value] COMPUTING_STATUS = ImpactStatus.COMPUTING.value BUDGET_WINDOW_MAX_ACTIVE_YEARS = 3 +IMPACT_CREATION_LOCK = Lock() class EconomicImpactSetupOptions(BaseModel): @@ -263,6 +265,11 @@ def get_budget_window_economic_impact( if country_id == "us": region = normalize_us_region(region) + if target != "general": + raise ValueError( + "Budget-window calculations only support target='general'" + ) + start_year_int = int(start_year) if window_size < 1: raise ValueError("window_size must be at least 1") @@ -331,16 +338,8 @@ def get_budget_window_economic_impact( ( year, executor.submit( - self.get_economic_impact, - country_id=country_id, - policy_id=policy_id, - baseline_policy_id=baseline_policy_id, - region=region, - dataset=dataset, - time_period=year, - options=dict(options), - api_version=api_version, - target=target, + self._get_or_create_economic_impact, + setup_options_by_year[year], ), ) for year in years_to_start @@ -488,14 +487,47 @@ def _get_or_create_economic_impact( ) if impact_action == ImpactAction.CREATE: - logger.log_struct( - { - "message": "No previous economic impact record found in db; creating new simulation run", - **setup_options.model_dump(), - }, - severity="INFO", - ) - return self._handle_create_impact(setup_options=setup_options) + with IMPACT_CREATION_LOCK: + most_recent_impact = self._get_most_recent_impact( + setup_options=setup_options + ) + impact_action = self._determine_impact_action( + most_recent_impact=most_recent_impact + ) + + if impact_action == ImpactAction.COMPLETED: + logger.log_struct( + { + "message": "Found completed economic impact in db after locking; returning result", + **setup_options.model_dump(), + }, + severity="INFO", + ) + return self._handle_completed_impact( + most_recent_impact=most_recent_impact + ) + + if impact_action == ImpactAction.COMPUTING: + logger.log_struct( + { + "message": "Found computing economic impact in db after locking; returning progress", + **setup_options.model_dump(), + }, + severity="INFO", + ) + return self._handle_computing_impact( + setup_options=setup_options, + most_recent_impact=most_recent_impact, + ) + + logger.log_struct( + { + "message": "No previous economic impact record found in db; creating new simulation run", + **setup_options.model_dump(), + }, + severity="INFO", + ) + return self._handle_create_impact(setup_options=setup_options) raise ValueError(f"Unexpected impact action: {impact_action}") diff --git a/policyengine_api/services/reform_impacts_service.py b/policyengine_api/services/reform_impacts_service.py index ca44ea10c..fabe345d5 100644 --- a/policyengine_api/services/reform_impacts_service.py +++ b/policyengine_api/services/reform_impacts_service.py @@ -25,7 +25,8 @@ def get_all_reform_impacts( "SELECT reform_impact_json, status, message, start_time, execution_id FROM " "reform_impact WHERE country_id = ? AND reform_policy_id = ? AND " "baseline_policy_id = ? AND region = ? AND time_period = ? AND " - "options_hash = ? AND api_version = ? AND dataset = ?" + "options_hash = ? AND api_version = ? AND dataset = ? " + "ORDER BY start_time DESC" ) return local_database.query( query, diff --git a/tests/to_refactor/python/test_economy_budget_window_routes.py b/tests/to_refactor/python/test_economy_budget_window_routes.py new file mode 100644 index 000000000..10148e973 --- /dev/null +++ b/tests/to_refactor/python/test_economy_budget_window_routes.py @@ -0,0 +1,107 @@ +import json +from unittest.mock import Mock, patch + + +@patch( + "policyengine_api.routes.economy_routes.economy_service.get_budget_window_economic_impact" +) +def test_budget_window_route_rejects_cliff_target( + mock_get_budget_window_economic_impact, rest_client +): + response = rest_client.get( + "/us/economy/123/over/456/budget-window" + "?region=us&start_year=2026&window_size=10&target=cliff" + ) + + data = json.loads(response.data) + + assert response.status_code == 400 + assert data["status"] == "error" + assert "target=general" in data["message"] + mock_get_budget_window_economic_impact.assert_not_called() + + +@patch( + "policyengine_api.routes.economy_routes.economy_service.get_budget_window_economic_impact" +) +def test_budget_window_route_requires_window_size( + mock_get_budget_window_economic_impact, rest_client +): + response = rest_client.get( + "/us/economy/123/over/456/budget-window?region=us&start_year=2026" + ) + + data = json.loads(response.data) + + assert response.status_code == 400 + assert data["status"] == "error" + assert "window_size" in data["message"] + mock_get_budget_window_economic_impact.assert_not_called() + + +@patch( + "policyengine_api.routes.economy_routes.economy_service.get_budget_window_economic_impact" +) +def test_budget_window_route_requires_integer_window_size( + mock_get_budget_window_economic_impact, rest_client +): + response = rest_client.get( + "/us/economy/123/over/456/budget-window" + "?region=us&start_year=2026&window_size=abc" + ) + + data = json.loads(response.data) + + assert response.status_code == 400 + assert data["status"] == "error" + assert "window_size must be an integer" == data["message"] + mock_get_budget_window_economic_impact.assert_not_called() + + +@patch( + "policyengine_api.routes.economy_routes.economy_service.get_budget_window_economic_impact" +) +def test_budget_window_route_passes_version_to_service( + mock_get_budget_window_economic_impact, rest_client +): + mock_result = Mock() + mock_result.to_dict.return_value = { + "status": "ok", + "message": None, + "data": { + "kind": "budgetWindow", + "startYear": "2026", + "endYear": "2027", + "windowSize": 2, + "annualImpacts": [], + "totals": {}, + }, + "progress": 100, + "completed_years": ["2026", "2027"], + "computing_years": [], + "queued_years": [], + "error": None, + } + mock_get_budget_window_economic_impact.return_value = mock_result + + response = rest_client.get( + "/us/economy/123/over/456/budget-window" + "?region=us&start_year=2026&window_size=2&version=1.2.3" + ) + + data = json.loads(response.data) + + assert response.status_code == 200 + assert data["status"] == "ok" + mock_get_budget_window_economic_impact.assert_called_once_with( + country_id="us", + policy_id=123, + baseline_policy_id=456, + region="us", + dataset="default", + start_year="2026", + window_size=2, + options={}, + api_version="1.2.3", + target="general", + ) diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index 0bb2f3c18..2c892d5a1 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -329,7 +329,7 @@ def make_setup(*, time_period, **_kwargs): ], ) as mock_get_existing, patch.object( - economy_service, "get_economic_impact" + economy_service, "_get_or_create_economic_impact" ) as mock_get_economic_impact, ): result = economy_service.get_budget_window_economic_impact( @@ -425,7 +425,7 @@ def make_setup(*, time_period, **_kwargs): ), patch.object( economy_service, - "get_economic_impact", + "_get_or_create_economic_impact", return_value=EconomicImpactResult.computing(), ) as mock_get_economic_impact, ): @@ -441,7 +441,7 @@ def make_setup(*, time_period, **_kwargs): assert "1 of 5 complete" in result.message assert mock_get_economic_impact.call_count == 2 started_years = sorted( - call.kwargs["time_period"] + call.args[0].time_period for call in mock_get_economic_impact.call_args_list ) assert started_years == ["2028", "2029"] @@ -490,7 +490,7 @@ def make_setup(*, time_period, **_kwargs): ], ), patch.object( - economy_service, "get_economic_impact" + economy_service, "_get_or_create_economic_impact" ) as mock_get_economic_impact, ): result = economy_service.get_budget_window_economic_impact( @@ -502,6 +502,17 @@ def make_setup(*, time_period, **_kwargs): assert result.completed_years == ["2026"] mock_get_economic_impact.assert_not_called() + def test__given_cliff_target__raises_value_error( + self, economy_service, base_params + ): + base_params["target"] = "cliff" + + with pytest.raises( + ValueError, + match="Budget-window calculations only support target='general'", + ): + economy_service.get_budget_window_economic_impact(**base_params) + class TestGetPreviousImpacts: @pytest.fixture def economy_service(self): From cdfbf010bafb8a68f26275436627f4d5d5558ff1 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Thu, 9 Apr 2026 07:42:06 -0400 Subject: [PATCH 04/13] Address budget window review findings --- policyengine_api/services/economy_service.py | 38 +++++-- .../test_economy_budget_window_routes.py | 13 +++ tests/unit/services/test_economy_service.py | 106 ++++++++++++++++++ 3 files changed, 149 insertions(+), 8 deletions(-) diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index d2d3227e6..070bba4ca 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -60,6 +60,7 @@ class ImpactStatus(Enum): COMPLETE_STATUSES = [ImpactStatus.OK.value, ImpactStatus.ERROR.value] COMPUTING_STATUS = ImpactStatus.COMPUTING.value BUDGET_WINDOW_MAX_ACTIVE_YEARS = 3 +BUDGET_WINDOW_MAX_YEARS = 20 IMPACT_CREATION_LOCK = Lock() @@ -87,6 +88,7 @@ class EconomicImpactResult(BaseModel): status: ImpactStatus data: Optional[dict] = None + message: Optional[str] = None model_config = {"frozen": True} # Make model immutable @@ -119,7 +121,7 @@ def error(cls, message: str) -> "EconomicImpactResult": Create an EconomicImpactResult for an error in the impact calculation. """ logger.log_struct({"message": message}, severity="ERROR") - return cls(status=ImpactStatus.ERROR, data=None) + return cls(status=ImpactStatus.ERROR, data=None, message=message) class BudgetWindowEconomicImpactResult(BaseModel): @@ -272,8 +274,10 @@ def get_budget_window_economic_impact( ) start_year_int = int(start_year) - if window_size < 1: - raise ValueError("window_size must be at least 1") + if not 1 <= window_size <= BUDGET_WINDOW_MAX_YEARS: + raise ValueError( + f"window_size must be between 1 and {BUDGET_WINDOW_MAX_YEARS}" + ) years = [str(start_year_int + index) for index in range(window_size)] setup_options_by_year = { @@ -320,9 +324,10 @@ def get_budget_window_economic_impact( if completed_year in completed_impacts ] return BudgetWindowEconomicImpactResult.failed( - result.data.get("message") - if isinstance(result.data, dict) - else f"Budget-window calculation failed for {year}", + self._get_economic_impact_error_message( + result=result, + year=year, + ), completed_years=completed_years, computing_years=computing_years, queued_years=queued_years, @@ -364,7 +369,10 @@ def get_budget_window_economic_impact( if completed_year in completed_impacts ] return BudgetWindowEconomicImpactResult.failed( - f"Budget-window calculation failed for {year}", + self._get_economic_impact_error_message( + result=result, + year=year, + ), completed_years=completed_years, computing_years=computing_years, queued_years=remaining_queued_years, @@ -547,7 +555,8 @@ def _get_existing_economic_impact( ) return EconomicImpactResult( status=ImpactStatus.ERROR, - data={"message": error_message}, + data=None, + message=error_message, ) if status == ImpactStatus.OK.value: @@ -561,6 +570,19 @@ def _get_existing_economic_impact( raise ValueError(f"Unknown impact status: {status}") + def _get_economic_impact_error_message( + self, result: EconomicImpactResult, year: str + ) -> str: + if result.message: + return result.message + + if isinstance(result.data, dict): + data_message = result.data.get("message") + if isinstance(data_message, str) and data_message: + return data_message + + return f"Budget-window calculation failed for {year}" + def _extract_budget_window_annual_impact( self, year: str, impact_data: dict ) -> dict[str, Union[str, int, float]]: diff --git a/tests/to_refactor/python/test_economy_budget_window_routes.py b/tests/to_refactor/python/test_economy_budget_window_routes.py index 10148e973..fca938948 100644 --- a/tests/to_refactor/python/test_economy_budget_window_routes.py +++ b/tests/to_refactor/python/test_economy_budget_window_routes.py @@ -58,6 +58,19 @@ def test_budget_window_route_requires_integer_window_size( mock_get_budget_window_economic_impact.assert_not_called() +def test_budget_window_route_rejects_oversized_window(rest_client): + response = rest_client.get( + "/us/economy/123/over/456/budget-window" + "?region=us&start_year=2026&window_size=999" + ) + + data = json.loads(response.data) + + assert response.status_code == 400 + assert data["status"] == "error" + assert "window_size must be between 1 and" in data["message"] + + @patch( "policyengine_api.routes.economy_routes.economy_service.get_budget_window_economic_impact" ) diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index e52855975..0e0be9d5b 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -4,6 +4,7 @@ from typing import Literal from policyengine_api.services.economy_service import ( + BUDGET_WINDOW_MAX_YEARS, EconomyService, EconomicImpactResult, EconomicImpactSetupOptions, @@ -552,6 +553,104 @@ def test__given_cliff_target__raises_value_error( ): economy_service.get_budget_window_economic_impact(**base_params) + def test__given_oversized_window__raises_value_error( + self, economy_service, base_params + ): + base_params["window_size"] = BUDGET_WINDOW_MAX_YEARS + 1 + + with pytest.raises( + ValueError, + match=(f"window_size must be between 1 and {BUDGET_WINDOW_MAX_YEARS}"), + ): + economy_service.get_budget_window_economic_impact(**base_params) + + def test__given_started_year_error__returns_specific_budget_window_error( + self, economy_service, base_params + ): + with ( + patch.object( + economy_service, + "_get_existing_economic_impact", + side_effect=[None, None, None], + ), + patch.object( + economy_service, + "_get_or_create_economic_impact", + side_effect=[ + EconomicImpactResult.error("Calculation failed for 2026"), + EconomicImpactResult.computing(), + EconomicImpactResult.computing(), + ], + ), + ): + result = economy_service.get_budget_window_economic_impact( + **base_params + ) + + assert result.status == ImpactStatus.ERROR + assert result.error == "Calculation failed for 2026" + assert result.completed_years == [] + + def test__given_runtime_cache_version__uses_versioned_cache_key_for_budget_window( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_logger, + mock_datetime, + mock_numpy_random, + monkeypatch, + ): + cache_version = "e1cache01" + seen_existing_calls = [] + seen_create_calls = [] + + monkeypatch.setattr( + "policyengine_api.services.economy_service.get_economy_impact_cache_version", + lambda country_id, api_version=None: cache_version, + ) + + def fake_get_existing(setup_options): + seen_existing_calls.append( + (setup_options.time_period, setup_options.api_version) + ) + return None + + def fake_get_or_create(setup_options): + seen_create_calls.append( + (setup_options.time_period, setup_options.api_version) + ) + return EconomicImpactResult.computing() + + with ( + patch.object( + economy_service, + "_get_existing_economic_impact", + side_effect=fake_get_existing, + ), + patch.object( + economy_service, + "_get_or_create_economic_impact", + side_effect=fake_get_or_create, + ), + ): + result = economy_service.get_budget_window_economic_impact( + **base_params + ) + + assert result.status == ImpactStatus.COMPUTING + assert seen_existing_calls == [ + ("2026", cache_version), + ("2027", cache_version), + ("2028", cache_version), + ] + assert seen_create_calls == [ + ("2026", cache_version), + ("2027", cache_version), + ("2028", cache_version), + ] + class TestGetPreviousImpacts: @pytest.fixture def economy_service(self): @@ -730,6 +829,7 @@ def test__given_failed_state__returns_error_result( assert result.status == ImpactStatus.ERROR assert result.data is None + assert result.message == "Simulation API execution failed" mock_reform_impacts_service.set_error_reform_impact.assert_called_once() def test__given_active_state__returns_computing_result( @@ -801,6 +901,7 @@ def test__given_modal_failed_state__then_returns_error_result( # Then assert result.status == ImpactStatus.ERROR assert result.data is None + assert result.message == "Simulation API execution failed" mock_reform_impacts_service.set_error_reform_impact.assert_called_once() def test__given_modal_failed_state_with_error_message__then_includes_error_in_message( @@ -822,6 +923,10 @@ def test__given_modal_failed_state_with_error_message__then_includes_error_in_me # Then assert result.status == ImpactStatus.ERROR + assert ( + result.message + == "Simulation API execution failed: Simulation timed out" + ) # Verify the error message was passed to the service call_args = mock_reform_impacts_service.set_error_reform_impact.call_args assert "Simulation timed out" in call_args[1]["message"] @@ -919,6 +1024,7 @@ def test__given_error__creates_correct_instance_and_logs(self): assert result.status == ImpactStatus.ERROR assert result.data is None + assert result.message == "Test error message" mock_logger.log_struct.assert_called_once() From f9062cb74f3a1d1662fb9c748d37affd240a40cb Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Thu, 9 Apr 2026 10:03:07 -0400 Subject: [PATCH 05/13] Prevent duplicate budget window jobs across workers --- .../endpoints/economy/reform_impact.py | 4 +- policyengine_api/endpoints/simulation.py | 4 +- policyengine_api/services/economy_service.py | 13 +- .../services/reform_impacts_service.py | 91 ++++++++++++-- tests/unit/services/test_economy_service.py | 4 +- .../services/test_reform_impacts_service.py | 117 ++++++++++++++++++ 6 files changed, 216 insertions(+), 17 deletions(-) create mode 100644 tests/unit/services/test_reform_impacts_service.py diff --git a/policyengine_api/endpoints/economy/reform_impact.py b/policyengine_api/endpoints/economy/reform_impact.py index 42795243d..cc778bcef 100644 --- a/policyengine_api/endpoints/economy/reform_impact.py +++ b/policyengine_api/endpoints/economy/reform_impact.py @@ -1,4 +1,4 @@ -from policyengine_api.data import local_database +from policyengine_api.data import database def set_comment_on_job( @@ -17,7 +17,7 @@ def set_comment_on_job( "time_period = ? AND options_hash = ? AND dataset = ?" ) - local_database.query( + database.query( query, ( comment, diff --git a/policyengine_api/endpoints/simulation.py b/policyengine_api/endpoints/simulation.py index 132e5b2d6..394f2a7e9 100644 --- a/policyengine_api/endpoints/simulation.py +++ b/policyengine_api/endpoints/simulation.py @@ -1,4 +1,4 @@ -from policyengine_api.data import local_database +from policyengine_api.data import database """ @@ -28,7 +28,7 @@ def get_simulations( desc_limit = f"DESC LIMIT {max_results}" if max_results is not None else "" - result = local_database.query( + result = database.query( f"SELECT * FROM reform_impact ORDER BY start_time {desc_limit}", ).fetchall() diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index 070bba4ca..1ecfcdfd2 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -28,7 +28,6 @@ import numpy as np from enum import Enum from concurrent.futures import ThreadPoolExecutor -from threading import Lock load_dotenv() @@ -61,7 +60,6 @@ class ImpactStatus(Enum): COMPUTING_STATUS = ImpactStatus.COMPUTING.value BUDGET_WINDOW_MAX_ACTIVE_YEARS = 3 BUDGET_WINDOW_MAX_YEARS = 20 -IMPACT_CREATION_LOCK = Lock() class EconomicImpactSetupOptions(BaseModel): @@ -497,7 +495,16 @@ def _get_or_create_economic_impact( ) if impact_action == ImpactAction.CREATE: - with IMPACT_CREATION_LOCK: + with reform_impacts_service.claim_lock( + country_id=setup_options.country_id, + policy_id=setup_options.reform_policy_id, + baseline_policy_id=setup_options.baseline_policy_id, + region=setup_options.region, + dataset=setup_options.dataset, + time_period=setup_options.time_period, + options_hash=setup_options.options_hash, + api_version=setup_options.api_version, + ): most_recent_impact = self._get_most_recent_impact( setup_options=setup_options ) diff --git a/policyengine_api/services/reform_impacts_service.py b/policyengine_api/services/reform_impacts_service.py index fabe345d5..3ca5c6c7d 100644 --- a/policyengine_api/services/reform_impacts_service.py +++ b/policyengine_api/services/reform_impacts_service.py @@ -1,14 +1,89 @@ -from policyengine_api.data import local_database +from contextlib import contextmanager +import hashlib +from threading import Lock +from policyengine_api.data import database import datetime +LOCAL_REFORM_IMPACT_LOCK = Lock() +REFORM_IMPACT_LOCK_TIMEOUT_SECONDS = 5 + + class ReformImpactsService: """ Service for storing and retrieving economy-wide reform impacts; - this is connected to the locally-stored reform_impact table - and no existing route + this is connected to the shared reform_impact table. """ + def _build_lock_name( + self, + country_id, + policy_id, + baseline_policy_id, + region, + dataset, + time_period, + options_hash, + api_version, + ) -> str: + raw_key = ( + f"{country_id}:{policy_id}:{baseline_policy_id}:{region}:{dataset}:" + f"{time_period}:{options_hash}:{api_version}" + ) + digest = hashlib.sha256(raw_key.encode("utf-8")).hexdigest() + return f"ri:{digest[:61]}" + + @contextmanager + def claim_lock( + self, + *, + country_id, + policy_id, + baseline_policy_id, + region, + dataset, + time_period, + options_hash, + api_version, + timeout_seconds: int = REFORM_IMPACT_LOCK_TIMEOUT_SECONDS, + ): + if database.local: + with LOCAL_REFORM_IMPACT_LOCK: + yield + return + + lock_name = self._build_lock_name( + country_id=country_id, + policy_id=policy_id, + baseline_policy_id=baseline_policy_id, + region=region, + dataset=dataset, + time_period=time_period, + options_hash=options_hash, + api_version=api_version, + ) + with database.pool.connect() as conn: + acquired = ( + conn.exec_driver_sql( + "SELECT GET_LOCK(%s, %s) AS acquired", + (lock_name, timeout_seconds), + ) + .mappings() + .first() + ) + if acquired is None or acquired["acquired"] != 1: + raise TimeoutError( + f"Could not acquire reform impact lock for {country_id}/{policy_id}/{time_period}" + ) + + try: + yield + finally: + conn.exec_driver_sql( + "SELECT RELEASE_LOCK(%s) AS released", (lock_name,) + ) + conn.commit() + def get_all_reform_impacts( self, country_id, @@ -28,7 +103,7 @@ def get_all_reform_impacts( "options_hash = ? AND api_version = ? AND dataset = ? " "ORDER BY start_time DESC" ) - return local_database.query( + return database.query( query, ( country_id, @@ -67,7 +142,7 @@ def set_reform_impact( "region, dataset, time_period, options_json, options_hash, status, api_version, " "reform_impact_json, start_time, execution_id) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)" ) - local_database.query( + database.query( query, ( country_id, @@ -107,7 +182,7 @@ def delete_reform_impact( "dataset = ? AND status = 'computing'" ) - local_database.query( + database.query( query, ( country_id, @@ -142,7 +217,7 @@ def set_error_reform_impact( "region = ? AND time_period = ? AND options_hash = ? AND dataset = ? AND " "execution_id = ?" ) - local_database.query( + database.query( query, ( "error", @@ -186,7 +261,7 @@ def set_complete_reform_impact( "baseline_policy_id = ? AND region = ? AND time_period = ? AND " "options_hash = ? AND dataset = ? AND execution_id = ?" ) - local_database.query( + database.query( query, ( "ok", diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index 0e0be9d5b..829b5caf8 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -487,7 +487,7 @@ def make_setup(*, time_period, **_kwargs): assert started_years == ["2028", "2029"] def test__given_year_error__returns_budget_window_error( - self, economy_service, base_params + self, economy_service, base_params, mock_logger ): def make_setup(*, time_period, **_kwargs): return EconomicImpactSetupOptions( @@ -565,7 +565,7 @@ def test__given_oversized_window__raises_value_error( economy_service.get_budget_window_economic_impact(**base_params) def test__given_started_year_error__returns_specific_budget_window_error( - self, economy_service, base_params + self, economy_service, base_params, mock_logger ): with ( patch.object( diff --git a/tests/unit/services/test_reform_impacts_service.py b/tests/unit/services/test_reform_impacts_service.py new file mode 100644 index 000000000..4f44b63a2 --- /dev/null +++ b/tests/unit/services/test_reform_impacts_service.py @@ -0,0 +1,117 @@ +from unittest.mock import MagicMock + +import pytest + +from policyengine_api.services.reform_impacts_service import ReformImpactsService + + +class TestReformImpactsService: + def test__given_remote_database__claim_lock_uses_advisory_lock(self, monkeypatch): + service = ReformImpactsService() + + acquired_result = MagicMock() + acquired_result.mappings.return_value.first.return_value = {"acquired": 1} + release_result = MagicMock() + + mock_connection = MagicMock() + mock_connection.exec_driver_sql.side_effect = [ + acquired_result, + release_result, + ] + + mock_connection_context = MagicMock() + mock_connection_context.__enter__.return_value = mock_connection + mock_connection_context.__exit__.return_value = False + + mock_pool = MagicMock() + mock_pool.connect.return_value = mock_connection_context + + mock_database = MagicMock() + mock_database.local = False + mock_database.pool = mock_pool + + monkeypatch.setattr( + "policyengine_api.services.reform_impacts_service.database", + mock_database, + ) + + with service.claim_lock( + country_id="us", + policy_id=123, + baseline_policy_id=456, + region="us", + dataset="enhanced_cps", + time_period="2026", + options_hash="[option=value]", + api_version="e1cache01", + ): + pass + + assert mock_connection.exec_driver_sql.call_count == 2 + + acquire_call = mock_connection.exec_driver_sql.call_args_list[0] + assert acquire_call.args == ( + "SELECT GET_LOCK(%s, %s) AS acquired", + ( + service._build_lock_name( + country_id="us", + policy_id=123, + baseline_policy_id=456, + region="us", + dataset="enhanced_cps", + time_period="2026", + options_hash="[option=value]", + api_version="e1cache01", + ), + 5, + ), + ) + assert len(acquire_call.args[1][0]) <= 64 + + release_call = mock_connection.exec_driver_sql.call_args_list[1] + assert release_call.args == ( + "SELECT RELEASE_LOCK(%s) AS released", + (acquire_call.args[1][0],), + ) + mock_connection.commit.assert_called_once() + + def test__given_remote_database_lock_timeout__claim_lock_raises(self, monkeypatch): + service = ReformImpactsService() + + acquired_result = MagicMock() + acquired_result.mappings.return_value.first.return_value = {"acquired": 0} + + mock_connection = MagicMock() + mock_connection.exec_driver_sql.return_value = acquired_result + + mock_connection_context = MagicMock() + mock_connection_context.__enter__.return_value = mock_connection + mock_connection_context.__exit__.return_value = False + + mock_pool = MagicMock() + mock_pool.connect.return_value = mock_connection_context + + mock_database = MagicMock() + mock_database.local = False + mock_database.pool = mock_pool + + monkeypatch.setattr( + "policyengine_api.services.reform_impacts_service.database", + mock_database, + ) + + with pytest.raises( + TimeoutError, + match="Could not acquire reform impact lock", + ): + with service.claim_lock( + country_id="us", + policy_id=123, + baseline_policy_id=456, + region="us", + dataset="enhanced_cps", + time_period="2026", + options_hash="[option=value]", + api_version="e1cache01", + ): + pass From 2542fd97ee0d9090b2209375ce0d558c113dc1a6 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Thu, 9 Apr 2026 10:25:18 -0400 Subject: [PATCH 06/13] Harden reform impact claim deduping --- policyengine_api/data/data.py | 20 +- .../endpoints/economy/reform_impact.py | 4 +- policyengine_api/endpoints/simulation.py | 4 +- policyengine_api/services/economy_service.py | 227 ++++++++++++++---- .../services/reform_impacts_service.py | 39 ++- tests/fixtures/services/economy_service.py | 10 +- tests/unit/data/test_sqlalchemy_v2.py | 32 +++ tests/unit/services/test_economy_service.py | 193 +++++++++++++++ 8 files changed, 469 insertions(+), 60 deletions(-) diff --git a/policyengine_api/data/data.py b/policyengine_api/data/data.py index 7dcb96c43..a2c0dd1ae 100644 --- a/policyengine_api/data/data.py +++ b/policyengine_api/data/data.py @@ -75,16 +75,20 @@ def _create_pool(self): with open(".dbpw") as f: db_pass = f.read().strip() db_name = "policyengine" - conn = self.connector.connect( - instance_connection_string=instance_connection_name, - driver="pymysql", - db=db_name, - user=db_user, - password=db_pass, - ) + + def get_connection(): + return self.connector.connect( + instance_connection_string=instance_connection_name, + driver="pymysql", + db=db_name, + user=db_user, + password=db_pass, + ) + self.pool = sqlalchemy.create_engine( "mysql+pymysql://", - creator=lambda: conn, + creator=get_connection, + pool_pre_ping=True, ) def _close_pool(self): diff --git a/policyengine_api/endpoints/economy/reform_impact.py b/policyengine_api/endpoints/economy/reform_impact.py index cc778bcef..42795243d 100644 --- a/policyengine_api/endpoints/economy/reform_impact.py +++ b/policyengine_api/endpoints/economy/reform_impact.py @@ -1,4 +1,4 @@ -from policyengine_api.data import database +from policyengine_api.data import local_database def set_comment_on_job( @@ -17,7 +17,7 @@ def set_comment_on_job( "time_period = ? AND options_hash = ? AND dataset = ?" ) - database.query( + local_database.query( query, ( comment, diff --git a/policyengine_api/endpoints/simulation.py b/policyengine_api/endpoints/simulation.py index 394f2a7e9..132e5b2d6 100644 --- a/policyengine_api/endpoints/simulation.py +++ b/policyengine_api/endpoints/simulation.py @@ -1,4 +1,4 @@ -from policyengine_api.data import database +from policyengine_api.data import local_database """ @@ -28,7 +28,7 @@ def get_simulations( desc_limit = f"DESC LIMIT {max_results}" if max_results is not None else "" - result = database.query( + result = local_database.query( f"SELECT * FROM reform_impact ORDER BY start_time {desc_limit}", ).fetchall() diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index 1ecfcdfd2..a51af49dd 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -60,6 +60,11 @@ class ImpactStatus(Enum): COMPUTING_STATUS = ImpactStatus.COMPUTING.value BUDGET_WINDOW_MAX_ACTIVE_YEARS = 3 BUDGET_WINDOW_MAX_YEARS = 20 +PENDING_EXECUTION_ID_PREFIX = "pending:" +PROVISIONAL_CLAIM_TTL_SECONDS = 90 +STALE_PROVISIONAL_IMPACT_MESSAGE = ( + "Simulation claim expired before job submission completed" +) class EconomicImpactSetupOptions(BaseModel): @@ -495,56 +500,88 @@ def _get_or_create_economic_impact( ) if impact_action == ImpactAction.CREATE: - with reform_impacts_service.claim_lock( - country_id=setup_options.country_id, - policy_id=setup_options.reform_policy_id, - baseline_policy_id=setup_options.baseline_policy_id, - region=setup_options.region, - dataset=setup_options.dataset, - time_period=setup_options.time_period, - options_hash=setup_options.options_hash, - api_version=setup_options.api_version, - ): - most_recent_impact = self._get_most_recent_impact( - setup_options=setup_options - ) - impact_action = self._determine_impact_action( - most_recent_impact=most_recent_impact - ) - - if impact_action == ImpactAction.COMPLETED: - logger.log_struct( - { - "message": "Found completed economic impact in db after locking; returning result", - **setup_options.model_dump(), - }, - severity="INFO", + try: + with reform_impacts_service.claim_lock( + country_id=setup_options.country_id, + policy_id=setup_options.reform_policy_id, + baseline_policy_id=setup_options.baseline_policy_id, + region=setup_options.region, + dataset=setup_options.dataset, + time_period=setup_options.time_period, + options_hash=setup_options.options_hash, + api_version=setup_options.api_version, + ): + most_recent_impact = self._get_most_recent_impact( + setup_options=setup_options ) - return self._handle_completed_impact( + impact_action = self._determine_impact_action( most_recent_impact=most_recent_impact ) - if impact_action == ImpactAction.COMPUTING: - logger.log_struct( - { - "message": "Found computing economic impact in db after locking; returning progress", - **setup_options.model_dump(), - }, - severity="INFO", + if impact_action == ImpactAction.COMPLETED: + logger.log_struct( + { + "message": "Found completed economic impact in db after locking; returning result", + **setup_options.model_dump(), + }, + severity="INFO", + ) + return self._handle_completed_impact( + most_recent_impact=most_recent_impact + ) + + if impact_action == ImpactAction.COMPUTING: + logger.log_struct( + { + "message": "Found computing economic impact in db after locking; returning progress", + **setup_options.model_dump(), + }, + severity="INFO", + ) + return self._handle_computing_impact( + setup_options=setup_options, + most_recent_impact=most_recent_impact, + ) + + if self._is_stale_provisional_impact(most_recent_impact): + self._expire_stale_provisional_impact( + setup_options=setup_options, + most_recent_impact=most_recent_impact, + ) + + provisional_execution_id = self._build_provisional_execution_id( + setup_options.process_id ) - return self._handle_computing_impact( + self._set_reform_impact_computing( setup_options=setup_options, - most_recent_impact=most_recent_impact, + execution_id=provisional_execution_id, ) - + except TimeoutError: logger.log_struct( { - "message": "No previous economic impact record found in db; creating new simulation run", + "message": "Timed out waiting for economic impact claim lock; re-checking existing claim", **setup_options.model_dump(), }, - severity="INFO", + severity="WARNING", + ) + existing_impact = self._get_existing_economic_impact( + setup_options=setup_options ) - return self._handle_create_impact(setup_options=setup_options) + if existing_impact is not None: + return existing_impact + return EconomicImpactResult.computing() + + logger.log_struct( + { + "message": "No previous economic impact record found in db; creating new simulation run", + **setup_options.model_dump(), + }, + severity="INFO", + ) + return self._handle_create_impact( + setup_options=setup_options, + provisional_execution_id=provisional_execution_id, + ) raise ValueError(f"Unexpected impact action: {impact_action}") @@ -570,6 +607,8 @@ def _get_existing_economic_impact( return self._handle_completed_impact(most_recent_impact=most_recent_impact) if status == ImpactStatus.COMPUTING.value: + if self._is_stale_provisional_impact(most_recent_impact): + return None return self._handle_computing_impact( setup_options=setup_options, most_recent_impact=most_recent_impact, @@ -711,6 +750,63 @@ def _get_most_recent_impact( return None + def _build_provisional_execution_id(self, process_id: str) -> str: + return f"{PENDING_EXECUTION_ID_PREFIX}{process_id}" + + def _is_provisional_execution_id(self, execution_id: Any) -> bool: + return isinstance(execution_id, str) and execution_id.startswith( + PENDING_EXECUTION_ID_PREFIX + ) + + def _coerce_impact_start_time(self, start_time: Any) -> Optional[datetime.datetime]: + if start_time is None: + return None + + if isinstance(start_time, str): + parsed_start_time = datetime.datetime.fromisoformat(start_time) + elif hasattr(start_time, "tzinfo") and hasattr(start_time, "isoformat"): + parsed_start_time = start_time + else: + return None + + if parsed_start_time.tzinfo is None: + return parsed_start_time.replace(tzinfo=datetime.timezone.utc) + + return parsed_start_time.astimezone(datetime.timezone.utc) + + def _is_stale_provisional_impact(self, impact: dict | None) -> bool: + if not impact: + return False + + if not self._is_provisional_execution_id(impact.get("execution_id")): + return False + + start_time = self._coerce_impact_start_time(impact.get("start_time")) + if start_time is None: + return False + + current_time = datetime.datetime.now(datetime.timezone.utc) + if current_time.tzinfo is None: + current_time = current_time.replace(tzinfo=datetime.timezone.utc) + + claim_age = current_time - start_time + return claim_age.total_seconds() > PROVISIONAL_CLAIM_TTL_SECONDS + + def _expire_stale_provisional_impact( + self, + setup_options: EconomicImpactSetupOptions, + most_recent_impact: dict, + ) -> None: + execution_id = most_recent_impact.get("execution_id") + if not self._is_provisional_execution_id(execution_id): + return + + self._set_reform_impact_error( + setup_options=setup_options, + message=STALE_PROVISIONAL_IMPACT_MESSAGE, + execution_id=execution_id, + ) + def _determine_impact_action( self, most_recent_impact: dict | None, @@ -723,6 +819,8 @@ def _determine_impact_action( if status in [ImpactStatus.OK.value, ImpactStatus.ERROR.value]: return ImpactAction.COMPLETED elif status == ImpactStatus.COMPUTING.value: + if self._is_stale_provisional_impact(most_recent_impact): + return ImpactAction.CREATE return ImpactAction.COMPUTING else: raise ValueError(f"Unknown impact status: {status}") @@ -798,10 +896,11 @@ def _handle_computing_impact( setup_options: EconomicImpactSetupOptions, most_recent_impact: dict, ) -> EconomicImpactResult: + execution_id = most_recent_impact["execution_id"] + if self._is_provisional_execution_id(execution_id): + return EconomicImpactResult.computing() - execution = simulation_api.get_execution_by_id( - most_recent_impact["execution_id"] - ) + execution = simulation_api.get_execution_by_id(execution_id) execution_state = simulation_api.get_execution_status(execution) return self._handle_execution_state( execution_state=execution_state, @@ -813,6 +912,7 @@ def _handle_computing_impact( def _handle_create_impact( self, setup_options: EconomicImpactSetupOptions, + provisional_execution_id: str, ) -> EconomicImpactResult: baseline_policy = policy_service.get_policy_json( @@ -852,8 +952,17 @@ def _handle_create_impact( "process_id": setup_options.process_id, } - sim_api_execution = simulation_api.run(sim_params) - execution_id = simulation_api.get_execution_id(sim_api_execution) + try: + sim_api_execution = simulation_api.run(sim_params) + execution_id = simulation_api.get_execution_id(sim_api_execution) + except Exception as error: + error_message = f"Failed to start simulation API job: {str(error)}" + self._set_reform_impact_error( + setup_options=setup_options, + message=error_message, + execution_id=provisional_execution_id, + ) + return EconomicImpactResult.error(message=error_message) progress_log = { **setup_options.model_dump(), @@ -862,9 +971,10 @@ def _handle_create_impact( } logger.log_struct(progress_log, severity="INFO") - self._set_reform_impact_computing( + self._update_reform_impact_execution_id( setup_options=setup_options, - execution_id=execution_id, + current_execution_id=provisional_execution_id, + new_execution_id=execution_id, ) return EconomicImpactResult.computing() @@ -1006,6 +1116,33 @@ def _set_reform_impact_computing( ) raise e + def _update_reform_impact_execution_id( + self, + setup_options: EconomicImpactSetupOptions, + current_execution_id: str, + new_execution_id: str, + ): + try: + reform_impacts_service.update_reform_impact_execution_id( + country_id=setup_options.country_id, + policy_id=setup_options.reform_policy_id, + baseline_policy_id=setup_options.baseline_policy_id, + region=setup_options.region, + dataset=setup_options.dataset, + time_period=setup_options.time_period, + options_hash=setup_options.options_hash, + current_execution_id=current_execution_id, + new_execution_id=new_execution_id, + ) + except Exception as e: + logger.log_struct( + { + "message": f"Error updating reform impact execution id: {str(e)}", + **setup_options.model_dump(), + } + ) + raise e + def _set_reform_impact_complete( self, setup_options: EconomicImpactSetupOptions, diff --git a/policyengine_api/services/reform_impacts_service.py b/policyengine_api/services/reform_impacts_service.py index 3ca5c6c7d..05f5c756a 100644 --- a/policyengine_api/services/reform_impacts_service.py +++ b/policyengine_api/services/reform_impacts_service.py @@ -101,7 +101,7 @@ def get_all_reform_impacts( "reform_impact WHERE country_id = ? AND reform_policy_id = ? AND " "baseline_policy_id = ? AND region = ? AND time_period = ? AND " "options_hash = ? AND api_version = ? AND dataset = ? " - "ORDER BY start_time DESC" + "ORDER BY start_time DESC, reform_impact_id DESC" ) return database.query( query, @@ -164,6 +164,43 @@ def set_reform_impact( print(f"Error setting reform impact: {str(e)}") raise e + def update_reform_impact_execution_id( + self, + country_id, + policy_id, + baseline_policy_id, + region, + dataset, + time_period, + options_hash, + current_execution_id, + new_execution_id, + ): + try: + query = ( + "UPDATE reform_impact SET execution_id = ? WHERE country_id = ? AND " + "reform_policy_id = ? AND baseline_policy_id = ? AND region = ? AND " + "time_period = ? AND options_hash = ? AND dataset = ? AND " + "execution_id = ? AND status = 'computing'" + ) + database.query( + query, + ( + new_execution_id, + country_id, + policy_id, + baseline_policy_id, + region, + time_period, + options_hash, + dataset, + current_execution_id, + ), + ) + except Exception as e: + print(f"Error updating reform impact execution id: {str(e)}") + raise e + def delete_reform_impact( self, country_id, diff --git a/tests/fixtures/services/economy_service.py b/tests/fixtures/services/economy_service.py index 687a82a48..f02b4159a 100644 --- a/tests/fixtures/services/economy_service.py +++ b/tests/fixtures/services/economy_service.py @@ -2,6 +2,7 @@ from unittest.mock import patch, MagicMock import json import datetime +from contextlib import nullcontext from policyengine_api.constants import ( MODAL_EXECUTION_STATUS_SUBMITTED, @@ -91,8 +92,10 @@ def mock_reform_impacts_service(): mock_service = MagicMock() mock_service.get_all_reform_impacts.return_value = [] mock_service.set_reform_impact.return_value = None + mock_service.update_reform_impact_execution_id.return_value = None mock_service.set_complete_reform_impact.return_value = None mock_service.set_error_reform_impact.return_value = None + mock_service.claim_lock.side_effect = lambda **kwargs: nullcontext() with patch( "policyengine_api.services.economy_service.reform_impacts_service", @@ -147,7 +150,10 @@ def mock_numpy_random(): def create_mock_reform_impact( - status="ok", reform_impact_json=None, execution_id=MOCK_MODAL_JOB_ID + status="ok", + reform_impact_json=None, + execution_id=MOCK_MODAL_JOB_ID, + start_time=None, ): """Helper function to create mock reform impact records.""" return { @@ -163,7 +169,7 @@ def create_mock_reform_impact( "api_version": MOCK_API_VERSION, "reform_impact_json": reform_impact_json or json.dumps(MOCK_REFORM_IMPACT_DATA), "execution_id": execution_id, - "start_time": datetime.datetime(2025, 6, 26, 12, 0, 0), + "start_time": start_time or datetime.datetime(2025, 6, 26, 12, 0, 0), "end_time": ( datetime.datetime(2025, 6, 26, 12, 5, 0) if status == "ok" else None ), diff --git a/tests/unit/data/test_sqlalchemy_v2.py b/tests/unit/data/test_sqlalchemy_v2.py index 3882bb0f7..2ea63f0f0 100644 --- a/tests/unit/data/test_sqlalchemy_v2.py +++ b/tests/unit/data/test_sqlalchemy_v2.py @@ -12,6 +12,7 @@ import pytest import sqlalchemy +from unittest.mock import MagicMock from policyengine_api.data.data import _ResultProxy, PolicyEngineDatabase @@ -180,3 +181,34 @@ def test_remote_delete(self): db._execute_remote(["DELETE FROM test_table WHERE id = ?", (1,)]) result = db._execute_remote(["SELECT * FROM test_table WHERE id = ?", (1,)]) assert result.fetchone() is None + + +class TestRemotePoolCreation: + def test_create_pool_uses_fresh_connection_creator(self, monkeypatch): + first_connection = MagicMock(name="first_connection") + second_connection = MagicMock(name="second_connection") + mock_connector = MagicMock() + mock_connector.connect.side_effect = [first_connection, second_connection] + + captured_kwargs = {} + + def fake_create_engine(url, **kwargs): + captured_kwargs.update(kwargs) + return MagicMock() + + monkeypatch.setenv("POLICYENGINE_DB_PASSWORD", "test-password") + monkeypatch.setattr( + "policyengine_api.data.data.Connector", lambda: mock_connector + ) + monkeypatch.setattr( + "policyengine_api.data.data.sqlalchemy.create_engine", + fake_create_engine, + ) + + db = PolicyEngineDatabase.__new__(PolicyEngineDatabase) + db._create_pool() + + creator = captured_kwargs["creator"] + assert creator() is first_connection + assert creator() is second_connection + assert captured_kwargs["pool_pre_ping"] is True diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index 829b5caf8..b4f94f02e 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -1,3 +1,4 @@ +import datetime import json import pytest from unittest.mock import patch, MagicMock @@ -10,6 +11,9 @@ EconomicImpactSetupOptions, ImpactAction, ImpactStatus, + PENDING_EXECUTION_ID_PREFIX, + PROVISIONAL_CLAIM_TTL_SECONDS, + STALE_PROVISIONAL_IMPACT_MESSAGE, ) from tests.fixtures.services.economy_service import ( MOCK_COUNTRY_ID, @@ -199,6 +203,17 @@ def test__given_no_previous_impact__creates_new_simulation( assert result.data is None mock_simulation_api.run.assert_called_once() mock_reform_impacts_service.set_reform_impact.assert_called_once() + mock_reform_impacts_service.update_reform_impact_execution_id.assert_called_once_with( + country_id=MOCK_COUNTRY_ID, + policy_id=MOCK_POLICY_ID, + baseline_policy_id=MOCK_BASELINE_POLICY_ID, + region=MOCK_REGION, + dataset=MOCK_DATASET, + time_period=MOCK_TIME_PERIOD, + options_hash=MOCK_OPTIONS_HASH, + current_execution_id=f"{PENDING_EXECUTION_ID_PREFIX}{MOCK_PROCESS_ID}", + new_execution_id=MOCK_EXECUTION_ID, + ) def test__given_no_previous_impact__includes_metadata_in_simulation_params( self, @@ -230,6 +245,114 @@ def test__given_no_previous_impact__includes_metadata_in_simulation_params( ) assert sim_params["_metadata"]["process_id"] == MOCK_PROCESS_ID + def test__given_simulation_api_submission_failure__marks_provisional_claim_error( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_policy_service, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + mock_datetime, + mock_numpy_random, + ): + mock_reform_impacts_service.get_all_reform_impacts.return_value = [] + mock_simulation_api.run.side_effect = RuntimeError("gateway unavailable") + + result = economy_service.get_economic_impact(**base_params) + + assert result.status == ImpactStatus.ERROR + assert ( + result.message + == "Failed to start simulation API job: gateway unavailable" + ) + mock_reform_impacts_service.set_reform_impact.assert_called_once() + mock_reform_impacts_service.set_error_reform_impact.assert_called_once_with( + country_id=MOCK_COUNTRY_ID, + policy_id=MOCK_POLICY_ID, + baseline_policy_id=MOCK_BASELINE_POLICY_ID, + region=MOCK_REGION, + dataset=MOCK_DATASET, + time_period=MOCK_TIME_PERIOD, + options_hash=MOCK_OPTIONS_HASH, + message="Failed to start simulation API job: gateway unavailable", + execution_id=f"{PENDING_EXECUTION_ID_PREFIX}{MOCK_PROCESS_ID}", + ) + mock_reform_impacts_service.update_reform_impact_execution_id.assert_not_called() + + def test__given_claim_lock_timeout_and_existing_provisional_claim__returns_computing( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_policy_service, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + mock_numpy_random, + ): + provisional_impact = create_mock_reform_impact( + status="computing", + execution_id=f"{PENDING_EXECUTION_ID_PREFIX}job_other", + start_time=datetime.datetime.now(datetime.timezone.utc), + ) + mock_reform_impacts_service.get_all_reform_impacts.side_effect = [ + [], + [provisional_impact], + ] + mock_reform_impacts_service.claim_lock.side_effect = TimeoutError( + "lock busy" + ) + + result = economy_service.get_economic_impact(**base_params) + + assert result.status == ImpactStatus.COMPUTING + mock_simulation_api.run.assert_not_called() + + def test__given_stale_provisional_claim__expires_and_recreates_simulation( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_policy_service, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + ): + stale_start_time = datetime.datetime.now( + datetime.timezone.utc + ) - datetime.timedelta(seconds=PROVISIONAL_CLAIM_TTL_SECONDS + 1) + stale_provisional_impact = create_mock_reform_impact( + status="computing", + execution_id=f"{PENDING_EXECUTION_ID_PREFIX}job_stale", + start_time=stale_start_time, + ) + mock_reform_impacts_service.get_all_reform_impacts.side_effect = [ + [stale_provisional_impact], + [stale_provisional_impact], + ] + + result = economy_service.get_economic_impact(**base_params) + + assert result.status == ImpactStatus.COMPUTING + mock_reform_impacts_service.set_error_reform_impact.assert_called_once_with( + country_id=MOCK_COUNTRY_ID, + policy_id=MOCK_POLICY_ID, + baseline_policy_id=MOCK_BASELINE_POLICY_ID, + region=MOCK_REGION, + dataset=MOCK_DATASET, + time_period=MOCK_TIME_PERIOD, + options_hash=MOCK_OPTIONS_HASH, + message=STALE_PROVISIONAL_IMPACT_MESSAGE, + execution_id=f"{PENDING_EXECUTION_ID_PREFIX}job_stale", + ) + mock_reform_impacts_service.set_reform_impact.assert_called_once() + mock_simulation_api.run.assert_called_once() + def test__given_runtime_cache_version__uses_versioned_economy_cache_key( self, economy_service, @@ -733,6 +856,47 @@ def test__given_no_impacts__returns_none( # Assert assert result is None + class TestGetExistingEconomicImpact: + @pytest.fixture + def economy_service(self): + return EconomyService() + + @pytest.fixture + def setup_options(self): + return EconomicImpactSetupOptions( + process_id=MOCK_PROCESS_ID, + country_id=MOCK_COUNTRY_ID, + reform_policy_id=MOCK_POLICY_ID, + baseline_policy_id=MOCK_BASELINE_POLICY_ID, + region=MOCK_REGION, + dataset=MOCK_DATASET, + time_period=MOCK_TIME_PERIOD, + options=MOCK_OPTIONS, + api_version=MOCK_API_VERSION, + target="general", + options_hash=MOCK_OPTIONS_HASH, + ) + + def test__given_stale_provisional_impact__returns_none( + self, + economy_service, + setup_options, + mock_reform_impacts_service, + ): + stale_impact = create_mock_reform_impact( + status="computing", + execution_id=f"{PENDING_EXECUTION_ID_PREFIX}job_stale", + start_time=datetime.datetime.now(datetime.timezone.utc) + - datetime.timedelta(seconds=PROVISIONAL_CLAIM_TTL_SECONDS + 1), + ) + mock_reform_impacts_service.get_all_reform_impacts.return_value = [ + stale_impact + ] + + result = economy_service._get_existing_economic_impact(setup_options) + + assert result is None + class TestDetermineImpactAction: @pytest.fixture def economy_service(self): @@ -764,6 +928,20 @@ def test__given_computing_status__returns_computing(self, economy_service): assert result == ImpactAction.COMPUTING + def test__given_stale_provisional_computing_status__returns_create( + self, economy_service + ): + impact = create_mock_reform_impact( + status="computing", + execution_id=f"{PENDING_EXECUTION_ID_PREFIX}job_stale", + start_time=datetime.datetime.now(datetime.timezone.utc) + - datetime.timedelta(seconds=PROVISIONAL_CLAIM_TTL_SECONDS + 1), + ) + + result = economy_service._determine_impact_action(impact) + + assert result == ImpactAction.CREATE + def test__given_unknown_status__raises_error(self, economy_service): impact = create_mock_reform_impact(status="unknown") @@ -844,6 +1022,21 @@ def test__given_active_state__returns_computing_result( assert result.status == ImpactStatus.COMPUTING assert result.data is None + def test__given_provisional_claim__returns_computing_without_polling( + self, economy_service, setup_options, mock_simulation_api, mock_logger + ): + reform_impact = create_mock_reform_impact( + status="computing", + execution_id=f"{PENDING_EXECUTION_ID_PREFIX}job_pending", + ) + + result = economy_service._handle_computing_impact( + setup_options, reform_impact + ) + + assert result.status == ImpactStatus.COMPUTING + mock_simulation_api.get_execution_by_id.assert_not_called() + def test__given_unknown_state__raises_error( self, economy_service, setup_options ): From f089af97fad464f34a8d2723081117695eaee58c Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Thu, 9 Apr 2026 10:41:27 -0400 Subject: [PATCH 07/13] Tighten budget window claim recovery --- policyengine_api/api.py | 8 ++- policyengine_api/data/data.py | 1 + policyengine_api/services/economy_service.py | 62 +++++++++++++++---- .../services/reform_impacts_service.py | 3 +- tests/fixtures/services/economy_service.py | 2 +- tests/unit/services/test_economy_service.py | 36 +++++++++++ 6 files changed, 95 insertions(+), 17 deletions(-) diff --git a/policyengine_api/api.py b/policyengine_api/api.py index 112cce9ac..eb3eba9ee 100644 --- a/policyengine_api/api.py +++ b/policyengine_api/api.py @@ -4,6 +4,7 @@ import time import sys +import os start_time = time.time() @@ -157,8 +158,11 @@ def log_timing(message): app.register_blueprint(user_profile_bp) log_timing("User profile routes registered") -app.route("/simulations", methods=["GET"])(get_simulations) -log_timing("Simulations endpoint registered") +if os.environ.get("FLASK_DEBUG") == "1": + app.route("/simulations", methods=["GET"])(get_simulations) + log_timing("Simulations endpoint registered") +else: + log_timing("Simulations endpoint skipped outside debug mode") app.register_blueprint(tracer_analysis_bp) log_timing("Tracer analysis routes registered") diff --git a/policyengine_api/data/data.py b/policyengine_api/data/data.py index a2c0dd1ae..fdaea4b9b 100644 --- a/policyengine_api/data/data.py +++ b/policyengine_api/data/data.py @@ -19,6 +19,7 @@ class _ResultProxy: Provides fetchone()/fetchall() with dict-like row access.""" def __init__(self, cursor_result): + self.rowcount = getattr(cursor_result, "rowcount", -1) try: # Use .mappings() so rows behave like dicts self._rows = list(cursor_result.mappings()) diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index a51af49dd..1071eed65 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -543,10 +543,10 @@ def _get_or_create_economic_impact( most_recent_impact=most_recent_impact, ) + stale_provisional_execution_id = None if self._is_stale_provisional_impact(most_recent_impact): - self._expire_stale_provisional_impact( - setup_options=setup_options, - most_recent_impact=most_recent_impact, + stale_provisional_execution_id = most_recent_impact.get( + "execution_id" ) provisional_execution_id = self._build_provisional_execution_id( @@ -556,6 +556,11 @@ def _get_or_create_economic_impact( setup_options=setup_options, execution_id=provisional_execution_id, ) + if stale_provisional_execution_id: + self._expire_stale_provisional_impact( + setup_options=setup_options, + execution_id=stale_provisional_execution_id, + ) except TimeoutError: logger.log_struct( { @@ -795,9 +800,8 @@ def _is_stale_provisional_impact(self, impact: dict | None) -> bool: def _expire_stale_provisional_impact( self, setup_options: EconomicImpactSetupOptions, - most_recent_impact: dict, + execution_id: str, ) -> None: - execution_id = most_recent_impact.get("execution_id") if not self._is_provisional_execution_id(execution_id): return @@ -971,11 +975,40 @@ def _handle_create_impact( } logger.log_struct(progress_log, severity="INFO") - self._update_reform_impact_execution_id( - setup_options=setup_options, - current_execution_id=provisional_execution_id, - new_execution_id=execution_id, - ) + try: + updated_rows = self._update_reform_impact_execution_id( + setup_options=setup_options, + current_execution_id=provisional_execution_id, + new_execution_id=execution_id, + ) + except Exception as error: + logger.log_struct( + { + "message": "Failed to promote provisional reform impact row; inserting replacement tracking row", + **setup_options.model_dump(), + "execution_id": execution_id, + "provisional_execution_id": provisional_execution_id, + "error": str(error), + }, + severity="WARNING", + ) + updated_rows = 0 + + if updated_rows != 1: + logger.log_struct( + { + "message": "Provisional reform impact row was not updated; inserting replacement tracking row", + **setup_options.model_dump(), + "execution_id": execution_id, + "provisional_execution_id": provisional_execution_id, + "updated_rows": updated_rows, + }, + severity="WARNING", + ) + self._set_reform_impact_computing( + setup_options=setup_options, + execution_id=execution_id, + ) return EconomicImpactResult.computing() @@ -1092,6 +1125,9 @@ def _set_reform_impact_computing( In the reform_impact table, set the status of the impact to "computing". """ try: + start_time = datetime.datetime.now(datetime.timezone.utc).replace( + tzinfo=None + ) reform_impacts_service.set_reform_impact( country_id=setup_options.country_id, policy_id=setup_options.reform_policy_id, @@ -1104,7 +1140,7 @@ def _set_reform_impact_computing( status=ImpactStatus.COMPUTING.value, api_version=setup_options.api_version, reform_impact_json=json.dumps({}), - start_time=datetime.datetime.now(), + start_time=start_time, execution_id=execution_id, ) except Exception as e: @@ -1121,9 +1157,9 @@ def _update_reform_impact_execution_id( setup_options: EconomicImpactSetupOptions, current_execution_id: str, new_execution_id: str, - ): + ) -> int | None: try: - reform_impacts_service.update_reform_impact_execution_id( + return reform_impacts_service.update_reform_impact_execution_id( country_id=setup_options.country_id, policy_id=setup_options.reform_policy_id, baseline_policy_id=setup_options.baseline_policy_id, diff --git a/policyengine_api/services/reform_impacts_service.py b/policyengine_api/services/reform_impacts_service.py index 05f5c756a..b1c1f41cc 100644 --- a/policyengine_api/services/reform_impacts_service.py +++ b/policyengine_api/services/reform_impacts_service.py @@ -183,7 +183,7 @@ def update_reform_impact_execution_id( "time_period = ? AND options_hash = ? AND dataset = ? AND " "execution_id = ? AND status = 'computing'" ) - database.query( + result = database.query( query, ( new_execution_id, @@ -197,6 +197,7 @@ def update_reform_impact_execution_id( current_execution_id, ), ) + return getattr(result, "rowcount", None) except Exception as e: print(f"Error updating reform impact execution id: {str(e)}") raise e diff --git a/tests/fixtures/services/economy_service.py b/tests/fixtures/services/economy_service.py index f02b4159a..14c566772 100644 --- a/tests/fixtures/services/economy_service.py +++ b/tests/fixtures/services/economy_service.py @@ -92,7 +92,7 @@ def mock_reform_impacts_service(): mock_service = MagicMock() mock_service.get_all_reform_impacts.return_value = [] mock_service.set_reform_impact.return_value = None - mock_service.update_reform_impact_execution_id.return_value = None + mock_service.update_reform_impact_execution_id.return_value = 1 mock_service.set_complete_reform_impact.return_value = None mock_service.set_error_reform_impact.return_value = None mock_service.claim_lock.side_effect = lambda **kwargs: nullcontext() diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index b4f94f02e..b6bc2b6c7 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -203,6 +203,10 @@ def test__given_no_previous_impact__creates_new_simulation( assert result.data is None mock_simulation_api.run.assert_called_once() mock_reform_impacts_service.set_reform_impact.assert_called_once() + assert any( + call.args == (datetime.timezone.utc,) + for call in mock_datetime.now.call_args_list + ) mock_reform_impacts_service.update_reform_impact_execution_id.assert_called_once_with( country_id=MOCK_COUNTRY_ID, policy_id=MOCK_POLICY_ID, @@ -353,6 +357,38 @@ def test__given_stale_provisional_claim__expires_and_recreates_simulation( mock_reform_impacts_service.set_reform_impact.assert_called_once() mock_simulation_api.run.assert_called_once() + def test__given_provisional_promotion_updates_zero_rows__inserts_replacement_tracking_row( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_policy_service, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + mock_datetime, + mock_numpy_random, + ): + mock_reform_impacts_service.get_all_reform_impacts.return_value = [] + mock_reform_impacts_service.update_reform_impact_execution_id.return_value = 0 + + result = economy_service.get_economic_impact(**base_params) + + assert result.status == ImpactStatus.COMPUTING + assert mock_reform_impacts_service.set_reform_impact.call_count == 2 + first_insert = mock_reform_impacts_service.set_reform_impact.call_args_list[ + 0 + ] + second_insert = ( + mock_reform_impacts_service.set_reform_impact.call_args_list[1] + ) + assert ( + first_insert.kwargs["execution_id"] + == f"{PENDING_EXECUTION_ID_PREFIX}{MOCK_PROCESS_ID}" + ) + assert second_insert.kwargs["execution_id"] == MOCK_EXECUTION_ID + def test__given_runtime_cache_version__uses_versioned_economy_cache_key( self, economy_service, From e34d4520c59326d2625b581b8a3a0c29de1d99e7 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Thu, 9 Apr 2026 10:55:41 -0400 Subject: [PATCH 08/13] Prevent stale claim takeover --- policyengine_api/endpoints/simulation.py | 4 +- policyengine_api/services/economy_service.py | 90 ++++++++++++++++++-- tests/unit/endpoints/test_simulation.py | 16 ++++ tests/unit/services/test_economy_service.py | 39 +++++++++ 4 files changed, 139 insertions(+), 10 deletions(-) create mode 100644 tests/unit/endpoints/test_simulation.py diff --git a/policyengine_api/endpoints/simulation.py b/policyengine_api/endpoints/simulation.py index 132e5b2d6..394f2a7e9 100644 --- a/policyengine_api/endpoints/simulation.py +++ b/policyengine_api/endpoints/simulation.py @@ -1,4 +1,4 @@ -from policyengine_api.data import local_database +from policyengine_api.data import database """ @@ -28,7 +28,7 @@ def get_simulations( desc_limit = f"DESC LIMIT {max_results}" if max_results is not None else "" - result = local_database.query( + result = database.query( f"SELECT * FROM reform_impact ORDER BY start_time {desc_limit}", ).fetchall() diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index 1071eed65..f9da31294 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -995,22 +995,96 @@ def _handle_create_impact( updated_rows = 0 if updated_rows != 1: + self._recover_failed_execution_id_promotion( + setup_options=setup_options, + provisional_execution_id=provisional_execution_id, + execution_id=execution_id, + updated_rows=updated_rows, + ) + + return EconomicImpactResult.computing() + + def _recover_failed_execution_id_promotion( + self, + *, + setup_options: EconomicImpactSetupOptions, + provisional_execution_id: str, + execution_id: str, + updated_rows: int | None, + ) -> None: + logger.log_struct( + { + "message": "Provisional reform impact row was not updated; checking whether tracking has already been superseded", + **setup_options.model_dump(), + "execution_id": execution_id, + "provisional_execution_id": provisional_execution_id, + "updated_rows": updated_rows, + }, + severity="WARNING", + ) + + try: + with reform_impacts_service.claim_lock( + country_id=setup_options.country_id, + policy_id=setup_options.reform_policy_id, + baseline_policy_id=setup_options.baseline_policy_id, + region=setup_options.region, + dataset=setup_options.dataset, + time_period=setup_options.time_period, + options_hash=setup_options.options_hash, + api_version=setup_options.api_version, + ): + most_recent_impact = self._get_most_recent_impact( + setup_options=setup_options + ) + if most_recent_impact is not None: + impact_status = most_recent_impact.get("status") + tracked_execution_id = most_recent_impact.get("execution_id") + if tracked_execution_id == execution_id: + return + + if ( + impact_status == ImpactStatus.COMPUTING.value + and tracked_execution_id == provisional_execution_id + ): + retry_updated_rows = self._update_reform_impact_execution_id( + setup_options=setup_options, + current_execution_id=provisional_execution_id, + new_execution_id=execution_id, + ) + if retry_updated_rows == 1: + return + elif impact_status in ( + ImpactStatus.OK.value, + ImpactStatus.COMPUTING.value, + ): + logger.log_struct( + { + "message": "Skipping replacement tracking row because another claim is already authoritative", + **setup_options.model_dump(), + "execution_id": execution_id, + "provisional_execution_id": provisional_execution_id, + "tracked_execution_id": tracked_execution_id, + "tracked_status": impact_status, + }, + severity="WARNING", + ) + return + + self._set_reform_impact_computing( + setup_options=setup_options, + execution_id=execution_id, + ) + except TimeoutError: logger.log_struct( { - "message": "Provisional reform impact row was not updated; inserting replacement tracking row", + "message": "Timed out while recovering failed provisional promotion; leaving the newer claim authoritative", **setup_options.model_dump(), "execution_id": execution_id, "provisional_execution_id": provisional_execution_id, - "updated_rows": updated_rows, }, severity="WARNING", ) - self._set_reform_impact_computing( - setup_options=setup_options, - execution_id=execution_id, - ) - - return EconomicImpactResult.computing() def _setup_sim_options( self, diff --git a/tests/unit/endpoints/test_simulation.py b/tests/unit/endpoints/test_simulation.py new file mode 100644 index 000000000..c29837eec --- /dev/null +++ b/tests/unit/endpoints/test_simulation.py @@ -0,0 +1,16 @@ +from unittest.mock import MagicMock, patch + +from policyengine_api.endpoints.simulation import get_simulations + + +def test_get_simulations_reads_from_shared_database(): + mock_database = MagicMock() + mock_database.query.return_value.fetchall.return_value = [{"id": 1}] + + with patch("policyengine_api.endpoints.simulation.database", mock_database): + result = get_simulations() + + mock_database.query.assert_called_once_with( + "SELECT * FROM reform_impact ORDER BY start_time DESC LIMIT 100", + ) + assert result == {"result": [{"id": 1}]} diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index b6bc2b6c7..0042a8887 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -389,6 +389,45 @@ def test__given_provisional_promotion_updates_zero_rows__inserts_replacement_tra ) assert second_insert.kwargs["execution_id"] == MOCK_EXECUTION_ID + def test__given_provisional_promotion_updates_zero_rows_but_newer_claim_exists__does_not_insert_fallback( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_policy_service, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + mock_datetime, + mock_numpy_random, + ): + replacement_impact = create_mock_reform_impact( + status="computing", + execution_id=f"{PENDING_EXECUTION_ID_PREFIX}job_replacement", + start_time=datetime.datetime.now(datetime.timezone.utc), + ) + mock_reform_impacts_service.get_all_reform_impacts.side_effect = [ + [], + [], + [replacement_impact], + ] + mock_reform_impacts_service.update_reform_impact_execution_id.return_value = 0 + + result = economy_service.get_economic_impact(**base_params) + + assert result.status == ImpactStatus.COMPUTING + assert mock_reform_impacts_service.set_reform_impact.call_count == 1 + inserted_execution_id = ( + mock_reform_impacts_service.set_reform_impact.call_args.kwargs[ + "execution_id" + ] + ) + assert ( + inserted_execution_id + == f"{PENDING_EXECUTION_ID_PREFIX}{MOCK_PROCESS_ID}" + ) + def test__given_runtime_cache_version__uses_versioned_economy_cache_key( self, economy_service, From d9dd0059083df6a5f0f32ab70baa678468b17211 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Thu, 9 Apr 2026 12:41:29 -0400 Subject: [PATCH 09/13] Backfill reform impact schema lazily --- .../services/reform_impacts_service.py | 43 ++++++++++++ .../services/test_reform_impacts_service.py | 66 +++++++++++++++++++ 2 files changed, 109 insertions(+) diff --git a/policyengine_api/services/reform_impacts_service.py b/policyengine_api/services/reform_impacts_service.py index b1c1f41cc..cf739417e 100644 --- a/policyengine_api/services/reform_impacts_service.py +++ b/policyengine_api/services/reform_impacts_service.py @@ -6,6 +6,7 @@ LOCAL_REFORM_IMPACT_LOCK = Lock() +REFORM_IMPACT_SCHEMA_LOCK = Lock() REFORM_IMPACT_LOCK_TIMEOUT_SECONDS = 5 @@ -15,6 +16,42 @@ class ReformImpactsService: this is connected to the shared reform_impact table. """ + def __init__(self): + self._schema_checked = False + + def _ensure_remote_schema(self) -> None: + if database.local or self._schema_checked: + return + + with REFORM_IMPACT_SCHEMA_LOCK: + if self._schema_checked: + return + + existing_columns = { + row["Field"] + for row in database.query("SHOW COLUMNS FROM reform_impact").fetchall() + } + required_columns = { + "execution_id": ( + "ALTER TABLE reform_impact " + "ADD COLUMN execution_id VARCHAR(255) NULL" + ), + "end_time": ( + "ALTER TABLE reform_impact ADD COLUMN end_time DATETIME NULL" + ), + } + + for column_name, alter_query in required_columns.items(): + if column_name in existing_columns: + continue + try: + database.query(alter_query) + except Exception as error: + if "Duplicate column name" not in str(error): + raise + + self._schema_checked = True + def _build_lock_name( self, country_id, @@ -96,6 +133,7 @@ def get_all_reform_impacts( api_version, ): try: + self._ensure_remote_schema() query = ( "SELECT reform_impact_json, status, message, start_time, execution_id FROM " "reform_impact WHERE country_id = ? AND reform_policy_id = ? AND " @@ -137,6 +175,7 @@ def set_reform_impact( execution_id: str, ): try: + self._ensure_remote_schema() query = ( "INSERT INTO reform_impact (country_id, reform_policy_id, baseline_policy_id, " "region, dataset, time_period, options_json, options_hash, status, api_version, " @@ -177,6 +216,7 @@ def update_reform_impact_execution_id( new_execution_id, ): try: + self._ensure_remote_schema() query = ( "UPDATE reform_impact SET execution_id = ? WHERE country_id = ? AND " "reform_policy_id = ? AND baseline_policy_id = ? AND region = ? AND " @@ -213,6 +253,7 @@ def delete_reform_impact( options_hash, ): try: + self._ensure_remote_schema() query = ( "DELETE FROM reform_impact WHERE country_id = ? AND " "reform_policy_id = ? AND baseline_policy_id = ? AND " @@ -249,6 +290,7 @@ def set_error_reform_impact( execution_id: str, ): try: + self._ensure_remote_schema() query = ( "UPDATE reform_impact SET status = ?, message = ?, end_time = ? WHERE " "country_id = ? AND reform_policy_id = ? AND baseline_policy_id = ? AND " @@ -293,6 +335,7 @@ def set_complete_reform_impact( execution_id, ): try: + self._ensure_remote_schema() query = ( "UPDATE reform_impact SET status = ?, message = ?, end_time = ?, " "reform_impact_json = ? WHERE country_id = ? AND reform_policy_id = ? AND " diff --git a/tests/unit/services/test_reform_impacts_service.py b/tests/unit/services/test_reform_impacts_service.py index 4f44b63a2..4456dffad 100644 --- a/tests/unit/services/test_reform_impacts_service.py +++ b/tests/unit/services/test_reform_impacts_service.py @@ -6,6 +6,72 @@ class TestReformImpactsService: + def test__given_remote_database_missing_columns__ensure_remote_schema_adds_them( + self, monkeypatch + ): + service = ReformImpactsService() + + show_columns_result = MagicMock() + show_columns_result.fetchall.return_value = [ + {"Field": "reform_impact_id"}, + {"Field": "status"}, + {"Field": "start_time"}, + ] + alter_execution_result = MagicMock() + alter_end_time_result = MagicMock() + + mock_database = MagicMock() + mock_database.local = False + mock_database.query.side_effect = [ + show_columns_result, + alter_execution_result, + alter_end_time_result, + ] + + monkeypatch.setattr( + "policyengine_api.services.reform_impacts_service.database", + mock_database, + ) + + service._ensure_remote_schema() + + assert mock_database.query.call_args_list[0].args == ( + "SHOW COLUMNS FROM reform_impact", + ) + assert mock_database.query.call_args_list[1].args == ( + "ALTER TABLE reform_impact ADD COLUMN execution_id VARCHAR(255) NULL", + ) + assert mock_database.query.call_args_list[2].args == ( + "ALTER TABLE reform_impact ADD COLUMN end_time DATETIME NULL", + ) + + def test__given_remote_database_existing_columns__ensure_remote_schema_skips_alter( + self, monkeypatch + ): + service = ReformImpactsService() + + show_columns_result = MagicMock() + show_columns_result.fetchall.return_value = [ + {"Field": "reform_impact_id"}, + {"Field": "status"}, + {"Field": "start_time"}, + {"Field": "execution_id"}, + {"Field": "end_time"}, + ] + + mock_database = MagicMock() + mock_database.local = False + mock_database.query.return_value = show_columns_result + + monkeypatch.setattr( + "policyengine_api.services.reform_impacts_service.database", + mock_database, + ) + + service._ensure_remote_schema() + + mock_database.query.assert_called_once_with("SHOW COLUMNS FROM reform_impact") + def test__given_remote_database__claim_lock_uses_advisory_lock(self, monkeypatch): service = ReformImpactsService() From e0b62cb1e62e5ec251954ff16160e8ac500df32d Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Fri, 10 Apr 2026 09:18:36 -0400 Subject: [PATCH 10/13] Backfill reform impact dataset column --- policyengine_api/services/reform_impacts_service.py | 4 ++++ tests/unit/services/test_reform_impacts_service.py | 8 +++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/policyengine_api/services/reform_impacts_service.py b/policyengine_api/services/reform_impacts_service.py index cf739417e..27fca80bd 100644 --- a/policyengine_api/services/reform_impacts_service.py +++ b/policyengine_api/services/reform_impacts_service.py @@ -32,6 +32,10 @@ def _ensure_remote_schema(self) -> None: for row in database.query("SHOW COLUMNS FROM reform_impact").fetchall() } required_columns = { + "dataset": ( + "ALTER TABLE reform_impact " + "ADD COLUMN dataset VARCHAR(255) NOT NULL DEFAULT 'default'" + ), "execution_id": ( "ALTER TABLE reform_impact " "ADD COLUMN execution_id VARCHAR(255) NULL" diff --git a/tests/unit/services/test_reform_impacts_service.py b/tests/unit/services/test_reform_impacts_service.py index 4456dffad..106cf8757 100644 --- a/tests/unit/services/test_reform_impacts_service.py +++ b/tests/unit/services/test_reform_impacts_service.py @@ -17,6 +17,7 @@ def test__given_remote_database_missing_columns__ensure_remote_schema_adds_them( {"Field": "status"}, {"Field": "start_time"}, ] + alter_dataset_result = MagicMock() alter_execution_result = MagicMock() alter_end_time_result = MagicMock() @@ -24,6 +25,7 @@ def test__given_remote_database_missing_columns__ensure_remote_schema_adds_them( mock_database.local = False mock_database.query.side_effect = [ show_columns_result, + alter_dataset_result, alter_execution_result, alter_end_time_result, ] @@ -39,9 +41,12 @@ def test__given_remote_database_missing_columns__ensure_remote_schema_adds_them( "SHOW COLUMNS FROM reform_impact", ) assert mock_database.query.call_args_list[1].args == ( - "ALTER TABLE reform_impact ADD COLUMN execution_id VARCHAR(255) NULL", + "ALTER TABLE reform_impact ADD COLUMN dataset VARCHAR(255) NOT NULL DEFAULT 'default'", ) assert mock_database.query.call_args_list[2].args == ( + "ALTER TABLE reform_impact ADD COLUMN execution_id VARCHAR(255) NULL", + ) + assert mock_database.query.call_args_list[3].args == ( "ALTER TABLE reform_impact ADD COLUMN end_time DATETIME NULL", ) @@ -55,6 +60,7 @@ def test__given_remote_database_existing_columns__ensure_remote_schema_skips_alt {"Field": "reform_impact_id"}, {"Field": "status"}, {"Field": "start_time"}, + {"Field": "dataset"}, {"Field": "execution_id"}, {"Field": "end_time"}, ] From 7f1732949cebe81f3f9648d30aa5cbdf35af7c37 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Fri, 10 Apr 2026 09:46:20 -0400 Subject: [PATCH 11/13] Address budget window review findings --- policyengine_api/data/__init__.py | 7 +- policyengine_api/data/data.py | 8 ++ policyengine_api/endpoints/simulation.py | 12 +- policyengine_api/openapi_spec.yaml | 132 +++++++++++++++++++ policyengine_api/services/economy_service.py | 55 ++++++-- tests/unit/endpoints/test_simulation.py | 7 +- tests/unit/services/test_economy_service.py | 34 ++++- 7 files changed, 238 insertions(+), 17 deletions(-) diff --git a/policyengine_api/data/__init__.py b/policyengine_api/data/__init__.py index 15673afdb..94703ee36 100644 --- a/policyengine_api/data/__init__.py +++ b/policyengine_api/data/__init__.py @@ -1 +1,6 @@ -from .data import PolicyEngineDatabase, database, local_database +from .data import ( + PolicyEngineDatabase, + database, + get_remote_database, + local_database, +) diff --git a/policyengine_api/data/data.py b/policyengine_api/data/data.py index fdaea4b9b..ad521e386 100644 --- a/policyengine_api/data/data.py +++ b/policyengine_api/data/data.py @@ -199,3 +199,11 @@ def initialize(self): database = PolicyEngineDatabase(local=False, initialize=False) local_database = PolicyEngineDatabase(local=True, initialize=False) +remote_database = None + + +def get_remote_database() -> PolicyEngineDatabase: + global remote_database + if remote_database is None: + remote_database = PolicyEngineDatabase(local=False, initialize=False) + return remote_database diff --git a/policyengine_api/endpoints/simulation.py b/policyengine_api/endpoints/simulation.py index 394f2a7e9..be14e115f 100644 --- a/policyengine_api/endpoints/simulation.py +++ b/policyengine_api/endpoints/simulation.py @@ -1,4 +1,4 @@ -from policyengine_api.data import database +from policyengine_api.data import get_remote_database """ @@ -28,9 +28,13 @@ def get_simulations( desc_limit = f"DESC LIMIT {max_results}" if max_results is not None else "" - result = database.query( - f"SELECT * FROM reform_impact ORDER BY start_time {desc_limit}", - ).fetchall() + result = ( + get_remote_database() + .query( + f"SELECT * FROM reform_impact ORDER BY start_time {desc_limit}", + ) + .fetchall() + ) # Format into [{}] diff --git a/policyengine_api/openapi_spec.yaml b/policyengine_api/openapi_spec.yaml index a49268c8c..77daadc9e 100644 --- a/policyengine_api/openapi_spec.yaml +++ b/policyengine_api/openapi_spec.yaml @@ -660,6 +660,138 @@ paths: type: string message: type: string + /{country_id}/economy/{policy_id}/over/{baseline_policy_id}/budget-window: + get: + summary: Calculate budget-window economic impacts + operationId: get_budget_window_economic_impact + description: Calculate annual and total budget impacts for a policy over a multi-year budget window. + parameters: + - name: country_id + in: path + description: The country ID. + required: true + schema: + type: string + - name: policy_id + in: path + description: The reform policy ID. + required: true + schema: + type: string + - name: baseline_policy_id + in: path + description: The baseline policy ID. + required: true + schema: + type: string + - name: region + in: query + description: The sub-national region. + required: true + schema: + type: string + - name: start_year + in: query + description: First year in the budget window. + required: true + schema: + type: string + - name: window_size + in: query + description: Number of years to include in the budget window. + required: true + schema: + type: integer + - name: dataset + in: query + description: Dataset selection. + required: false + schema: + type: string + default: default + - name: version + in: query + description: API version number. + required: false + schema: + type: string + - name: include_district_breakdowns + in: query + description: Whether to include congressional district breakdowns for US national simulations. + required: false + schema: + type: boolean + default: false + - name: target + in: query + description: Impact target. Budget-window calculations only support general impacts. + required: false + schema: + type: string + default: general + responses: + 200: + description: Budget-window economic impact, progress, or error state. + content: + application/json: + schema: + type: object + properties: + status: + type: string + enum: + - ok + - computing + - error + message: + type: string + nullable: true + result: + type: object + nullable: true + progress: + type: integer + nullable: true + completed_years: + type: array + items: + type: string + computing_years: + type: array + items: + type: string + queued_years: + type: array + items: + type: string + error: + type: string + nullable: true + 400: + description: Invalid budget-window request. + content: + application/json: + schema: + type: object + properties: + status: + type: string + message: + type: string + result: + type: object + nullable: true + 404: + description: Invalid country ID. + content: + text/html: + schema: + type: object + properties: + status: + type: string + message: + type: string /{country_id}/analysis: post: summary: Get or trigger policy analysis diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index f9da31294..f525baf9e 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -44,6 +44,7 @@ class ImpactAction(Enum): COMPLETED = "completed" COMPUTING = "computing" CREATE = "create" + ERROR = "error" class ImpactStatus(Enum): @@ -499,6 +500,19 @@ def _get_or_create_economic_impact( most_recent_impact=most_recent_impact, ) + if impact_action == ImpactAction.ERROR: + logger.log_struct( + { + "message": "Found failed economic impact in db; returning error", + **setup_options.model_dump(), + }, + severity="INFO", + ) + return self._handle_error_impact( + setup_options=setup_options, + most_recent_impact=most_recent_impact, + ) + if impact_action == ImpactAction.CREATE: try: with reform_impacts_service.claim_lock( @@ -543,6 +557,19 @@ def _get_or_create_economic_impact( most_recent_impact=most_recent_impact, ) + if impact_action == ImpactAction.ERROR: + logger.log_struct( + { + "message": "Found failed economic impact in db after locking; returning error", + **setup_options.model_dump(), + }, + severity="INFO", + ) + return self._handle_error_impact( + setup_options=setup_options, + most_recent_impact=most_recent_impact, + ) + stale_provisional_execution_id = None if self._is_stale_provisional_impact(most_recent_impact): stale_provisional_execution_id = most_recent_impact.get( @@ -599,13 +626,9 @@ def _get_existing_economic_impact( status = most_recent_impact.get("status") if status == ImpactStatus.ERROR.value: - error_message = most_recent_impact.get("message") or ( - f"Economic impact failed for {setup_options.time_period}" - ) - return EconomicImpactResult( - status=ImpactStatus.ERROR, - data=None, - message=error_message, + return self._handle_error_impact( + setup_options=setup_options, + most_recent_impact=most_recent_impact, ) if status == ImpactStatus.OK.value: @@ -820,8 +843,10 @@ def _determine_impact_action( return ImpactAction.CREATE status = most_recent_impact.get("status") - if status in [ImpactStatus.OK.value, ImpactStatus.ERROR.value]: + if status == ImpactStatus.OK.value: return ImpactAction.COMPLETED + elif status == ImpactStatus.ERROR.value: + return ImpactAction.ERROR elif status == ImpactStatus.COMPUTING.value: if self._is_stale_provisional_impact(most_recent_impact): return ImpactAction.CREATE @@ -895,6 +920,20 @@ def _handle_completed_impact( data=json.loads(most_recent_impact["reform_impact_json"]) ) + def _handle_error_impact( + self, + setup_options: EconomicImpactSetupOptions, + most_recent_impact: dict, + ) -> EconomicImpactResult: + error_message = most_recent_impact.get("message") or ( + f"Economic impact failed for {setup_options.time_period}" + ) + return EconomicImpactResult( + status=ImpactStatus.ERROR, + data=None, + message=error_message, + ) + def _handle_computing_impact( self, setup_options: EconomicImpactSetupOptions, diff --git a/tests/unit/endpoints/test_simulation.py b/tests/unit/endpoints/test_simulation.py index c29837eec..a9013a056 100644 --- a/tests/unit/endpoints/test_simulation.py +++ b/tests/unit/endpoints/test_simulation.py @@ -3,11 +3,14 @@ from policyengine_api.endpoints.simulation import get_simulations -def test_get_simulations_reads_from_shared_database(): +def test_get_simulations_reads_from_remote_database(): mock_database = MagicMock() mock_database.query.return_value.fetchall.return_value = [{"id": 1}] - with patch("policyengine_api.endpoints.simulation.database", mock_database): + with patch( + "policyengine_api.endpoints.simulation.get_remote_database", + return_value=mock_database, + ): result = get_simulations() mock_database.query.assert_called_once_with( diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index 0042a8887..d47e9bbee 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -102,6 +102,36 @@ def test__given_completed_impact__returns_completed_result( mock_reform_impacts_service.get_all_reform_impacts.assert_called_once() mock_simulation_api.run.assert_not_called() + def test__given_error_impact__returns_error_result( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_policy_service, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + mock_datetime, + mock_numpy_random, + ): + error_impact = create_mock_reform_impact( + status="error", + reform_impact_json=json.dumps({}), + ) + error_impact["message"] = "Failed to start simulation API job" + mock_reform_impacts_service.get_all_reform_impacts.return_value = [ + error_impact + ] + + result = economy_service.get_economic_impact(**base_params) + + assert result.status == ImpactStatus.ERROR + assert result.data is None + assert result.message == "Failed to start simulation API job" + mock_reform_impacts_service.get_all_reform_impacts.assert_called_once() + mock_simulation_api.run.assert_not_called() + def test__given_computing_impact_with_succeeded_execution__returns_completed_result( self, economy_service, @@ -989,12 +1019,12 @@ def test__given_ok_status__returns_completed(self, economy_service): assert result == ImpactAction.COMPLETED - def test__given_error_status__returns_completed(self, economy_service): + def test__given_error_status__returns_error(self, economy_service): impact = create_mock_reform_impact(status="error") result = economy_service._determine_impact_action(impact) - assert result == ImpactAction.COMPLETED + assert result == ImpactAction.ERROR def test__given_computing_status__returns_computing(self, economy_service): impact = create_mock_reform_impact(status="computing") From 6035d0ec27869452ec4609361f6de0338cd8d7ee Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Fri, 10 Apr 2026 10:03:09 -0400 Subject: [PATCH 12/13] Avoid stale Utah macro cache in CI --- tests/to_refactor/python/test_us_policy_macro.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/to_refactor/python/test_us_policy_macro.py b/tests/to_refactor/python/test_us_policy_macro.py index f4228523c..2b3499993 100644 --- a/tests/to_refactor/python/test_us_policy_macro.py +++ b/tests/to_refactor/python/test_us_policy_macro.py @@ -47,7 +47,11 @@ def utah_reform_runner(rest_client, region: str = "us"): policy_id = policy_create.json["result"]["policy_id"] assert policy_id is not None - query = f"/us/economy/{policy_id}/over/{default_policy}?region={region}&time_period={test_year}" + cache_buster = int(time.time() * 1000) + query = ( + f"/us/economy/{policy_id}/over/{default_policy}" + f"?region={region}&time_period={test_year}&test_run={cache_buster}" + ) economy_response = rest_client.get(query) assert economy_response.status_code == 200 assert economy_response.json["status"] == "computing", ( From b82ace8c2c91f82e1d25aa30a5103c7cc2a19161 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Fri, 10 Apr 2026 10:38:28 -0400 Subject: [PATCH 13/13] Mark pre-submission setup failures as errors --- policyengine_api/services/economy_service.py | 68 ++++++++++---------- tests/unit/services/test_economy_service.py | 41 ++++++++++++ 2 files changed, 75 insertions(+), 34 deletions(-) diff --git a/policyengine_api/services/economy_service.py b/policyengine_api/services/economy_service.py index f525baf9e..165e3f03d 100644 --- a/policyengine_api/services/economy_service.py +++ b/policyengine_api/services/economy_service.py @@ -958,44 +958,44 @@ def _handle_create_impact( provisional_execution_id: str, ) -> EconomicImpactResult: - baseline_policy = policy_service.get_policy_json( - setup_options.country_id, setup_options.baseline_policy_id - ) - reform_policy = policy_service.get_policy_json( - setup_options.country_id, setup_options.reform_policy_id - ) + try: + baseline_policy = policy_service.get_policy_json( + setup_options.country_id, setup_options.baseline_policy_id + ) + reform_policy = policy_service.get_policy_json( + setup_options.country_id, setup_options.reform_policy_id + ) - sim_config: SimulationOptions = self._setup_sim_options( - country_id=setup_options.country_id, - reform_policy=reform_policy, - baseline_policy=baseline_policy, - region=setup_options.region, - time_period=setup_options.time_period, - dataset=setup_options.dataset, - scope="macro", - include_cliffs=setup_options.target == "cliff", - model_version=setup_options.model_version, - data_version=setup_options.data_version, - ) + sim_config: SimulationOptions = self._setup_sim_options( + country_id=setup_options.country_id, + reform_policy=reform_policy, + baseline_policy=baseline_policy, + region=setup_options.region, + time_period=setup_options.time_period, + dataset=setup_options.dataset, + scope="macro", + include_cliffs=setup_options.target == "cliff", + model_version=setup_options.model_version, + data_version=setup_options.data_version, + ) - logger.log_struct( - { - "message": "Setting up sim API job", - **setup_options.model_dump(), - } - ) + logger.log_struct( + { + "message": "Setting up sim API job", + **setup_options.model_dump(), + } + ) - # Build params with metadata for Logfire tracing in the simulation API. - # The _metadata field will be captured by the Logfire span before - # SimulationOptions validation (which silently ignores extra fields). - sim_params = sim_config.model_dump() - sim_params["_metadata"] = { - "reform_policy_id": setup_options.reform_policy_id, - "baseline_policy_id": setup_options.baseline_policy_id, - "process_id": setup_options.process_id, - } + # Build params with metadata for Logfire tracing in the simulation API. + # The _metadata field will be captured by the Logfire span before + # SimulationOptions validation (which silently ignores extra fields). + sim_params = sim_config.model_dump() + sim_params["_metadata"] = { + "reform_policy_id": setup_options.reform_policy_id, + "baseline_policy_id": setup_options.baseline_policy_id, + "process_id": setup_options.process_id, + } - try: sim_api_execution = simulation_api.run(sim_params) execution_id = simulation_api.get_execution_id(sim_api_execution) except Exception as error: diff --git a/tests/unit/services/test_economy_service.py b/tests/unit/services/test_economy_service.py index d47e9bbee..dbe25ffdf 100644 --- a/tests/unit/services/test_economy_service.py +++ b/tests/unit/services/test_economy_service.py @@ -316,6 +316,47 @@ def test__given_simulation_api_submission_failure__marks_provisional_claim_error ) mock_reform_impacts_service.update_reform_impact_execution_id.assert_not_called() + def test__given_simulation_setup_failure__marks_provisional_claim_error( + self, + economy_service, + base_params, + mock_country_package_versions, + mock_get_dataset_version, + mock_policy_service, + mock_reform_impacts_service, + mock_simulation_api, + mock_logger, + mock_datetime, + mock_numpy_random, + ): + mock_reform_impacts_service.get_all_reform_impacts.return_value = [] + with patch.object( + economy_service, + "_setup_sim_options", + side_effect=ValueError("Invalid US state: 'zz'"), + ): + result = economy_service.get_economic_impact(**base_params) + + assert result.status == ImpactStatus.ERROR + assert ( + result.message + == "Failed to start simulation API job: Invalid US state: 'zz'" + ) + mock_reform_impacts_service.set_reform_impact.assert_called_once() + mock_reform_impacts_service.set_error_reform_impact.assert_called_once_with( + country_id=MOCK_COUNTRY_ID, + policy_id=MOCK_POLICY_ID, + baseline_policy_id=MOCK_BASELINE_POLICY_ID, + region=MOCK_REGION, + dataset=MOCK_DATASET, + time_period=MOCK_TIME_PERIOD, + options_hash=MOCK_OPTIONS_HASH, + message="Failed to start simulation API job: Invalid US state: 'zz'", + execution_id=f"{PENDING_EXECUTION_ID_PREFIX}{MOCK_PROCESS_ID}", + ) + mock_simulation_api.run.assert_not_called() + mock_reform_impacts_service.update_reform_impact_execution_id.assert_not_called() + def test__given_claim_lock_timeout_and_existing_provisional_claim__returns_computing( self, economy_service,