Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/release-manifest-consumer.changed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Align the bundled UK release manifest with the pinned `policyengine-uk` package version and updated data package revisions.
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ dependencies = [
[project.optional-dependencies]
uk = [
"policyengine_core>=3.23.6",
"policyengine-uk>=2.51.0",
"policyengine-uk==2.78.0",
]
us = [
"policyengine_core>=3.23.6",
"policyengine-us>=1.213.1",
"policyengine-us==1.602.0",
]
dev = [
"pytest",
Expand All @@ -45,8 +45,8 @@ dev = [
"pytest-asyncio>=0.26.0",
"ruff>=0.9.0",
"policyengine_core>=3.23.6",
"policyengine-uk>=2.51.0",
"policyengine-us>=1.213.1",
"policyengine-uk==2.78.0",
"policyengine-us==1.602.0",
"towncrier>=24.8.0",
"mypy>=1.11.0",
"pytest-cov>=5.0.0",
Expand Down
7 changes: 7 additions & 0 deletions src/policyengine/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@
from .region import Region as Region
from .region import RegionRegistry as RegionRegistry
from .region import RegionType as RegionType
from .release_manifest import CountryReleaseManifest as CountryReleaseManifest
from .release_manifest import DataPackageVersion as DataPackageVersion
from .release_manifest import DataReleaseArtifact as DataReleaseArtifact
from .release_manifest import DataReleaseManifest as DataReleaseManifest
from .release_manifest import PackageVersion as PackageVersion
from .release_manifest import get_data_release_manifest as get_data_release_manifest
from .release_manifest import get_release_manifest as get_release_manifest
from .scoping_strategy import RegionScopingStrategy as RegionScopingStrategy
from .scoping_strategy import RowFilterStrategy as RowFilterStrategy
from .scoping_strategy import ScopingStrategy as ScopingStrategy
Expand Down
178 changes: 178 additions & 0 deletions src/policyengine/core/release_manifest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import os
from functools import lru_cache
from importlib.resources import files
from pathlib import Path

import requests
from pydantic import BaseModel, Field

HF_REQUEST_TIMEOUT_SECONDS = 30


class PackageVersion(BaseModel):
name: str
version: str


class DataPackageVersion(PackageVersion):
repo_id: str
repo_type: str = "model"
release_manifest_path: str = "release_manifest.json"


class CompatibleModelPackage(BaseModel):
name: str
specifier: str


class ArtifactPathReference(BaseModel):
path: str


class ArtifactPathTemplate(BaseModel):
path_template: str

def resolve(self, **kwargs: str) -> str:
return self.path_template.format(**kwargs)


class DataReleaseArtifact(BaseModel):
kind: str
path: str
repo_id: str
revision: str
sha256: str | None = None
size_bytes: int | None = None

@property
def uri(self) -> str:
return build_hf_uri(
repo_id=self.repo_id,
path_in_repo=self.path,
revision=self.revision,
)


class DataReleaseManifest(BaseModel):
schema_version: int
data_package: PackageVersion
compatible_model_packages: list[CompatibleModelPackage] = Field(
default_factory=list
)
default_datasets: dict[str, str] = Field(default_factory=dict)
artifacts: dict[str, DataReleaseArtifact] = Field(default_factory=dict)


class CountryReleaseManifest(BaseModel):
country_id: str
policyengine_version: str
model_package: PackageVersion
data_package: DataPackageVersion
default_dataset: str
datasets: dict[str, ArtifactPathReference] = Field(default_factory=dict)
region_datasets: dict[str, ArtifactPathTemplate] = Field(default_factory=dict)

@property
def default_dataset_uri(self) -> str:
return resolve_dataset_reference(self.country_id, self.default_dataset)


def build_hf_uri(repo_id: str, path_in_repo: str, revision: str) -> str:
return f"hf://{repo_id}/{path_in_repo}@{revision}"


@lru_cache
def get_release_manifest(country_id: str) -> CountryReleaseManifest:
manifest_path = files("policyengine").joinpath(
"data", "release_manifests", f"{country_id}.json"
)
if not manifest_path.is_file():
raise ValueError(f"No bundled release manifest for country '{country_id}'")

return CountryReleaseManifest.model_validate_json(manifest_path.read_text())


def _data_release_manifest_url(data_package: DataPackageVersion) -> str:
return (
"https://huggingface.co/"
f"{data_package.repo_id}/resolve/{data_package.version}/"
f"{data_package.release_manifest_path}"
)


@lru_cache
def get_data_release_manifest(country_id: str) -> DataReleaseManifest:
country_manifest = get_release_manifest(country_id)
data_package = country_manifest.data_package

headers = {}
token = os.environ.get("HUGGING_FACE_TOKEN")
if token:
headers["Authorization"] = f"Bearer {token}"

response = requests.get(
_data_release_manifest_url(data_package),
headers=headers,
timeout=HF_REQUEST_TIMEOUT_SECONDS,
)
if response.status_code in (401, 403):
raise ValueError(
"Could not fetch the data release manifest from Hugging Face. "
"If this country uses a private data repo, set HUGGING_FACE_TOKEN."
)
response.raise_for_status()
return DataReleaseManifest.model_validate_json(response.text)


def resolve_dataset_reference(country_id: str, dataset: str) -> str:
if "://" in dataset:
return dataset

manifest = get_release_manifest(country_id)
path_reference = manifest.datasets.get(dataset)
if path_reference is not None:
return build_hf_uri(
repo_id=manifest.data_package.repo_id,
path_in_repo=path_reference.path,
revision=manifest.data_package.version,
)

data_release_manifest = get_data_release_manifest(country_id)
artifact = data_release_manifest.artifacts.get(dataset)
if artifact is None:
raise ValueError(
f"Unknown dataset '{dataset}' for country '{country_id}'. "
f"Known datasets: {sorted(manifest.datasets)}"
)

return artifact.uri


def dataset_logical_name(dataset: str) -> str:
return Path(dataset.rsplit("@", 1)[0]).stem


def resolve_default_datasets(country_id: str) -> list[str]:
manifest = get_release_manifest(country_id)
return list(manifest.datasets.keys())


def resolve_region_dataset_path(
country_id: str,
region_type: str,
**kwargs: str,
) -> str | None:
manifest = get_release_manifest(country_id)
template = manifest.region_datasets.get(region_type)
if template is None:
return None

resolved_path = template.resolve(**kwargs)
if "://" in resolved_path:
return resolved_path

return build_hf_uri(
repo_id=manifest.data_package.repo_id,
path_in_repo=resolved_path,
revision=manifest.data_package.version,
)
14 changes: 14 additions & 0 deletions src/policyengine/core/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,17 @@ def save(self):
def load(self):
"""Load the simulation's output dataset."""
self.tax_benefit_model_version.load(self)

@property
def release_bundle(self) -> dict[str, str | None]:
bundle = (
self.tax_benefit_model_version.release_bundle
if self.tax_benefit_model_version is not None
else {}
)
return {
**bundle,
"dataset_filepath": self.dataset.filepath
if self.dataset is not None
else None,
}
30 changes: 30 additions & 0 deletions src/policyengine/core/tax_benefit_model_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from pydantic import BaseModel, Field

from .release_manifest import CountryReleaseManifest, PackageVersion
from .tax_benefit_model import TaxBenefitModel

if TYPE_CHECKING:
Expand Down Expand Up @@ -32,6 +33,13 @@ class TaxBenefitModelVersion(BaseModel):
region_registry: "RegionRegistry | None" = Field(
default=None, description="Registry of supported geographic regions"
)
release_manifest: CountryReleaseManifest | None = Field(
default=None,
exclude=True,
)
model_package: PackageVersion | None = Field(default=None)
data_package: PackageVersion | None = Field(default=None)
default_dataset_uri: str | None = Field(default=None)

@property
def parameter_values(self) -> list["ParameterValue"]:
Expand Down Expand Up @@ -116,6 +124,28 @@ def get_region(self, code: str) -> "Region | None":
return None
return self.region_registry.get(code)

@property
def release_bundle(self) -> dict[str, str | None]:
return {
"country_id": self.release_manifest.country_id
if self.release_manifest is not None
else None,
"policyengine_version": self.release_manifest.policyengine_version
if self.release_manifest is not None
else None,
"model_package": self.model_package.name
if self.model_package is not None
else None,
"model_version": self.version,
"data_package": self.data_package.name
if self.data_package is not None
else None,
"data_version": self.data_package.version
if self.data_package is not None
else None,
"default_dataset_uri": self.default_dataset_uri,
}

def __repr__(self) -> str:
# Give the id and version, and the number of variables, parameters, parameter nodes, parameter values
return f"<TaxBenefitModelVersion id={self.id} variables={len(self.variables)} parameters={len(self.parameters)} parameter_nodes={len(self.parameter_nodes)} parameter_values={len(self.parameter_values)}>"
3 changes: 2 additions & 1 deletion src/policyengine/countries/uk/regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import TYPE_CHECKING

from policyengine.core.region import Region, RegionRegistry
from policyengine.core.release_manifest import resolve_region_dataset_path
from policyengine.core.scoping_strategy import (
RowFilterStrategy,
WeightReplacementStrategy,
Expand Down Expand Up @@ -127,7 +128,7 @@ def build_uk_region_registry(
code="uk",
label="United Kingdom",
region_type="national",
dataset_path=f"{UK_DATA_BUCKET}/enhanced_frs_2023_24.h5",
dataset_path=resolve_region_dataset_path("uk", "national"),
)
)

Expand Down
15 changes: 12 additions & 3 deletions src/policyengine/countries/us/regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""

from policyengine.core.region import Region, RegionRegistry
from policyengine.core.release_manifest import resolve_region_dataset_path
from policyengine.core.scoping_strategy import RowFilterStrategy

from .data import AT_LARGE_STATES, DISTRICT_COUNTS, US_PLACES, US_STATES
Expand Down Expand Up @@ -40,7 +41,7 @@ def build_us_region_registry() -> RegionRegistry:
code="us",
label="United States",
region_type="national",
dataset_path=f"{US_DATA_BUCKET}/enhanced_cps_2024.h5",
dataset_path=resolve_region_dataset_path("us", "national"),
)
)

Expand All @@ -52,7 +53,11 @@ def build_us_region_registry() -> RegionRegistry:
label=name,
region_type="state",
parent_code="us",
dataset_path=f"{US_DATA_BUCKET}/states/{abbrev}.h5",
dataset_path=resolve_region_dataset_path(
"us",
"state",
state_code=abbrev,
),
state_code=abbrev,
state_name=name,
)
Expand All @@ -76,7 +81,11 @@ def build_us_region_registry() -> RegionRegistry:
label=label,
region_type="congressional_district",
parent_code=f"state/{state_abbrev.lower()}",
dataset_path=f"{US_DATA_BUCKET}/districts/{district_code}.h5",
dataset_path=resolve_region_dataset_path(
"us",
"congressional_district",
district_code=district_code,
),
state_code=state_abbrev,
state_name=state_name,
)
Expand Down
27 changes: 27 additions & 0 deletions src/policyengine/data/release_manifests/uk.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
{
"country_id": "uk",
"policyengine_version": "3.4.1",
"model_package": {
"name": "policyengine-uk",
"version": "2.78.0"
},
"data_package": {
"name": "policyengine-uk-data",
"version": "1.40.3",
"repo_id": "policyengine/policyengine-uk-data-private"
},
"default_dataset": "enhanced_frs_2023_24",
"datasets": {
"frs_2023_24": {
"path": "frs_2023_24.h5"
},
"enhanced_frs_2023_24": {
"path": "enhanced_frs_2023_24.h5"
}
},
"region_datasets": {
"national": {
"path_template": "enhanced_frs_2023_24.h5"
}
}
}
Loading
Loading