From 986470ab1228d702124b16184443eacbb2d11a02 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Mon, 16 Mar 2026 15:39:22 -0700 Subject: [PATCH 01/29] Initial commit for adding NERSC IRI-API support alongside SFAPI for job submission --- orchestration/flows/bl832/nersc.py | 128 +++++++++++++++- orchestration/globus/token.py | 235 +++++++++++++++++++++++++++++ 2 files changed, 362 insertions(+), 1 deletion(-) create mode 100644 orchestration/globus/token.py diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index f4ffc9fb..a87c6c53 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -1,6 +1,7 @@ from dataclasses import dataclass, field import datetime from dotenv import load_dotenv +import enum import json import logging import os @@ -16,6 +17,7 @@ from typing import Any, Optional from orchestration.flows.bl832.config import Config832 + from orchestration.flows.bl832.job_controller import get_controller, HPC, TomographyHPCController from orchestration.mlflow import get_checkpoint_info from orchestration.prune_controller import get_prune_controller, PruneMethod @@ -23,6 +25,7 @@ from orchestration.flows.bl832.streaming_mixin import ( NerscStreamingMixin, SlurmJobBlock, cancellation_hook, monitor_streaming_job, save_block ) +from orchestration.globus.token import get_access_token_confidential, DEFAULT_TOKEN_FILE from orchestration.prefect import schedule_prefect_flow logger = logging.getLogger(__name__) @@ -142,6 +145,37 @@ def _load_job_options( return {**opts, **overrides} +class NERSCLoginMethod(enum.Enum): + """Selects which NERSC API login method to use when creating a NERSC client. + + Each method corresponds to a different set of credentials and API base URL. + """ + + SFAPI = "sfapi" + """Standard Superfacility API via Iris-registered OAuth2 credentials.""" + + IRIAPI = "iriapi" + """Integrated Research Infrastructure API via IRI-registered OAuth2 credentials.""" + + +# Applies only to NERSCLoginMethod.IRIAPI +_IRIAPI_GLOBUS_CLIENT_ID_ENV: str = "GLOBUS_CLIENT_ID" +_IRIAPI_GLOBUS_CLIENT_SECRET_ENV: str = "GLOBUS_CLIENT_SECRET" # set → confidential client +_IRIAPI_TOKEN_FILE_ENV: str = "PATH_GLOBUS_TOKEN_FILE" +_IRIAPI_GLOBUS_RESOURCE_SERVER: str = "auth.globus.org" +_IRIAPI_GLOBUS_REQUIRED_SCOPES: frozenset[str] = frozenset({ + "openid", + "profile", + "email", + "urn:globus:auth:scope:auth.globus.org:view_identities", +}) + +_API_BASE_URLS: dict[NERSCLoginMethod, str] = { + NERSCLoginMethod.SFAPI: "https://api.nersc.gov/api/v1.2", + NERSCLoginMethod.IRIAPI: "https://api.iri.nersc.gov", +} + + class NERSCTomographyHPCController(TomographyHPCController, NerscStreamingMixin): """ Implementation for a NERSC-based tomography HPC controller. @@ -158,7 +192,99 @@ def __init__( self.client = client @staticmethod - def create_sfapi_client() -> Client: + def create_nersc_client( + login_method: NERSCLoginMethod = NERSCLoginMethod.SFAPI, + ) -> Client: + """Create and return a NERSC client for the requested login method. + + Two fundamentally different auth strategies are supported: + + - :attr:`NERSCLoginMethod.SFAPI`: uses an Iris-registered OAuth2 + client ID + private key (NERSC OIDC flow). Set ``PATH_NERSC_CLIENT_ID`` + and ``PATH_NERSC_PRI_KEY`` to the paths of those files. + + - :attr:`NERSCLoginMethod.IRIAPI`: uses a Globus bearer token written + by ``globus_token.py``. Set ``PATH_GLOBUS_TOKEN_FILE`` to the token + file path, or rely on the default (``~/.globus/auth_tokens.json``). + + Args: + login_method: Which NERSC API to authenticate against. + Defaults to :attr:`NERSCLoginMethod.SFAPI`. + + Returns: + An authenticated :class:`sfapi_client.Client` instance. + + Raises: + ValueError: If SFAPI credential environment variables are unset. + FileNotFoundError: If credential or token files are absent. + RuntimeError: If the Globus token is expired. + Exception: If the underlying client construction fails. + """ + logger.info(f"Creating NERSC client using login method: {login_method.value}") + api_url = _API_BASE_URLS[login_method] + logger.info(f"Targeting API base URL: {api_url}") + + if login_method is NERSCLoginMethod.SFAPI: + client = NERSCTomographyHPCController._create_sfapi_client() + + elif login_method is NERSCLoginMethod.IRIAPI: + client = NERSCTomographyHPCController._create_iriapi_client() + + else: + raise ValueError(f"Unhandled NERSCLoginMethod: {login_method}") + + logger.info( + f"NERSC client created successfully " + f"(method={login_method.value}, api_url={api_url})." + ) + return client + + @staticmethod + def _create_iriapi_client() -> Client: + """Create a NERSC client for the IRI API using a Globus bearer token. + + Requires ``GLOBUS_CLIENT_ID`` and ``GLOBUS_CLIENT_SECRET`` in the + environment. Reuses a cached token if valid; otherwise mints a new one + via the client credentials grant. No browser or user interaction. + + Returns: + An authenticated :class:`sfapi_client.Client` targeting the IRI API. + + Raises: + ValueError: If ``GLOBUS_CLIENT_ID`` or ``GLOBUS_CLIENT_SECRET`` are unset. + RuntimeError: If the acquired token is missing required scopes. + """ + client_id = os.getenv(_IRIAPI_GLOBUS_CLIENT_ID_ENV) + client_secret = os.getenv(_IRIAPI_GLOBUS_CLIENT_SECRET_ENV) + + if not client_id: + raise ValueError( + f"Globus client ID is unset. Set {_IRIAPI_GLOBUS_CLIENT_ID_ENV}." + ) + if not client_secret: + raise ValueError( + f"Globus client secret is unset. Set {_IRIAPI_GLOBUS_CLIENT_SECRET_ENV}. " + "A Globus Confidential App client is required for automated IRI API auth." + ) + + token_file_env = os.getenv(_IRIAPI_TOKEN_FILE_ENV) + token_file = Path(token_file_env) if token_file_env else DEFAULT_TOKEN_FILE + + access_token = get_access_token_confidential( + client_id=client_id, + client_secret=client_secret, + required_scopes=_IRIAPI_GLOBUS_REQUIRED_SCOPES, + resource_server=_IRIAPI_GLOBUS_RESOURCE_SERVER, + token_file=token_file, + ) + + return Client( + token=access_token, + api_url=_API_BASE_URLS[NERSCLoginMethod.IRIAPI], + ) + + @staticmethod + def _create_sfapi_client() -> Client: """Create and return an NERSC client instance""" # When generating the SFAPI Key in Iris, make sure to select "asldev" as the user! diff --git a/orchestration/globus/token.py b/orchestration/globus/token.py new file mode 100644 index 00000000..81b5438f --- /dev/null +++ b/orchestration/globus/token.py @@ -0,0 +1,235 @@ +import json +import logging +import os +from pathlib import Path +import stat +import time + +import globus_sdk +from globus_sdk.exc import GlobusAPIError + +logger = logging.getLogger(__name__) + +# Default token file location, matching the Globus SDK convention. +DEFAULT_TOKEN_FILE: Path = Path.home() / ".globus" / "auth_tokens.json" +GLOBUS_OIDC_TOKEN_URL: str = "https://auth.globus.org/v2/oauth2/token" + + +def get_access_token_confidential( + client_id: str, + client_secret: str, + required_scopes: frozenset[str], + resource_server: str, + token_file: Path | None = None, +) -> str: + """Get a valid Globus access token using a Confidential Client (machine-to-machine). + + No browser or user interaction required. If a valid unexpired token exists + on disk it is reused; otherwise a new one is minted via the client + credentials grant and saved. + + Args: + client_id: Globus Confidential App client ID. + client_secret: Globus Confidential App client secret. + required_scopes: Set of OAuth2 scopes that must be present on the token. + resource_server: Resource server key to extract from the token response. + token_file: Path to the JSON token cache file. Defaults to + ``~/.globus/auth_tokens.json``. + + Returns: + A valid Globus access token string. + + Raises: + RuntimeError: If the acquired token is missing required scopes. + KeyError: If ``access_token`` is absent from the token response. + """ + resolved_token_file = token_file or DEFAULT_TOKEN_FILE + + # 1. Do we already have a valid token? + stored = load_token_file(resolved_token_file) + if stored: + expires_at = stored.get("expires_at_seconds") + if expires_at and time.time() < expires_at: + logger.info("Using cached Globus token (still valid).") + return stored["access_token"] + logger.info("Cached Globus token is expired; minting a new one.") + else: + logger.info("No cached Globus token found; minting a new one.") + + # 2. Mint a new token — same call whether first time or expired. + globus_client = globus_sdk.ConfidentialAppAuthClient(client_id, client_secret) + token_response = globus_client.oauth2_client_credentials_tokens( + requested_scopes=" ".join(sorted(required_scopes)) + ) + auth_data = token_response.by_resource_server[resource_server] + + granted = set(auth_data.get("scope", "").split()) + missing = required_scopes - granted + if missing: + raise RuntimeError( + f"Globus token is missing required scopes: {sorted(missing)}" + ) + + save_token_file(resolved_token_file, auth_data) + logger.info(f"New Globus token saved to {resolved_token_file}.") + + return auth_data["access_token"] + + +def load_token_file(token_file: Path) -> dict | None: + """Load saved Globus token data from disk. + + Args: + token_file: Path to the JSON token file. + + Returns: + Parsed token dict, or None if the file does not exist. + """ + if not token_file.exists(): + return None + with token_file.open("r", encoding="utf-8") as f: + return json.load(f) + + +def save_token_file(token_file: Path, tokens: dict) -> None: + """Atomically save Globus token data to disk with owner-only permissions. + + Writes to a temporary file then renames to avoid partial writes. + + Args: + token_file: Destination path for the JSON token file. + tokens: Token dict to serialise. + """ + _ensure_private_parent_dir(token_file) + tmp = token_file.with_suffix(".tmp") + with os.fdopen( + os.open(tmp, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600), + "w", + encoding="utf-8", + ) as f: + json.dump(tokens, f, indent=2) + os.replace(tmp, token_file) + os.chmod(token_file, stat.S_IRUSR | stat.S_IWUSR) + + +def interactive_login( + client: globus_sdk.NativeAppAuthClient, + required_scopes: frozenset[str], + resource_server: str, +) -> dict: + """Run an interactive browser-based Globus login flow. + + Prints an authorization URL, waits for the user to paste an auth code, + and exchanges it for tokens. + + Args: + client: Globus NativeAppAuthClient to drive the flow. + required_scopes: Set of OAuth2 scopes to request. + resource_server: Resource server key to extract from the token response + (e.g. ``"auth.globus.org"``). + + Returns: + Token dict for the given resource server. + """ + client.oauth2_start_flow( + requested_scopes=" ".join(sorted(required_scopes)), + refresh_tokens=True, + ) + logger.info("Open this URL in your browser to authenticate with Globus:") + logger.info(client.oauth2_get_authorize_url()) + code = input("\nEnter authorization code: ").strip() + token_response = client.oauth2_exchange_code_for_tokens(code) + return token_response.by_resource_server[resource_server] + + +def refresh_tokens( + client: globus_sdk.NativeAppAuthClient, + refresh_token: str, + resource_server: str, +) -> dict | None: + """Attempt a silent Globus token refresh. + + Args: + client: Globus NativeAppAuthClient to drive the refresh. + refresh_token: The stored refresh token. + resource_server: Resource server key to extract from the token response. + + Returns: + Fresh token dict for the given resource server, or None if refresh failed. + """ + try: + token_response = client.oauth2_refresh_token(refresh_token) + return token_response.by_resource_server[resource_server] + except GlobusAPIError as e: + logger.warning( + f"Globus token refresh failed ({e.http_status}); " + "falling back to interactive login." + ) + return None + + +def get_access_token( + client_id: str, + required_scopes: frozenset[str], + resource_server: str, + token_file: Path | None = None, + force_login: bool = False, +) -> str: + """Get a valid Globus access token, refreshing or logging in as needed. + + Attempts a silent refresh from the saved token file first. Falls back to + interactive browser login if no saved tokens exist, the refresh token is + absent, or the refresh fails. Saves the resulting tokens back to disk. + + Args: + client_id: Globus NativeApp client ID. + required_scopes: Set of OAuth2 scopes that must be present on the token. + resource_server: Resource server key to extract from the token response. + token_file: Path to the JSON token file. Defaults to + ``~/.globus/auth_tokens.json``. + force_login: If True, skip refresh and force interactive login. + + Returns: + A valid Globus access token string. + + Raises: + RuntimeError: If the acquired token is missing required scopes. + KeyError: If ``access_token`` is absent from the token response. + """ + resolved_token_file = token_file or DEFAULT_TOKEN_FILE + globus_client = globus_sdk.NativeAppAuthClient(client_id) + + auth_data: dict | None = None + + if not force_login: + stored = load_token_file(resolved_token_file) + if stored and stored.get("refresh_token"): + auth_data = refresh_tokens( + globus_client, stored["refresh_token"], resource_server + ) + + if auth_data is None: + logger.info("Initiating interactive Globus login.") + auth_data = interactive_login(globus_client, required_scopes, resource_server) + + granted = set(auth_data.get("scope", "").split()) + missing = required_scopes - granted + if missing: + raise RuntimeError( + f"Globus token is missing required scopes: {sorted(missing)}" + ) + + save_token_file(resolved_token_file, auth_data) + logger.info(f"Globus token saved to {resolved_token_file}.") + + return auth_data["access_token"] + + +def _ensure_private_parent_dir(path: Path) -> None: + """Create parent directories for path with owner-only permissions. + + Args: + path: The file path whose parent directory should be created. + """ + path.parent.mkdir(parents=True, exist_ok=True) + os.chmod(path.parent, 0o700) From 0512e587894efb5b9c1b2b05af65adb1ef442e06 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Mon, 16 Mar 2026 15:55:45 -0700 Subject: [PATCH 02/29] Adding an abstraction for _submit_job() and _wait_for_job() that use the correct mechanism based on IRI/SF-API --- orchestration/flows/bl832/nersc.py | 68 +++++++++++++++++++++++++++++- 1 file changed, 67 insertions(+), 1 deletion(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index a87c6c53..52ffcc25 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -186,10 +186,12 @@ class NERSCTomographyHPCController(TomographyHPCController, NerscStreamingMixin) def __init__( self, client: Client, - config: Config832 + config: Config832, + login_method: NERSCLoginMethod = NERSCLoginMethod.SFAPI, ) -> None: TomographyHPCController.__init__(self, config) self.client = client + self.login_method = login_method @staticmethod def create_nersc_client( @@ -353,6 +355,70 @@ def _get_segmentation_spec(self, model: str, project: str) -> SegmentationModelS ) return registry[key] + def _submit_job(self, job_script: str) -> str: + """Submit a Slurm job script and return the job ID. + + Dispatches to the appropriate submission mechanism based on + ``self.login_method``. + + Args: + job_script: The full Slurm batch script to submit. + + Returns: + The submitted job ID as a string. + + Raises: + RuntimeError: If job submission fails. + """ + if self.login_method is NERSCLoginMethod.SFAPI: + perlmutter = self.client.compute(Machine.perlmutter) + job = perlmutter.submit_job(job_script) + return str(job.jobid) + + elif self.login_method is NERSCLoginMethod.IRIAPI: + response = self.client.post( + "/api/v1/compute/job/perlmutter", + json={"script": job_script}, + ) + response.raise_for_status() + return str(response.json()["job_id"]) + + else: + raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") + + def _wait_for_job(self, job_id: str) -> bool: + """Block until a submitted job completes. + + Dispatches to the appropriate polling mechanism based on + ``self.login_method``. + + Args: + job_id: The job ID returned by :meth:`_submit_job`. + + Returns: + True if the job completed successfully, False otherwise. + """ + if self.login_method is NERSCLoginMethod.SFAPI: + perlmutter = self.client.compute(Machine.perlmutter) + job = perlmutter.job(jobid=job_id) + job.complete() + return True + + elif self.login_method is NERSCLoginMethod.IRIAPI: + while True: + response = self.client.get( + f"/api/v1/compute/status/perlmutter/{job_id}" + ) + response.raise_for_status() + state = response.json().get("state") + logger.info(f"Job {job_id} state: {state}") + if state in ("COMPLETED", "FAILED", "CANCELLED", "TIMEOUT"): + return state == "COMPLETED" + time.sleep(60) + + else: + raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") + def reconstruct( self, file_path: str = "", From fe275199eb25288f2e4eae0aeb0f1929a3d99750 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Tue, 17 Mar 2026 10:22:49 -0700 Subject: [PATCH 03/29] moving NERSCLoginMethod(Enum) to the job_controller.py module --- orchestration/flows/bl832/job_controller.py | 26 +++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/orchestration/flows/bl832/job_controller.py b/orchestration/flows/bl832/job_controller.py index b2ff064b..1a23d02a 100644 --- a/orchestration/flows/bl832/job_controller.py +++ b/orchestration/flows/bl832/job_controller.py @@ -10,6 +10,19 @@ load_dotenv() +class NERSCLoginMethod(Enum): + """Selects which NERSC API login method to use when creating a NERSC client. + + Each method corresponds to a different set of credentials and API base URL. + """ + + SFAPI = "sfapi" + """Standard Superfacility API via Iris-registered OAuth2 credentials.""" + + IRIAPI = "iriapi" + """Integrated Research Infrastructure API via IRI-registered OAuth2 credentials.""" + + class TomographyHPCController(ABC): """ Abstract class for tomography HPC controllers. @@ -65,7 +78,8 @@ class HPC(Enum): def get_controller( hpc_type: HPC, - config: Config832 + config: Config832, + login_method: "NERSCLoginMethod | None" = None, ) -> TomographyHPCController: """ Factory function that returns an HPC controller instance for the given HPC environment. @@ -86,10 +100,14 @@ def get_controller( config=config ) elif hpc_type == HPC.NERSC: - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + resolved_login_method = login_method if isinstance(login_method, NERSCLoginMethod) else NERSCLoginMethod.SFAPI return NERSCTomographyHPCController( - client=NERSCTomographyHPCController.create_sfapi_client(), - config=config + client=NERSCTomographyHPCController.create_nersc_client( + login_method=resolved_login_method + ), + config=config, + login_method=resolved_login_method, ) elif hpc_type == HPC.OLCF: # TODO: Implement OLCF controller From eaf02fe8c83a87a22970a7bbe6d6e91f5f8885e3 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Tue, 17 Mar 2026 10:25:00 -0700 Subject: [PATCH 04/29] Removed NERSCLoginMethod(Enum) from nersc.py. Created a temporary test flow for reconstruction to test job submission. In reconstruct(), replaced the SFAPI-specific job submission/polling code with the general _submit_job() and _wait_for_job() methods. --- orchestration/flows/bl832/nersc.py | 165 ++++++++++++++--------------- 1 file changed, 82 insertions(+), 83 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 52ffcc25..4e5e1c0e 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -2,6 +2,7 @@ import datetime from dotenv import load_dotenv import enum +import httpx import json import logging import os @@ -17,21 +18,37 @@ from typing import Any, Optional from orchestration.flows.bl832.config import Config832 - -from orchestration.flows.bl832.job_controller import get_controller, HPC, TomographyHPCController -from orchestration.mlflow import get_checkpoint_info -from orchestration.prune_controller import get_prune_controller, PruneMethod -from orchestration.transfer_controller import globus_transfer_task +from orchestration.flows.bl832.job_controller import get_controller, HPC, NERSCLoginMethod, TomographyHPCController from orchestration.flows.bl832.streaming_mixin import ( NerscStreamingMixin, SlurmJobBlock, cancellation_hook, monitor_streaming_job, save_block ) from orchestration.globus.token import get_access_token_confidential, DEFAULT_TOKEN_FILE +from orchestration.mlflow import get_checkpoint_info from orchestration.prefect import schedule_prefect_flow +from orchestration.prune_controller import get_prune_controller, PruneMethod +from orchestration.transfer_controller import globus_transfer_task logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) load_dotenv() +# Applies only to NERSCLoginMethod.IRIAPI +_IRIAPI_GLOBUS_CLIENT_ID_ENV: str = "GLOBUS_CLIENT_ID" +_IRIAPI_GLOBUS_CLIENT_SECRET_ENV: str = "GLOBUS_CLIENT_SECRET" # set → confidential client +_IRIAPI_TOKEN_FILE_ENV: str = "PATH_GLOBUS_TOKEN_FILE" +_IRIAPI_GLOBUS_RESOURCE_SERVER: str = "auth.globus.org" +_IRIAPI_GLOBUS_REQUIRED_SCOPES: frozenset[str] = frozenset({ + "openid", + "profile", + "email", + "urn:globus:auth:scope:auth.globus.org:view_identities", +}) + +_API_BASE_URLS: dict[NERSCLoginMethod, str] = { + NERSCLoginMethod.SFAPI: "https://api.nersc.gov/api/v1.2", + NERSCLoginMethod.IRIAPI: "https://api.iri.nersc.gov", +} + @dataclass class SegmentationModelSpec: @@ -158,24 +175,6 @@ class NERSCLoginMethod(enum.Enum): """Integrated Research Infrastructure API via IRI-registered OAuth2 credentials.""" -# Applies only to NERSCLoginMethod.IRIAPI -_IRIAPI_GLOBUS_CLIENT_ID_ENV: str = "GLOBUS_CLIENT_ID" -_IRIAPI_GLOBUS_CLIENT_SECRET_ENV: str = "GLOBUS_CLIENT_SECRET" # set → confidential client -_IRIAPI_TOKEN_FILE_ENV: str = "PATH_GLOBUS_TOKEN_FILE" -_IRIAPI_GLOBUS_RESOURCE_SERVER: str = "auth.globus.org" -_IRIAPI_GLOBUS_REQUIRED_SCOPES: frozenset[str] = frozenset({ - "openid", - "profile", - "email", - "urn:globus:auth:scope:auth.globus.org:view_identities", -}) - -_API_BASE_URLS: dict[NERSCLoginMethod, str] = { - NERSCLoginMethod.SFAPI: "https://api.nersc.gov/api/v1.2", - NERSCLoginMethod.IRIAPI: "https://api.iri.nersc.gov", -} - - class NERSCTomographyHPCController(TomographyHPCController, NerscStreamingMixin): """ Implementation for a NERSC-based tomography HPC controller. @@ -185,8 +184,8 @@ class NERSCTomographyHPCController(TomographyHPCController, NerscStreamingMixin) def __init__( self, - client: Client, config: Config832, + client: Client | httpx.Client | None = None, login_method: NERSCLoginMethod = NERSCLoginMethod.SFAPI, ) -> None: TomographyHPCController.__init__(self, config) @@ -280,9 +279,9 @@ def _create_iriapi_client() -> Client: token_file=token_file, ) - return Client( - token=access_token, - api_url=_API_BASE_URLS[NERSCLoginMethod.IRIAPI], + return httpx.Client( + base_url=_API_BASE_URLS[NERSCLoginMethod.IRIAPI], + headers={"Authorization": f"Bearer {access_token}"}, ) @staticmethod @@ -355,6 +354,28 @@ def _get_segmentation_spec(self, model: str, project: str) -> SegmentationModelS ) return registry[key] + def _get_nersc_username(self) -> str: + """Get the NERSC username for constructing pscratch paths. + + Uses the sfapi_client user endpoint for SFAPI, or reads + ``NERSC_USERNAME`` from the environment for IRIAPI. + + Returns: + NERSC username string. + + Raises: + ValueError: If IRIAPI is selected and NERSC_USERNAME is unset. + """ + if self.login_method is NERSCLoginMethod.SFAPI: + return self.client.user().name + else: + username = os.getenv("NERSC_USERNAME") + if not username: + raise ValueError( + "NERSC_USERNAME must be set in the environment when using IRIAPI." + ) + return username + def _submit_job(self, job_script: str) -> str: """Submit a Slurm job script and return the job ID. @@ -393,7 +414,7 @@ def _wait_for_job(self, job_id: str) -> bool: ``self.login_method``. Args: - job_id: The job ID returned by :meth:`_submit_job`. + job_id: The job ID returned by `_submit_job`. Returns: True if the job completed successfully, False otherwise. @@ -433,7 +454,8 @@ def reconstruct( """ logger.info("Starting NERSC reconstruction process.") - user = self.client.user() + # user = self.client.user() + username = self._get_nersc_username() raw_path = self.config.nersc832_alsdev_raw.root_path logger.info(f"{raw_path=}") @@ -447,7 +469,7 @@ def reconstruct( scratch_path = self.config.nersc832_alsdev_scratch.root_path logger.info(f"{scratch_path=}") - pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" + pscratch_path = f"/pscratch/sd/{username[0]}/{username}" logger.info(f"{pscratch_path=}") path = Path(file_path) @@ -596,55 +618,23 @@ def reconstruct( echo "JOB_STATUS=SUCCESS" >> $TIMING_FILE echo "JOB_END=$(date +%s)" >> $TIMING_FILE """ + job_id = None try: - logger.info("Submitting reconstruction job script to Perlmutter.") - perlmutter = self.client.compute(Machine.perlmutter) - job = perlmutter.submit_job(job_script) - logger.info(f"Submitted job ID: {job.jobid}") - - try: - job.update() - except Exception as update_err: - logger.warning(f"Initial job update failed, continuing: {update_err}") - + logger.info("Submitting reconstruction job to Perlmutter.") + job_id = self._submit_job(job_script) + logger.info(f"Submitted job ID: {job_id}") time.sleep(60) - logger.info(f"Job {job.jobid} current state: {job.state}") - - job.complete() # Wait until the job completes - logger.info("Reconstruction job completed successfully.") - # Fetch timing data - timing = self._fetch_timing_data(perlmutter, pscratch_path, job.jobid) - - return { - "success": True, - "job_id": job.jobid, - "timing": timing - } - + success = self._wait_for_job(job_id) + timing = self._fetch_timing_data(pscratch_path, job_id) if success else None + return {"success": success, "job_id": job_id, "timing": timing} except Exception as e: - logger.info(f"Error during job submission or completion: {e}") - match = re.search(r"Job not found:\s*(\d+)", str(e)) + logger.error(f"Error during reconstruction job submission or completion: {e}") + return {"success": False, "job_id": job_id, "timing": None} - if match: - jobid = match.group(1) - logger.info(f"Attempting to recover job {jobid}.") - try: - job = self.client.perlmutter.job(jobid=jobid) - time.sleep(30) - job.complete() - logger.info("Reconstruction job completed successfully after recovery.") - return True - except Exception as recovery_err: - logger.error(f"Failed to recover job {jobid}: {recovery_err}") - return False - else: - return False - - def _fetch_timing_data(self, perlmutter, pscratch_path: str, job_id: str) -> dict: + def _fetch_timing_data(self, pscratch_path: str, job_id: str) -> dict: """ Fetch and parse timing data from the SLURM job. - :param perlmutter: SFAPI compute object for Perlmutter :param pscratch_path: Path to the user's pscratch directory :param job_id: SLURM job ID :return: Dictionary with timing breakdown @@ -653,17 +643,26 @@ def _fetch_timing_data(self, perlmutter, pscratch_path: str, job_id: str) -> dic try: # Use SFAPI to read the timing file - result = perlmutter.run(f"cat {timing_file}") - - # result might be a string directly, or an object with .output - if isinstance(result, str): - output = result - elif hasattr(result, 'output'): - output = result.output - elif hasattr(result, 'stdout'): - output = result.stdout - else: - output = str(result) + if self.login_method is NERSCLoginMethod.SFAPI: + perlmutter = self.client.compute(Machine.perlmutter) + result = perlmutter.run(f"cat {timing_file}") + + # result might be a string directly, or an object with .output + if isinstance(result, str): + output = result + elif hasattr(result, 'output'): + output = result.output + elif hasattr(result, 'stdout'): + output = result.stdout + else: + output = str(result) + elif self.login_method is NERSCLoginMethod.IRIAPI: + response = self.client.get( + "/api/v1/filesystem/file/perlmutter", + params={"path": timing_file}, + ) + response.raise_for_status() + output = response.text logger.info(f"Timing file contents:\n{output}") From be2c5716a28c8de382f57318c324cfac58ee9195 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Tue, 17 Mar 2026 10:50:35 -0700 Subject: [PATCH 05/29] Updating pytests --- orchestration/_tests/test_bl832/test_nersc.py | 309 +++++++++++++--- orchestration/_tests/test_sfapi_flow.py | 332 ++---------------- 2 files changed, 292 insertions(+), 349 deletions(-) diff --git a/orchestration/_tests/test_bl832/test_nersc.py b/orchestration/_tests/test_bl832/test_nersc.py index 8d7056a8..7a8ca07a 100644 --- a/orchestration/_tests/test_bl832/test_nersc.py +++ b/orchestration/_tests/test_bl832/test_nersc.py @@ -1,5 +1,4 @@ -# orchestration/_tests/bl832/test_nersc.py - +# orchestration/_tests/test_bl832/test_nersc.py import pytest from uuid import uuid4 @@ -20,18 +19,28 @@ def prefect_test_fixture(): yield -# ────────────────────────────────────────────────────────────────────────────── +# --------------------------------------------------------------------------- # Shared fixtures -# ────────────────────────────────────────────────────────────────────────────── +# --------------------------------------------------------------------------- + +@pytest.fixture +def mock_config(mocker): + config = mocker.MagicMock() + config.ghcr_images832 = { + "recon_image": "mock_recon_image", + "multires_image": "mock_multires_image", + } + return config + @pytest.fixture def mock_sfapi_client(mocker): - """Mock sfapi_client.Client with a completed job on Perlmutter.""" - mock_client = mocker.MagicMock() + """sfapi_client.Client mock with user, compute, submit_job, and job chained.""" + client = mocker.MagicMock() mock_user = mocker.MagicMock() mock_user.name = "testuser" - mock_client.user.return_value = mock_user + client.user.return_value = mock_user mock_job = mocker.MagicMock() mock_job.jobid = "12345" @@ -39,10 +48,9 @@ def mock_sfapi_client(mocker): mock_compute = mocker.MagicMock() mock_compute.submit_job.return_value = mock_job - mock_client.compute.return_value = mock_compute - - mocker.patch("orchestration.flows.bl832.nersc.Client", return_value=mock_client) - return mock_client + client.compute.return_value = mock_compute + mocker.patch("orchestration.flows.bl832.nersc.Client", return_value=client) + return client @pytest.fixture @@ -167,11 +175,28 @@ def _make_future(mocker, value): return f -# ────────────────────────────────────────────────────────────────────────────── -# create_sfapi_client -# ────────────────────────────────────────────────────────────────────────────── +@pytest.fixture +def mock_iriapi_client(mocker): + """httpx.Client mock for IRI API responses.""" + client = mocker.MagicMock() + + submit_response = mocker.MagicMock() + submit_response.json.return_value = {"job_id": "99999"} + client.post.return_value = submit_response + + status_response = mocker.MagicMock() + status_response.json.return_value = {"state": "COMPLETED"} + client.get.return_value = status_response + + return client + + +# --------------------------------------------------------------------------- +# _create_sfapi_client +# --------------------------------------------------------------------------- def test_create_sfapi_client_success(mocker): + """Valid credentials produce a Client instance.""" from orchestration.flows.bl832.nersc import NERSCTomographyHPCController mocker.patch("orchestration.flows.bl832.nersc.os.getenv", side_effect=lambda x: { @@ -179,29 +204,34 @@ def test_create_sfapi_client_success(mocker): "PATH_NERSC_PRI_KEY": "/path/to/client_secret", }.get(x)) mocker.patch("orchestration.flows.bl832.nersc.os.path.isfile", return_value=True) - mocker.patch("builtins.open", side_effect=[ - mocker.mock_open(read_data="client_id_value")(), - mocker.mock_open(read_data='{"key": "value"}')(), - ]) + mocker.patch( + "builtins.open", + side_effect=[ + mocker.mock_open(read_data="my-client-id")(), + mocker.mock_open(read_data='{"kty": "RSA", "n": "x", "e": "y"}')(), + ] + ) mocker.patch("orchestration.flows.bl832.nersc.JsonWebKey.import_key", return_value="mock_secret") mock_client_cls = mocker.patch("orchestration.flows.bl832.nersc.Client") - client = NERSCTomographyHPCController.create_sfapi_client() + client = NERSCTomographyHPCController._create_sfapi_client() - mock_client_cls.assert_called_once_with("client_id_value", "mock_secret") - assert client == mock_client_cls.return_value + mock_client_cls.assert_called_once_with("my-client-id", "mock_secret") + assert client is mock_client_cls.return_value def test_create_sfapi_client_missing_paths(mocker): + """Unset env vars raise ValueError.""" from orchestration.flows.bl832.nersc import NERSCTomographyHPCController mocker.patch("orchestration.flows.bl832.nersc.os.getenv", return_value=None) with pytest.raises(ValueError, match="Missing NERSC credentials paths."): - NERSCTomographyHPCController.create_sfapi_client() + NERSCTomographyHPCController._create_sfapi_client() def test_create_sfapi_client_missing_files(mocker): + """Env vars set but files absent raise FileNotFoundError.""" from orchestration.flows.bl832.nersc import NERSCTomographyHPCController mocker.patch("orchestration.flows.bl832.nersc.os.getenv", side_effect=lambda x: { @@ -211,40 +241,7 @@ def test_create_sfapi_client_missing_files(mocker): mocker.patch("orchestration.flows.bl832.nersc.os.path.isfile", return_value=False) with pytest.raises(FileNotFoundError, match="NERSC credential files are missing."): - NERSCTomographyHPCController.create_sfapi_client() - - -# ────────────────────────────────────────────────────────────────────────────── -# reconstruct -# ────────────────────────────────────────────────────────────────────────────── - -def test_reconstruct_success(mocker, mock_sfapi_client, mock_config832): - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - from sfapi_client.compute import Machine - - mocker.patch("orchestration.flows.bl832.nersc.time.sleep") - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) - - result = controller.reconstruct(file_path="folder/file.h5") - - mock_sfapi_client.compute.assert_called_once_with(Machine.perlmutter) - mock_sfapi_client.compute.return_value.submit_job.assert_called_once() - mock_sfapi_client.compute.return_value.submit_job.return_value.complete.assert_called_once() - assert isinstance(result, dict) - assert result["success"] is True - assert result["job_id"] == "12345" - - -def test_reconstruct_submission_failure(mocker, mock_sfapi_client, mock_config832): - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - - mocker.patch("orchestration.flows.bl832.nersc.time.sleep") - mock_sfapi_client.compute.return_value.submit_job.side_effect = Exception("Submission failed") - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) - - result = controller.reconstruct(file_path="folder/file.h5") - - assert result is False + NERSCTomographyHPCController._create_sfapi_client() # ────────────────────────────────────────────────────────────────────────────── @@ -386,6 +383,162 @@ def test_segmentation_dinov3_submission_failure(mocker, mock_sfapi_client, mock_ controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) result = controller.segmentation_dinov3(recon_folder_path="folder/recfile") + assert result is False + +# --------------------------------------------------------------------------- +# reconstruct — SFAPI +# --------------------------------------------------------------------------- + + +def test_reconstruct_sfapi_success(mocker, mock_sfapi_client, mock_config832): + """SFAPI reconstruct submits a job and waits for completion.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + from sfapi_client.compute import Machine + + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + + controller = NERSCTomographyHPCController( + client=mock_sfapi_client, + config=mock_config832, + login_method=NERSCLoginMethod.SFAPI, + ) + + result = controller.reconstruct(file_path="folder/scan.h5") + + assert result["success"] is True + assert result["job_id"] == "12345" + assert mock_sfapi_client.compute.call_count == 3 # 1 _submit_job() + 1 _wait_for_job() + 1 _fetch_timing_data() + mock_sfapi_client.compute.assert_called_with(Machine.perlmutter) + mock_sfapi_client.compute.return_value.submit_job.assert_called_once() + mock_sfapi_client.compute.return_value.job.assert_called_once_with(jobid="12345") + mock_sfapi_client.compute.return_value.job.return_value.complete.assert_called_once() + + +def test_reconstruct_sfapi_submission_failure(mocker, mock_sfapi_client, mock_config832): + """SFAPI reconstruct returns False when submission raises.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + mock_sfapi_client.compute.return_value.submit_job.side_effect = Exception("SFAPI error") + + controller = NERSCTomographyHPCController( + client=mock_sfapi_client, + config=mock_config832, + login_method=NERSCLoginMethod.SFAPI, + ) + + result = controller.reconstruct(file_path="folder/scan.h5") + + assert result["success"] is False + + +# --------------------------------------------------------------------------- +# reconstruct — IRIAPI +# --------------------------------------------------------------------------- + +def test_reconstruct_iriapi_success(mocker, mock_iriapi_client, mock_config832, monkeypatch): + """IRIAPI reconstruct POSTs a job and polls for COMPLETED state.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + + monkeypatch.setenv("NERSC_USERNAME", "alsdev") + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + + controller = NERSCTomographyHPCController( + client=mock_iriapi_client, + config=mock_config832, + login_method=NERSCLoginMethod.IRIAPI, + ) + + result = controller.reconstruct(file_path="folder/scan.h5") + + assert result["success"] is True + assert result["job_id"] == "99999" + mock_iriapi_client.post.assert_called_once() + assert mock_iriapi_client.post.call_args.args[0] == "/api/v1/compute/job/perlmutter" + assert "script" in mock_iriapi_client.post.call_args.kwargs["json"] + assert mock_iriapi_client.get.call_count == 2 + mock_iriapi_client.get.assert_any_call( + "/api/v1/compute/status/perlmutter/99999" + ) + mock_iriapi_client.get.assert_any_call( + "/api/v1/filesystem/file/perlmutter", + params={"path": mocker.ANY}, + ) + + +def test_reconstruct_iriapi_job_failed(mocker, mock_iriapi_client, mock_config832, monkeypatch): + """IRIAPI reconstruct returns False when job state is FAILED.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + + monkeypatch.setenv("NERSC_USERNAME", "alsdev") + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + mock_iriapi_client.get.return_value.json.return_value = {"state": "FAILED"} + + controller = NERSCTomographyHPCController( + client=mock_iriapi_client, + config=mock_config832, + login_method=NERSCLoginMethod.IRIAPI, + ) + + result = controller.reconstruct(file_path="folder/scan.h5") + + assert result["success"] is False + + +def test_reconstruct_iriapi_missing_username(mocker, mock_iriapi_client, mock_config832, monkeypatch): + """IRIAPI reconstruct raises ValueError when NERSC_USERNAME is unset.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + + monkeypatch.delenv("NERSC_USERNAME", raising=False) + + controller = NERSCTomographyHPCController( + client=mock_iriapi_client, + config=mock_config832, + login_method=NERSCLoginMethod.IRIAPI, + ) + + with pytest.raises(ValueError, match="NERSC_USERNAME"): + controller.reconstruct(file_path="folder/scan.h5") + + +# --------------------------------------------------------------------------- +# build_multi_resolution — SFAPI +# --------------------------------------------------------------------------- + +def test_build_multi_resolution_sfapi_success(mocker, mock_sfapi_client, mock_config832): + """SFAPI build_multi_resolution submits and waits successfully.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + from sfapi_client.compute import Machine + + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + + controller = NERSCTomographyHPCController( + client=mock_sfapi_client, + config=mock_config832, + login_method=NERSCLoginMethod.SFAPI, + ) + + result = controller.build_multi_resolution(file_path="folder/scan.h5") + + assert result is True + assert mock_sfapi_client.compute.call_count == 2 + mock_sfapi_client.compute.assert_called_with(Machine.perlmutter) + + +def test_build_multi_resolution_sfapi_failure(mocker, mock_sfapi_client, mock_config832): + """SFAPI build_multi_resolution returns False when submission raises.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + mock_sfapi_client.compute.return_value.submit_job.side_effect = Exception("error") + + controller = NERSCTomographyHPCController( + client=mock_sfapi_client, + config=mock_config832, + login_method=NERSCLoginMethod.SFAPI, + ) + + result = controller.build_multi_resolution(file_path="folder/scan.h5") assert result is False @@ -794,3 +947,47 @@ def test_moon_segment_flow_no_sam3_no_combine(mocker, mock_config832, mock_recon mock_sam3_task.submit.assert_not_called() mock_combine_task.submit.assert_not_called() +# --------------------------------------------------------------------------- +# build_multi_resolution — IRIAPI +# --------------------------------------------------------------------------- + + +def test_build_multi_resolution_iriapi_success(mocker, mock_iriapi_client, mock_config, monkeypatch): + """IRIAPI build_multi_resolution POSTs and polls successfully.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + + monkeypatch.setenv("NERSC_USERNAME", "alsdev") + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + + controller = NERSCTomographyHPCController( + client=mock_iriapi_client, + config=mock_config, + login_method=NERSCLoginMethod.IRIAPI, + ) + + result = controller.build_multi_resolution(file_path="folder/scan.h5") + + assert result is True + mock_iriapi_client.post.assert_called_once() + mock_iriapi_client.get.assert_called_once_with( + "/api/v1/compute/status/perlmutter/99999" + ) + + +def test_build_multi_resolution_iriapi_failure(mocker, mock_iriapi_client, mock_config, monkeypatch): + """IRIAPI build_multi_resolution returns False when job state is FAILED.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + + monkeypatch.setenv("NERSC_USERNAME", "alsdev") + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + mock_iriapi_client.get.return_value.json.return_value = {"state": "FAILED"} + + controller = NERSCTomographyHPCController( + client=mock_iriapi_client, + config=mock_config, + login_method=NERSCLoginMethod.IRIAPI, + ) + + result = controller.build_multi_resolution(file_path="folder/scan.h5") + + assert result is False diff --git a/orchestration/_tests/test_sfapi_flow.py b/orchestration/_tests/test_sfapi_flow.py index 6e9bf225..e0d4a854 100644 --- a/orchestration/_tests/test_sfapi_flow.py +++ b/orchestration/_tests/test_sfapi_flow.py @@ -1,8 +1,5 @@ # orchestration/_tests/test_sfapi_flow.py - -from pathlib import Path import pytest -from unittest.mock import MagicMock, patch, mock_open from uuid import uuid4 from prefect.blocks.system import Secret @@ -11,307 +8,56 @@ @pytest.fixture(autouse=True, scope="session") def prefect_test_fixture(): - """ - A pytest fixture that automatically sets up and tears down the Prefect test harness - for the entire test session. It creates and saves test secrets and configurations - required for Globus integration. - - Yields: - None - """ with prefect_test_harness(): - globus_client_id = Secret(value=str(uuid4())) - globus_client_id.save(name="globus-client-id", overwrite=True) - globus_client_secret = Secret(value=str(uuid4())) - globus_client_secret.save(name="globus-client-secret", overwrite=True) - + Secret(value=str(uuid4())).save(name="globus-client-id", overwrite=True) + Secret(value=str(uuid4())).save(name="globus-client-secret", overwrite=True) yield -# ---------------------------- -# Tests for create_sfapi_client -# ---------------------------- - - -def test_create_sfapi_client_success(): - """ - Test successful creation of the SFAPI client. - """ - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - - # Mock data for client_id and client_secret files - mock_client_id = 'value' - mock_client_secret = '{"key": "value"}' - - # Create separate mock_open instances for each file - mock_open_client_id = mock_open(read_data=mock_client_id) - mock_open_client_secret = mock_open(read_data=mock_client_secret) - - with patch("orchestration.flows.bl832.nersc.os.getenv") as mock_getenv, \ - patch("orchestration.flows.bl832.nersc.os.path.isfile") as mock_isfile, \ - patch("builtins.open", side_effect=[ - mock_open_client_id.return_value, - mock_open_client_secret.return_value - ]), \ - patch("orchestration.flows.bl832.nersc.JsonWebKey.import_key") as mock_import_key, \ - patch("orchestration.flows.bl832.nersc.Client") as MockClient: - - # Mock environment variables - mock_getenv.side_effect = lambda x: { - "PATH_NERSC_CLIENT_ID": "/path/to/client_id", - "PATH_NERSC_PRI_KEY": "/path/to/client_secret" - }.get(x, None) - - # Mock file existence - mock_isfile.return_value = True - - # Mock JsonWebKey.import_key to return a mock secret - mock_import_key.return_value = "mock_secret" - - # Create the client - client = NERSCTomographyHPCController.create_sfapi_client() - - # Assert that Client was instantiated with 'value' and 'mock_secret' - MockClient.assert_called_once_with("value", "mock_secret") - - # Assert that the returned client is the mocked client - assert client == MockClient.return_value, "Client should be the mocked sfapi_client.Client instance" - - -def test_create_sfapi_client_missing_paths(): - """ - Test creation of the SFAPI client with missing credential paths. - """ - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - - with patch("orchestration.flows.bl832.nersc.os.getenv", return_value=None): - with pytest.raises(ValueError, match="Missing NERSC credentials paths."): - NERSCTomographyHPCController.create_sfapi_client() - - -def test_create_sfapi_client_missing_files(): - """ - Test creation of the SFAPI client with missing credential files. - """ - with ( - # Mock environment variables - patch( - "orchestration.flows.bl832.nersc.os.getenv", - side_effect=lambda x: { - "PATH_NERSC_CLIENT_ID": "/path/to/client_id", - "PATH_NERSC_PRI_KEY": "/path/to/client_secret" - }.get(x, None) - ), - - # Mock file existence to simulate missing files - patch("orchestration.flows.bl832.nersc.os.path.isfile", return_value=False) - ): - # Import the module after applying patches to ensure mocks are in place - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - - # Expect a FileNotFoundError due to missing credential files - with pytest.raises(FileNotFoundError, match="NERSC credential files are missing."): - NERSCTomographyHPCController.create_sfapi_client() - -# ---------------------------- -# Fixture for Mocking SFAPI Client -# ---------------------------- - - -@pytest.fixture -def mock_sfapi_client(): - """ - Mock the sfapi_client.Client class with necessary methods. - """ - with patch("orchestration.flows.bl832.nersc.Client") as MockClient: - mock_client_instance = MockClient.return_value - - # Mock the user method - mock_user = MagicMock() - mock_user.name = "testuser" - mock_client_instance.user.return_value = mock_user - - # Mock the compute method to return a mocked compute object - mock_compute = MagicMock() - mock_job = MagicMock() - mock_job.jobid = "12345" - mock_job.state = "COMPLETED" - mock_compute.submit_job.return_value = mock_job - mock_client_instance.compute.return_value = mock_compute - - yield mock_client_instance - - -# ---------------------------- -# Fixture for Mocking Config832 -# ---------------------------- - -@pytest.fixture -def mock_config832(): - """ - Mock the Config832 class to provide necessary configurations. - - All settings dicts must be fully populated to match the config YAML schema, - because _load_job_options() passes config_settings directly as the defaults - dict and then accesses keys by name. - """ - with patch("orchestration.flows.bl832.nersc.Config832") as MockConfig: - mock_config = MockConfig.return_value - mock_config.ghcr_images832 = { - "recon_image": "mock_recon_image", - "multires_image": "mock_multires_image", - } - mock_config.nersc_recon_settings = { - "qos": "realtime", - "account": "mock_account", - "reservation": "", - "num_nodes": 4, - "cpus-per-task": 128, - "walltime": "0:30:00", - } - mock_config.nersc_multiresolution_settings = { - "qos": "realtime", - "account": "mock_account", - "reservation": "", - "cpus-per-task": 128, - "walltime": "0:15:00", - } - mock_config.apps = {"als_transfer": "some_config"} - yield mock_config - - -# ---------------------------- -# Tests for NERSCTomographyHPCController -# ---------------------------- - -def test_reconstruct_success(mock_sfapi_client, mock_config832): - """ - Test successful reconstruction job submission. - """ +def test_create_sfapi_client_success(mocker): + """Valid credentials produce a Client instance.""" from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - from sfapi_client.compute import Machine - - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) - file_path = "path/to/file.h5" - - with patch("orchestration.flows.bl832.nersc.time.sleep", return_value=None): - result = controller.reconstruct(file_path=file_path) - - # Verify that compute was called with Machine.perlmutter - mock_sfapi_client.compute.assert_called_once_with(Machine.perlmutter) - - # Verify that submit_job was called once - mock_sfapi_client.compute.return_value.submit_job.assert_called_once() - - # Verify that complete was called on the job - mock_sfapi_client.compute.return_value.submit_job.return_value.complete.assert_called_once() - - # Assert that the method returns True - assert isinstance(result, dict) - assert result["success"] is True - assert result["job_id"] == "12345" - -def test_reconstruct_submission_failure(mock_sfapi_client, mock_config832): - """ - Test reconstruction job submission failure. - """ + mocker.patch("orchestration.flows.bl832.nersc.os.getenv", side_effect=lambda x: { + "PATH_NERSC_CLIENT_ID": "/path/to/client_id", + "PATH_NERSC_PRI_KEY": "/path/to/client_secret", + }.get(x)) + mocker.patch("orchestration.flows.bl832.nersc.os.path.isfile", return_value=True) + mocker.patch( + "builtins.open", + side_effect=[ + mocker.mock_open(read_data="my-client-id")(), + mocker.mock_open(read_data='{"kty": "RSA", "n": "x", "e": "y"}')(), + ] + ) + mocker.patch("orchestration.flows.bl832.nersc.JsonWebKey.import_key", return_value="mock_secret") + mock_client_cls = mocker.patch("orchestration.flows.bl832.nersc.Client") + + client = NERSCTomographyHPCController._create_sfapi_client() + + mock_client_cls.assert_called_once_with("my-client-id", "mock_secret") + assert client is mock_client_cls.return_value + + +def test_create_sfapi_client_missing_paths(mocker): + """Unset env vars raise ValueError.""" from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) - file_path = "path/to/file.h5" - - # Simulate submission failure - mock_sfapi_client.compute.return_value.submit_job.side_effect = Exception("Submission failed") + mocker.patch("orchestration.flows.bl832.nersc.os.getenv", return_value=None) - with patch("orchestration.flows.bl832.nersc.time.sleep", return_value=None): - result = controller.reconstruct(file_path=file_path) + with pytest.raises(ValueError, match="Missing NERSC credentials paths."): + NERSCTomographyHPCController._create_sfapi_client() - # Assert that the method returns False - assert result is False, "reconstruct should return False on submission failure." - -def test_build_multi_resolution_success(mock_sfapi_client, mock_config832): - """ - Test successful multi-resolution job submission. - """ +def test_create_sfapi_client_missing_files(mocker): + """Env vars set but files absent raise FileNotFoundError.""" from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - from sfapi_client.compute import Machine - - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) - file_path = "path/to/file.h5" - - with patch("orchestration.flows.bl832.nersc.time.sleep", return_value=None): - result = controller.build_multi_resolution(file_path=file_path) - - # Verify that compute was called with Machine.perlmutter - mock_sfapi_client.compute.assert_called_once_with(Machine.perlmutter) - - # Verify that submit_job was called once - mock_sfapi_client.compute.return_value.submit_job.assert_called_once() - - # Verify that complete was called on the job - mock_sfapi_client.compute.return_value.submit_job.return_value.complete.assert_called_once() - - # Assert that the method returns True - assert result is True, "build_multi_resolution should return True on successful job completion." - - -def test_build_multi_resolution_submission_failure(mock_sfapi_client, mock_config832): - """ - Test multi-resolution job submission failure. - """ - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) - file_path = "path/to/file.h5" - - # Simulate submission failure - mock_sfapi_client.compute.return_value.submit_job.side_effect = Exception("Submission failed") - - with patch("orchestration.flows.bl832.nersc.time.sleep", return_value=None): - result = controller.build_multi_resolution(file_path=file_path) - - # Assert that the method returns False - assert result is False, "build_multi_resolution should return False on submission failure." - - -def test_job_submission(mock_sfapi_client): - """ - Test job submission and status updates. - """ - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - from sfapi_client.compute import Machine - - mock_config = MagicMock() - mock_config.nersc_recon_settings = { - "qos": "realtime", - "account": "mock_account", - "reservation": "", - "num_nodes": 4, - "cpus-per-task": 128, - "walltime": "0:30:00", - } - - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config) - file_path = "path/to/file.h5" - - # Mock Path to extract file and folder names - with patch.object(Path, 'parent', new_callable=MagicMock) as mock_parent, \ - patch.object(Path, 'stem', new_callable=MagicMock) as mock_stem: - mock_parent.name = "to" - mock_stem.return_value = "file" - - with patch("orchestration.flows.bl832.nersc.time.sleep", return_value=None): - controller.reconstruct(file_path=file_path) - - # Verify that compute was called with Machine.perlmutter - mock_sfapi_client.compute.assert_called_once_with(Machine.perlmutter) - # Verify that submit_job was called once - mock_sfapi_client.compute.return_value.submit_job.assert_called_once() + mocker.patch("orchestration.flows.bl832.nersc.os.getenv", side_effect=lambda x: { + "PATH_NERSC_CLIENT_ID": "/path/to/client_id", + "PATH_NERSC_PRI_KEY": "/path/to/client_secret", + }.get(x)) + mocker.patch("orchestration.flows.bl832.nersc.os.path.isfile", return_value=False) - # Verify the returned job has the expected attributes - submitted_job = mock_sfapi_client.compute.return_value.submit_job.return_value - assert submitted_job.jobid == "12345", "Job ID should match the mock job ID." - assert submitted_job.state == "COMPLETED", "Job state should be COMPLETED." + with pytest.raises(FileNotFoundError, match="NERSC credential files are missing."): + NERSCTomographyHPCController._create_sfapi_client() From cf15c2041d7985f9fc2fcdfa748ea75530b8f49d Mon Sep 17 00:00:00 2001 From: David Abramov Date: Tue, 17 Mar 2026 10:51:14 -0700 Subject: [PATCH 06/29] Updating multires() method to use the generic _submit_job() and _wait_for_job() helpers --- orchestration/flows/bl832/nersc.py | 85 +++++++++++++++++------------- 1 file changed, 48 insertions(+), 37 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 4e5e1c0e..79a4bdb0 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -7,7 +7,6 @@ import logging import os from pathlib import Path -import re import time from authlib.jose import JsonWebKey @@ -718,7 +717,8 @@ def build_multi_resolution( logger.info("Starting NERSC multiresolution process.") - user = self.client.user() + # user = self.client.user() + username = self._get_nersc_username() multires_image = self.config.ghcr_images832["multires_image"] logger.info(f"{multires_image=}") @@ -729,7 +729,7 @@ def build_multi_resolution( scratch_path = self.config.nersc832_alsdev_scratch.root_path logger.info(f"{scratch_path=}") - pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" + pscratch_path = f"/pscratch/sd/{username[0]}/{username}" logger.info(f"{pscratch_path=}") path = Path(file_path) @@ -784,42 +784,53 @@ def build_multi_resolution( date """ try: - logger.info("Submitting Tiff to Zarr job script to Perlmutter.") - perlmutter = self.client.compute(Machine.perlmutter) - job = perlmutter.submit_job(job_script) - logger.info(f"Submitted job ID: {job.jobid}") - - try: - job.update() - except Exception as update_err: - logger.warning(f"Initial job update failed, continuing: {update_err}") - + logger.info("Submitting Tiff to Zarr job to Perlmutter.") + job_id = self._submit_job(job_script) + logger.info(f"Submitted job ID: {job_id}") time.sleep(60) - logger.info(f"Job {job.jobid} current state: {job.state}") - - job.complete() # Wait until the job completes - logger.info("Reconstruction job completed successfully.") - - return True - + success = self._wait_for_job(job_id) + logger.info(f"Multiresolution job {'completed' if success else 'failed'}.") + return success except Exception as e: - logger.warning(f"Error during job submission or completion: {e}") - match = re.search(r"Job not found:\s*(\d+)", str(e)) - - if match: - jobid = match.group(1) - logger.info(f"Attempting to recover job {jobid}.") - try: - job = self.client.perlmutter.job(jobid=jobid) - time.sleep(30) - job.complete() - logger.info("Reconstruction job completed successfully after recovery.") - return True - except Exception as recovery_err: - logger.error(f"Failed to recover job {jobid}: {recovery_err}") - return False - else: - return False + logger.error(f"Error during multiresolution job submission or completion: {e}") + return False + # try: + # logger.info("Submitting Tiff to Zarr job script to Perlmutter.") + # perlmutter = self.client.compute(Machine.perlmutter) + # job = perlmutter.submit_job(job_script) + # logger.info(f"Submitted job ID: {job.jobid}") + + # try: + # job.update() + # except Exception as update_err: + # logger.warning(f"Initial job update failed, continuing: {update_err}") + + # time.sleep(60) + # logger.info(f"Job {job.jobid} current state: {job.state}") + + # job.complete() # Wait until the job completes + # logger.info("Reconstruction job completed successfully.") + + # return True + + # except Exception as e: + # logger.warning(f"Error during job submission or completion: {e}") + # match = re.search(r"Job not found:\s*(\d+)", str(e)) + + # if match: + # jobid = match.group(1) + # logger.info(f"Attempting to recover job {jobid}.") + # try: + # job = self.client.perlmutter.job(jobid=jobid) + # time.sleep(30) + # job.complete() + # logger.info("Reconstruction job completed successfully after recovery.") + # return True + # except Exception as recovery_err: + # logger.error(f"Failed to recover job {jobid}: {recovery_err}") + # return False + # else: + # return False def segmentation_sam3( self, From d0e80683737337f3dee3b74b5c44e5b6ba29b405 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Mon, 30 Mar 2026 14:53:53 -0700 Subject: [PATCH 07/29] successfully ran reconstruction using the IRI-API --- orchestration/flows/bl832/nersc.py | 78 ++++-- orchestration/globus/token.py | 390 +++++++++++++++++++++-------- scripts/get_globus_token.py | 337 +++++++++++++++++++++++++ 3 files changed, 681 insertions(+), 124 deletions(-) create mode 100644 scripts/get_globus_token.py diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 79a4bdb0..2aad35de 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -7,6 +7,7 @@ import logging import os from pathlib import Path +import re import time from authlib.jose import JsonWebKey @@ -21,8 +22,12 @@ from orchestration.flows.bl832.streaming_mixin import ( NerscStreamingMixin, SlurmJobBlock, cancellation_hook, monitor_streaming_job, save_block ) -from orchestration.globus.token import get_access_token_confidential, DEFAULT_TOKEN_FILE from orchestration.mlflow import get_checkpoint_info +from orchestration.globus.token import ( + get_access_token, + DEFAULT_TOKEN_FILE, + IRI_SCOPE, +) from orchestration.prefect import schedule_prefect_flow from orchestration.prune_controller import get_prune_controller, PruneMethod from orchestration.transfer_controller import globus_transfer_task @@ -33,7 +38,9 @@ # Applies only to NERSCLoginMethod.IRIAPI _IRIAPI_GLOBUS_CLIENT_ID_ENV: str = "GLOBUS_CLIENT_ID" -_IRIAPI_GLOBUS_CLIENT_SECRET_ENV: str = "GLOBUS_CLIENT_SECRET" # set → confidential client +_IRI_COMPUTE_RESOURCE: str = "compute" +_IRI_SCRATCH_RESOURCE: str = "scratch" +# _IRIAPI_GLOBUS_CLIENT_SECRET_ENV: str = "GLOBUS_CLIENT_SECRET" # set → confidential client _IRIAPI_TOKEN_FILE_ENV: str = "PATH_GLOBUS_TOKEN_FILE" _IRIAPI_GLOBUS_RESOURCE_SERVER: str = "auth.globus.org" _IRIAPI_GLOBUS_REQUIRED_SCOPES: frozenset[str] = frozenset({ @@ -41,6 +48,7 @@ "profile", "email", "urn:globus:auth:scope:auth.globus.org:view_identities", + IRI_SCOPE, }) _API_BASE_URLS: dict[NERSCLoginMethod, str] = { @@ -254,33 +262,33 @@ def _create_iriapi_client() -> Client: ValueError: If ``GLOBUS_CLIENT_ID`` or ``GLOBUS_CLIENT_SECRET`` are unset. RuntimeError: If the acquired token is missing required scopes. """ - client_id = os.getenv(_IRIAPI_GLOBUS_CLIENT_ID_ENV) - client_secret = os.getenv(_IRIAPI_GLOBUS_CLIENT_SECRET_ENV) + client_id = "fae5c579-490a-4d76-b6eb-d78f65caeb63" # os.getenv(_IRIAPI_GLOBUS_CLIENT_ID_ENV) + # client_secret = os.getenv(_IRIAPI_GLOBUS_CLIENT_SECRET_ENV) if not client_id: raise ValueError( f"Globus client ID is unset. Set {_IRIAPI_GLOBUS_CLIENT_ID_ENV}." ) - if not client_secret: - raise ValueError( - f"Globus client secret is unset. Set {_IRIAPI_GLOBUS_CLIENT_SECRET_ENV}. " - "A Globus Confidential App client is required for automated IRI API auth." - ) + # if not client_secret: + # raise ValueError( + # f"Globus client secret is unset. Set {_IRIAPI_GLOBUS_CLIENT_SECRET_ENV}. " + # "A Globus Confidential App client is required for automated IRI API auth." + # ) token_file_env = os.getenv(_IRIAPI_TOKEN_FILE_ENV) token_file = Path(token_file_env) if token_file_env else DEFAULT_TOKEN_FILE - access_token = get_access_token_confidential( + access_token = get_access_token( client_id=client_id, - client_secret=client_secret, - required_scopes=_IRIAPI_GLOBUS_REQUIRED_SCOPES, - resource_server=_IRIAPI_GLOBUS_RESOURCE_SERVER, + requested_scopes=_IRIAPI_GLOBUS_REQUIRED_SCOPES, token_file=token_file, + force_login=False, ) return httpx.Client( base_url=_API_BASE_URLS[NERSCLoginMethod.IRIAPI], headers={"Authorization": f"Bearer {access_token}"}, + timeout=httpx.Timeout(connect=10.0, read=60.0, write=10.0, pool=10.0), ) @staticmethod @@ -396,12 +404,39 @@ def _submit_job(self, job_script: str) -> str: return str(job.jobid) elif self.login_method is NERSCLoginMethod.IRIAPI: + username = self._get_nersc_username() + pscratch_path = f"/pscratch/sd/{username[0]}/{username}" + + script_body = "\n".join( + line for line in job_script.splitlines() + if not line.startswith("#SBATCH") and not line.startswith("#!/") + ).strip() + + job_spec = { + "executable": "/bin/bash", + "arguments": ["-c", script_body], + "stdout_path": f"{pscratch_path}/tomo_recon_logs/iri_job.out", + "stderr_path": f"{pscratch_path}/tomo_recon_logs/iri_job.err", + "resources": { + "node_count": 1, + "processes_per_node": 1, + "cpu_cores_per_process": 64, + "exclusive_node_use": True, + }, + "attributes": { + "duration": 1800, + "queue_name": "realtime", + "account": "als", + "custom_attributes": {"constraint": "cpu"}, + }, + } + response = self.client.post( - "/api/v1/compute/job/perlmutter", - json={"script": job_script}, + f"/api/v1/compute/job/{_IRI_COMPUTE_RESOURCE}", + json=job_spec, ) response.raise_for_status() - return str(response.json()["job_id"]) + return str(response.json()["id"]) else: raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") @@ -427,13 +462,16 @@ def _wait_for_job(self, job_id: str) -> bool: elif self.login_method is NERSCLoginMethod.IRIAPI: while True: response = self.client.get( - f"/api/v1/compute/status/perlmutter/{job_id}" + f"/api/v1/compute/status/{_IRI_COMPUTE_RESOURCE}/{job_id}" # ← was "perlmutter" ) response.raise_for_status() - state = response.json().get("state") + state = response.json().get("status", {}).get("state") logger.info(f"Job {job_id} state: {state}") - if state in ("COMPLETED", "FAILED", "CANCELLED", "TIMEOUT"): - return state == "COMPLETED" + if state == "completed": + return True + if state in ("failed", "canceled", "timeout"): + logger.error(f"Job {job_id} ended with state: {state}") + return False time.sleep(60) else: diff --git a/orchestration/globus/token.py b/orchestration/globus/token.py index 81b5438f..4970eaa7 100644 --- a/orchestration/globus/token.py +++ b/orchestration/globus/token.py @@ -1,3 +1,4 @@ +# orchestration/globus/token.py import json import logging import os @@ -12,69 +13,20 @@ # Default token file location, matching the Globus SDK convention. DEFAULT_TOKEN_FILE: Path = Path.home() / ".globus" / "auth_tokens.json" -GLOBUS_OIDC_TOKEN_URL: str = "https://auth.globus.org/v2/oauth2/token" +# IRI API Globus scope and resource server. +# The IRI access token lives in other_tokens under this scope, not at the +# top level of the auth.globus.org response. +IRI_SCOPE: str = ( + "https://auth.globus.org/scopes/" + "ed3e577d-f7f3-4639-b96e-ff5a8445d699/iri_api" +) +IRI_RESOURCE_SERVER: str = "ed3e577d-f7f3-4639-b96e-ff5a8445d699" -def get_access_token_confidential( - client_id: str, - client_secret: str, - required_scopes: frozenset[str], - resource_server: str, - token_file: Path | None = None, -) -> str: - """Get a valid Globus access token using a Confidential Client (machine-to-machine). - - No browser or user interaction required. If a valid unexpired token exists - on disk it is reused; otherwise a new one is minted via the client - credentials grant and saved. - - Args: - client_id: Globus Confidential App client ID. - client_secret: Globus Confidential App client secret. - required_scopes: Set of OAuth2 scopes that must be present on the token. - resource_server: Resource server key to extract from the token response. - token_file: Path to the JSON token cache file. Defaults to - ``~/.globus/auth_tokens.json``. - - Returns: - A valid Globus access token string. - - Raises: - RuntimeError: If the acquired token is missing required scopes. - KeyError: If ``access_token`` is absent from the token response. - """ - resolved_token_file = token_file or DEFAULT_TOKEN_FILE - - # 1. Do we already have a valid token? - stored = load_token_file(resolved_token_file) - if stored: - expires_at = stored.get("expires_at_seconds") - if expires_at and time.time() < expires_at: - logger.info("Using cached Globus token (still valid).") - return stored["access_token"] - logger.info("Cached Globus token is expired; minting a new one.") - else: - logger.info("No cached Globus token found; minting a new one.") - - # 2. Mint a new token — same call whether first time or expired. - globus_client = globus_sdk.ConfidentialAppAuthClient(client_id, client_secret) - token_response = globus_client.oauth2_client_credentials_tokens( - requested_scopes=" ".join(sorted(required_scopes)) - ) - auth_data = token_response.by_resource_server[resource_server] - - granted = set(auth_data.get("scope", "").split()) - missing = required_scopes - granted - if missing: - raise RuntimeError( - f"Globus token is missing required scopes: {sorted(missing)}" - ) - - save_token_file(resolved_token_file, auth_data) - logger.info(f"New Globus token saved to {resolved_token_file}.") - - return auth_data["access_token"] +# --------------------------------------------------------------------------- +# File I/O +# --------------------------------------------------------------------------- def load_token_file(token_file: Path) -> dict | None: """Load saved Globus token data from disk. @@ -112,105 +64,345 @@ def save_token_file(token_file: Path, tokens: dict) -> None: os.chmod(token_file, stat.S_IRUSR | stat.S_IWUSR) +def _ensure_private_parent_dir(path: Path) -> None: + """Create parent directories for path with owner-only permissions. + + Args: + path: The file path whose parent directory should be created. + """ + path.parent.mkdir(parents=True, exist_ok=True) + os.chmod(path.parent, 0o700) + + +# --------------------------------------------------------------------------- +# IRI token helpers +# --------------------------------------------------------------------------- + +def _parse_scope_string(scope_string: str) -> set[str]: + """Split a space-separated scope string into a set. + + Args: + scope_string: Space-separated OAuth2 scope string. + + Returns: + Set of individual scope strings. + """ + return set(scope_string.split()) if scope_string else set() + + +def extract_iri_token(token_response_data: dict) -> dict: + """Extract the IRI access token entry from a Globus token response. + + The IRI token is not returned at the top level — it lives inside + ``other_tokens``, identified by :data:`IRI_SCOPE`. + + Args: + token_response_data: Full token response dict as returned by the + Globus SDK (i.e. ``token_response.data``). + + Returns: + Token dict for the IRI resource server. + + Raises: + RuntimeError: If no token matching the IRI scope is found. + """ + for token_data in token_response_data.get("other_tokens", []): + if IRI_SCOPE in _parse_scope_string(token_data.get("scope", "")): + return token_data + raise RuntimeError( + f"Missing token for required IRI scope: {IRI_SCOPE}. " + "Re-run with --force-login and ensure consent is granted for the IRI scope." + ) + + +def _replace_iri_token(token_response_data: dict, iri_token_data: dict) -> dict: + """Return a copy of token_response_data with the IRI entry replaced. + + Args: + token_response_data: Full stored token response dict. + iri_token_data: Updated IRI token dict to splice in. + + Returns: + Updated token response dict. + """ + merged = dict(token_response_data) + other_tokens = list(merged.get("other_tokens", [])) + for i, token_data in enumerate(other_tokens): + if IRI_SCOPE in _parse_scope_string(token_data.get("scope", "")): + other_tokens[i] = iri_token_data + break + else: + other_tokens.append(iri_token_data) + merged["other_tokens"] = other_tokens + return merged + + +def _get_iri_refresh_token(stored_tokens: dict) -> str | None: + """Extract the IRI refresh token from stored token data, if present. + + Args: + stored_tokens: Full stored token response dict. + + Returns: + The IRI refresh token string, or None if absent. + """ + try: + return extract_iri_token(stored_tokens).get("refresh_token") + except RuntimeError: + return None + + +def _get_auth_refresh_token(stored_tokens: dict) -> str | None: + """Extract the top-level Globus Auth refresh token from stored data. + + Args: + stored_tokens: Full stored token response dict. + + Returns: + The auth refresh token string, or None if absent. + """ + if "refresh_token" in stored_tokens: + return stored_tokens["refresh_token"] + auth_tokens = stored_tokens.get("auth.globus.org") + if isinstance(auth_tokens, dict): + return auth_tokens.get("refresh_token") + return None + + +# --------------------------------------------------------------------------- +# NativeApp flow (interactive) +# --------------------------------------------------------------------------- + def interactive_login( client: globus_sdk.NativeAppAuthClient, - required_scopes: frozenset[str], - resource_server: str, + requested_scopes: frozenset[str], + prompt_login: bool = False, ) -> dict: """Run an interactive browser-based Globus login flow. Prints an authorization URL, waits for the user to paste an auth code, - and exchanges it for tokens. + and returns the full token response data including ``other_tokens``. Args: client: Globus NativeAppAuthClient to drive the flow. - required_scopes: Set of OAuth2 scopes to request. - resource_server: Resource server key to extract from the token response - (e.g. ``"auth.globus.org"``). + requested_scopes: Set of OAuth2 scopes to request. Should include + :data:`IRI_SCOPE` to obtain an IRI API token. + prompt_login: If True, add ``prompt=login`` to the authorize URL to + force a fresh identity-provider login. Returns: - Token dict for the given resource server. + Full token response dict (``token_response.data``), including + ``other_tokens``. + + Raises: + RuntimeError: If no authorization code is entered, or if the code + exchange fails. """ client.oauth2_start_flow( - requested_scopes=" ".join(sorted(required_scopes)), + requested_scopes=" ".join(sorted(requested_scopes)), refresh_tokens=True, ) logger.info("Open this URL in your browser to authenticate with Globus:") - logger.info(client.oauth2_get_authorize_url()) + prompt = "login" if prompt_login else globus_sdk.MISSING + logger.info(client.oauth2_get_authorize_url(prompt=prompt)) code = input("\nEnter authorization code: ").strip() - token_response = client.oauth2_exchange_code_for_tokens(code) - return token_response.by_resource_server[resource_server] + if not code: + raise RuntimeError( + "No authorization code entered. Re-run the script and paste the " + "code shown by Globus after login." + ) + try: + token_response = client.oauth2_exchange_code_for_tokens(code) + except GlobusAPIError as e: + if e.http_status == 400: + raise RuntimeError( + "Authorization code exchange failed — the code was empty, " + "invalid, expired, or already used. Re-run and try again." + ) from e + raise RuntimeError( + f"Authorization code exchange failed with HTTP {e.http_status}." + ) from e + return token_response.data -def refresh_tokens( +def _refresh_single_token( client: globus_sdk.NativeAppAuthClient, refresh_token: str, - resource_server: str, ) -> dict | None: - """Attempt a silent Globus token refresh. + """Attempt a single Globus token refresh, returning raw response data. Args: - client: Globus NativeAppAuthClient to drive the refresh. + client: NativeAppAuthClient to drive the refresh. refresh_token: The stored refresh token. - resource_server: Resource server key to extract from the token response. Returns: - Fresh token dict for the given resource server, or None if refresh failed. + Raw token response data dict, or None if the refresh failed. """ try: token_response = client.oauth2_refresh_token(refresh_token) - return token_response.by_resource_server[resource_server] + return token_response.data except GlobusAPIError as e: logger.warning( f"Globus token refresh failed ({e.http_status}); " - "falling back to interactive login." + "will fall back to interactive login." ) return None +def _refresh_stored_tokens( + client: globus_sdk.NativeAppAuthClient, + stored_tokens: dict, +) -> tuple[dict | None, bool]: + """Try to refresh stored tokens, preferring the IRI refresh token. + + Attempts the IRI-specific refresh token first, then falls back to the + top-level Globus Auth refresh token. + + Args: + client: NativeAppAuthClient to drive the refresh. + stored_tokens: Full stored token response dict. + + Returns: + Tuple of ``(updated_token_data, success)``. On failure both values + are ``(None, False)``. + """ + iri_refresh = _get_iri_refresh_token(stored_tokens) + if iri_refresh: + iri_token_data = _refresh_single_token(client, iri_refresh) + if iri_token_data is not None: + return _replace_iri_token(stored_tokens, iri_token_data), True + + auth_refresh = _get_auth_refresh_token(stored_tokens) + if auth_refresh: + auth_data = _refresh_single_token(client, auth_refresh) + if auth_data is not None: + return auth_data, True + + return None, False + + def get_access_token( client_id: str, - required_scopes: frozenset[str], - resource_server: str, + requested_scopes: frozenset[str], token_file: Path | None = None, force_login: bool = False, + prompt_login: bool = False, ) -> str: - """Get a valid Globus access token, refreshing or logging in as needed. + """Get a valid IRI API access token via the NativeApp interactive flow. Attempts a silent refresh from the saved token file first. Falls back to interactive browser login if no saved tokens exist, the refresh token is absent, or the refresh fails. Saves the resulting tokens back to disk. + The IRI token is extracted from ``other_tokens`` in the response — it is + not the top-level Globus Auth token. + Args: client_id: Globus NativeApp client ID. - required_scopes: Set of OAuth2 scopes that must be present on the token. - resource_server: Resource server key to extract from the token response. + requested_scopes: Set of OAuth2 scopes to request. Must include + :data:`IRI_SCOPE` to obtain a usable IRI API token. token_file: Path to the JSON token file. Defaults to ``~/.globus/auth_tokens.json``. force_login: If True, skip refresh and force interactive login. + prompt_login: If True, add ``prompt=login`` to the authorize URL. Returns: - A valid Globus access token string. + A valid IRI API access token string. Raises: - RuntimeError: If the acquired token is missing required scopes. - KeyError: If ``access_token`` is absent from the token response. + RuntimeError: If the IRI scope token is missing from the response. """ resolved_token_file = token_file or DEFAULT_TOKEN_FILE globus_client = globus_sdk.NativeAppAuthClient(client_id) - auth_data: dict | None = None + token_response_data: dict | None = None + used_refresh = False if not force_login: stored = load_token_file(resolved_token_file) - if stored and stored.get("refresh_token"): - auth_data = refresh_tokens( - globus_client, stored["refresh_token"], resource_server + if stored: + token_response_data, used_refresh = _refresh_stored_tokens( + globus_client, stored ) - if auth_data is None: + if token_response_data is None: logger.info("Initiating interactive Globus login.") - auth_data = interactive_login(globus_client, required_scopes, resource_server) + token_response_data = interactive_login( + globus_client, requested_scopes, prompt_login=prompt_login + ) + + # Extract IRI token — if a refresh ran but didn't return the IRI token, + # fall back to interactive login before raising. + try: + iri_token = extract_iri_token(token_response_data) + except RuntimeError: + if used_refresh: + logger.warning( + "Refreshed tokens did not include the IRI token; " + "falling back to interactive login." + ) + token_response_data = interactive_login( + globus_client, requested_scopes, prompt_login=prompt_login + ) + iri_token = extract_iri_token(token_response_data) + else: + raise + + save_token_file(resolved_token_file, token_response_data) + logger.info(f"Globus token saved to {resolved_token_file}.") + + return iri_token["access_token"] + + +# --------------------------------------------------------------------------- +# Confidential Client flow (machine-to-machine) +# --------------------------------------------------------------------------- + +def get_access_token_confidential( + client_id: str, + client_secret: str, + required_scopes: frozenset[str], + resource_server: str, + token_file: Path | None = None, +) -> str: + """Get a valid Globus access token using a Confidential Client. + + No browser or user interaction required. If a valid unexpired token exists + on disk it is reused; otherwise a new one is minted via the client + credentials grant and saved. + + Args: + client_id: Globus Confidential App client ID. + client_secret: Globus Confidential App client secret. + required_scopes: Set of OAuth2 scopes that must be present on the token. + resource_server: Resource server key to extract from the token response. + token_file: Path to the JSON token cache file. Defaults to + ``~/.globus/auth_tokens.json``. + + Returns: + A valid Globus access token string. + + Raises: + RuntimeError: If the acquired token is missing required scopes. + KeyError: If ``access_token`` is absent from the token response. + """ + resolved_token_file = token_file or DEFAULT_TOKEN_FILE + + stored = load_token_file(resolved_token_file) + if stored: + expires_at = stored.get("expires_at_seconds") + if expires_at and time.time() < expires_at: + logger.info("Using cached Globus token (still valid).") + return stored["access_token"] + logger.info("Cached Globus token is expired; minting a new one.") + else: + logger.info("No cached Globus token found; minting a new one.") + + globus_client = globus_sdk.ConfidentialAppAuthClient(client_id, client_secret) + token_response = globus_client.oauth2_client_credentials_tokens( + requested_scopes=" ".join(sorted(required_scopes)) + ) + auth_data = token_response.by_resource_server[resource_server] granted = set(auth_data.get("scope", "").split()) missing = required_scopes - granted @@ -220,16 +412,6 @@ def get_access_token( ) save_token_file(resolved_token_file, auth_data) - logger.info(f"Globus token saved to {resolved_token_file}.") + logger.info(f"New Globus token saved to {resolved_token_file}.") return auth_data["access_token"] - - -def _ensure_private_parent_dir(path: Path) -> None: - """Create parent directories for path with owner-only permissions. - - Args: - path: The file path whose parent directory should be created. - """ - path.parent.mkdir(parents=True, exist_ok=True) - os.chmod(path.parent, 0o700) diff --git a/scripts/get_globus_token.py b/scripts/get_globus_token.py new file mode 100644 index 00000000..6b615378 --- /dev/null +++ b/scripts/get_globus_token.py @@ -0,0 +1,337 @@ +#!/usr/bin/env python3 +import argparse +import json +import os +import stat +import time +import urllib.error +import urllib.request +from pathlib import Path + +import globus_sdk +from globus_sdk.exc import GlobusAPIError + +CLIENT_ID = "fae5c579-490a-4d76-b6eb-d78f65caeb63" +RESOURCE_SERVER = "auth.globus.org" +IRI_SCOPE = ( + "https://auth.globus.org/scopes/" + "ed3e577d-f7f3-4639-b96e-ff5a8445d699/iri_api" +) +REQUIRED_SCOPES = { + "openid", + "profile", + "email", + "urn:globus:auth:scope:auth.globus.org:view_identities", +} +REQUESTED_SCOPES = REQUIRED_SCOPES | {IRI_SCOPE} +DEFAULT_IRI_VALIDATE_URL = "https://api.iri.nersc.gov/api/v1/account/projects" + + +def parse_args() -> argparse.Namespace: + default_token_file = Path.home() / ".globus" / "auth_tokens.json" + parser = argparse.ArgumentParser( + description=( + "Get Globus Auth tokens with required scopes. " + "Tokens are saved to a secure local file by default." + ) + ) + parser.add_argument( + "--token-file", + type=Path, + default=default_token_file, + help=f"Path for saved token JSON (default: {default_token_file})", + ) + parser.add_argument( + "--print-token", + action="store_true", + help="Print the access token to stdout (off by default).", + ) + parser.add_argument( + "--force-login", + action="store_true", + help="Skip refresh and force interactive browser login.", + ) + parser.add_argument( + "--refresh-only", + action="store_true", + help="Refresh saved tokens only; do not fall back to interactive login.", + ) + parser.add_argument( + "--prompt-login", + action="store_true", + help="Add prompt=login to the Globus authorize URL to force re-authentication.", + ) + parser.add_argument( + "--validate-iri", + action="store_true", + help="Validate the IRI token by calling the IRI account/projects endpoint.", + ) + parser.add_argument( + "--iri-validate-url", + default=DEFAULT_IRI_VALIDATE_URL, + help=( + "IRI endpoint used by --validate-iri " + f"(default: {DEFAULT_IRI_VALIDATE_URL})" + ), + ) + return parser.parse_args() + + +def parse_scope_string(scope_string: str) -> set[str]: + return set(scope_string.split()) if scope_string else set() + + +def ensure_private_parent_dir(path: Path) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + os.chmod(path.parent, 0o700) + + +def load_tokens(token_file: Path) -> dict | None: + if not token_file.exists(): + return None + with token_file.open("r", encoding="utf-8") as f: + return json.load(f) + + +def save_tokens(token_file: Path, tokens: dict) -> None: + ensure_private_parent_dir(token_file) + tmp = token_file.with_suffix(".tmp") + with os.fdopen( + os.open(tmp, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600), + "w", + encoding="utf-8", + ) as f: + json.dump(tokens, f, indent=2) + os.replace(tmp, token_file) + os.chmod(token_file, stat.S_IRUSR | stat.S_IWUSR) + + +def get_refresh_token(stored_tokens: dict) -> str | None: + if "refresh_token" in stored_tokens: + return stored_tokens.get("refresh_token") + + auth_tokens = stored_tokens.get(RESOURCE_SERVER) + if isinstance(auth_tokens, dict): + return auth_tokens.get("refresh_token") + + return None + + +def get_iri_token(token_response_data: dict) -> dict: + for token_data in token_response_data.get("other_tokens", []): + if IRI_SCOPE in parse_scope_string(token_data.get("scope", "")): + return token_data + raise RuntimeError(f"Missing token for required IRI scope: {IRI_SCOPE}") + + +def get_iri_refresh_token(stored_tokens: dict) -> str | None: + try: + return get_iri_token(stored_tokens).get("refresh_token") + except RuntimeError: + return None + + +def replace_iri_token(token_response_data: dict, iri_token_data: dict) -> dict: + merged = dict(token_response_data) + other_tokens = list(merged.get("other_tokens", [])) + for index, token_data in enumerate(other_tokens): + if IRI_SCOPE in parse_scope_string(token_data.get("scope", "")): + other_tokens[index] = iri_token_data + break + else: + other_tokens.append(iri_token_data) + merged["other_tokens"] = other_tokens + return merged + + +def validate_auth_data(auth_data: dict) -> dict: + if auth_data.get("resource_server") != RESOURCE_SERVER: + raise RuntimeError( + f"Missing token for required resource server: {RESOURCE_SERVER}" + ) + + granted = parse_scope_string(auth_data.get("scope", "")) + missing = REQUIRED_SCOPES - granted + if missing: + raise RuntimeError(f"Missing required scopes: {sorted(missing)}") + + return get_iri_token(auth_data) + + +def validate_iri_token(iri_token_data: dict, validate_url: str) -> dict | list: + request = urllib.request.Request( + validate_url, + headers={ + "accept": "application/json", + "Authorization": f"Bearer {iri_token_data['access_token']}", + }, + method="GET", + ) + try: + with urllib.request.urlopen(request) as response: + body = response.read().decode("utf-8") + data = json.loads(body) if body else {} + except urllib.error.HTTPError as exc: + body = exc.read().decode("utf-8") + details = body.strip() or exc.reason + raise RuntimeError( + f"IRI validation failed with HTTP {exc.code} from {validate_url}: {details}" + ) from exc + except urllib.error.URLError as exc: + raise RuntimeError( + f"IRI validation request failed for {validate_url}: {exc.reason}" + ) from exc + except json.JSONDecodeError as exc: + raise RuntimeError( + f"IRI validation returned non-JSON data from {validate_url}" + ) from exc + + if isinstance(data, dict): + session_info = data.get("session_info") + if isinstance(session_info, dict): + authentications = session_info.get("authentications") + if isinstance(authentications, dict) and not authentications: + raise RuntimeError( + "IRI validation succeeded but session_info.authentications is empty. " + "Re-run with --force-login --prompt-login and use a Chrome incognito window." + ) + + return data + + +def interactive_login( + client: globus_sdk.NativeAppAuthClient, + *, + prompt_login: bool = False, +) -> dict: + client.oauth2_start_flow( + requested_scopes=" ".join(sorted(REQUESTED_SCOPES)), + refresh_tokens=True, + ) + print("Open this URL, login, and consent:") + prompt = "login" if prompt_login else globus_sdk.MISSING + print(client.oauth2_get_authorize_url(prompt=prompt)) + code = input("\nEnter authorization code: ").strip() + if not code: + raise RuntimeError( + "No authorization code entered. Re-run the script and paste the code " + "shown by Globus after login." + ) + try: + token_response = client.oauth2_exchange_code_for_tokens(code) + except GlobusAPIError as exc: + if exc.http_status == 400: + raise RuntimeError( + "Authorization code exchange failed. The code was empty, invalid, " + "expired, or already used. Re-run the script and complete the " + "Globus login flow again." + ) from exc + raise RuntimeError( + f"Authorization code exchange failed with HTTP {exc.http_status}. " + "Re-run the script and try again." + ) from exc + return token_response.data + + +def refresh_tokens( + client: globus_sdk.NativeAppAuthClient, refresh_token: str +) -> dict | None: + try: + token_response = client.oauth2_refresh_token(refresh_token) + return token_response.data + except GlobusAPIError as exc: + print( + f"Refresh failed ({exc.http_status}); switching to interactive login." + ) + return None + + +def refresh_stored_tokens( + client: globus_sdk.NativeAppAuthClient, stored_tokens: dict +) -> tuple[dict | None, bool]: + iri_refresh_token = get_iri_refresh_token(stored_tokens) + if iri_refresh_token: + iri_token_data = refresh_tokens(client, iri_refresh_token) + if iri_token_data is not None: + return replace_iri_token(stored_tokens, iri_token_data), True + + auth_refresh_token = get_refresh_token(stored_tokens) + if auth_refresh_token: + auth_data = refresh_tokens(client, auth_refresh_token) + if auth_data is not None: + return auth_data, True + + return None, False + + +def main() -> None: + args = parse_args() + if args.force_login and args.refresh_only: + raise RuntimeError("Choose only one of --force-login or --refresh-only") + + client = globus_sdk.NativeAppAuthClient(CLIENT_ID) + + auth_data = None + used_refresh = False + if not args.force_login: + stored = load_tokens(args.token_file) + if stored: + auth_data, used_refresh = refresh_stored_tokens(client, stored) + + if auth_data is None: + if args.refresh_only: + raise RuntimeError( + "Refresh-only mode failed. No usable saved refresh token was found " + "or token refresh did not return the required IRI token." + ) + auth_data = interactive_login(client, prompt_login=args.prompt_login) + + try: + iri_token_data = validate_auth_data(auth_data) + except RuntimeError as exc: + if used_refresh and "Missing token for required IRI scope" in str(exc): + print( + "Refreshed tokens did not include the IRI token; " + "switching to interactive login." + ) + auth_data = interactive_login(client, prompt_login=args.prompt_login) + iri_token_data = validate_auth_data(auth_data) + else: + raise + + save_tokens(args.token_file, auth_data) + + if args.validate_iri: + validation_data = validate_iri_token(iri_token_data, args.iri_validate_url) + print(f"IRI validation succeeded against {args.iri_validate_url}") + if isinstance(validation_data, dict): + session_info = validation_data.get("session_info") + if isinstance(session_info, dict): + session_id = session_info.get("session_id") + if session_id: + print(f"IRI session_id: {session_id}") + elif isinstance(validation_data, list): + print(f"IRI validation response items: {len(validation_data)}") + + expires_at = iri_token_data.get("expires_at_seconds") + if expires_at: + ttl = int(expires_at - time.time()) + print(f"\nIRI access token valid for ~{max(ttl, 0)} seconds.") + + print(f"Saved token data to {args.token_file}") + print(f"Granted Globus Auth scopes: {auth_data.get('scope', '')}") + print(f"IRI token resource server: {iri_token_data.get('resource_server')}") + print(f"IRI token scopes: {iri_token_data.get('scope', '')}") + + if args.print_token: + print("\nIRI access token:") + print(iri_token_data["access_token"]) + else: + print( + "IRI access token not printed " + "(use --print-token to display it for the NERSC IRI API)." + ) + + +if __name__ == "__main__": + main() From 6b8c843424bdfd4995416f5ef8cc32bb69a3a467 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 1 Apr 2026 16:21:14 -0700 Subject: [PATCH 08/29] removing token.py and moving the logic to get_globus_token.py --- orchestration/globus/token.py | 417 ---------------------------------- 1 file changed, 417 deletions(-) delete mode 100644 orchestration/globus/token.py diff --git a/orchestration/globus/token.py b/orchestration/globus/token.py deleted file mode 100644 index 4970eaa7..00000000 --- a/orchestration/globus/token.py +++ /dev/null @@ -1,417 +0,0 @@ -# orchestration/globus/token.py -import json -import logging -import os -from pathlib import Path -import stat -import time - -import globus_sdk -from globus_sdk.exc import GlobusAPIError - -logger = logging.getLogger(__name__) - -# Default token file location, matching the Globus SDK convention. -DEFAULT_TOKEN_FILE: Path = Path.home() / ".globus" / "auth_tokens.json" - -# IRI API Globus scope and resource server. -# The IRI access token lives in other_tokens under this scope, not at the -# top level of the auth.globus.org response. -IRI_SCOPE: str = ( - "https://auth.globus.org/scopes/" - "ed3e577d-f7f3-4639-b96e-ff5a8445d699/iri_api" -) -IRI_RESOURCE_SERVER: str = "ed3e577d-f7f3-4639-b96e-ff5a8445d699" - - -# --------------------------------------------------------------------------- -# File I/O -# --------------------------------------------------------------------------- - -def load_token_file(token_file: Path) -> dict | None: - """Load saved Globus token data from disk. - - Args: - token_file: Path to the JSON token file. - - Returns: - Parsed token dict, or None if the file does not exist. - """ - if not token_file.exists(): - return None - with token_file.open("r", encoding="utf-8") as f: - return json.load(f) - - -def save_token_file(token_file: Path, tokens: dict) -> None: - """Atomically save Globus token data to disk with owner-only permissions. - - Writes to a temporary file then renames to avoid partial writes. - - Args: - token_file: Destination path for the JSON token file. - tokens: Token dict to serialise. - """ - _ensure_private_parent_dir(token_file) - tmp = token_file.with_suffix(".tmp") - with os.fdopen( - os.open(tmp, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600), - "w", - encoding="utf-8", - ) as f: - json.dump(tokens, f, indent=2) - os.replace(tmp, token_file) - os.chmod(token_file, stat.S_IRUSR | stat.S_IWUSR) - - -def _ensure_private_parent_dir(path: Path) -> None: - """Create parent directories for path with owner-only permissions. - - Args: - path: The file path whose parent directory should be created. - """ - path.parent.mkdir(parents=True, exist_ok=True) - os.chmod(path.parent, 0o700) - - -# --------------------------------------------------------------------------- -# IRI token helpers -# --------------------------------------------------------------------------- - -def _parse_scope_string(scope_string: str) -> set[str]: - """Split a space-separated scope string into a set. - - Args: - scope_string: Space-separated OAuth2 scope string. - - Returns: - Set of individual scope strings. - """ - return set(scope_string.split()) if scope_string else set() - - -def extract_iri_token(token_response_data: dict) -> dict: - """Extract the IRI access token entry from a Globus token response. - - The IRI token is not returned at the top level — it lives inside - ``other_tokens``, identified by :data:`IRI_SCOPE`. - - Args: - token_response_data: Full token response dict as returned by the - Globus SDK (i.e. ``token_response.data``). - - Returns: - Token dict for the IRI resource server. - - Raises: - RuntimeError: If no token matching the IRI scope is found. - """ - for token_data in token_response_data.get("other_tokens", []): - if IRI_SCOPE in _parse_scope_string(token_data.get("scope", "")): - return token_data - raise RuntimeError( - f"Missing token for required IRI scope: {IRI_SCOPE}. " - "Re-run with --force-login and ensure consent is granted for the IRI scope." - ) - - -def _replace_iri_token(token_response_data: dict, iri_token_data: dict) -> dict: - """Return a copy of token_response_data with the IRI entry replaced. - - Args: - token_response_data: Full stored token response dict. - iri_token_data: Updated IRI token dict to splice in. - - Returns: - Updated token response dict. - """ - merged = dict(token_response_data) - other_tokens = list(merged.get("other_tokens", [])) - for i, token_data in enumerate(other_tokens): - if IRI_SCOPE in _parse_scope_string(token_data.get("scope", "")): - other_tokens[i] = iri_token_data - break - else: - other_tokens.append(iri_token_data) - merged["other_tokens"] = other_tokens - return merged - - -def _get_iri_refresh_token(stored_tokens: dict) -> str | None: - """Extract the IRI refresh token from stored token data, if present. - - Args: - stored_tokens: Full stored token response dict. - - Returns: - The IRI refresh token string, or None if absent. - """ - try: - return extract_iri_token(stored_tokens).get("refresh_token") - except RuntimeError: - return None - - -def _get_auth_refresh_token(stored_tokens: dict) -> str | None: - """Extract the top-level Globus Auth refresh token from stored data. - - Args: - stored_tokens: Full stored token response dict. - - Returns: - The auth refresh token string, or None if absent. - """ - if "refresh_token" in stored_tokens: - return stored_tokens["refresh_token"] - auth_tokens = stored_tokens.get("auth.globus.org") - if isinstance(auth_tokens, dict): - return auth_tokens.get("refresh_token") - return None - - -# --------------------------------------------------------------------------- -# NativeApp flow (interactive) -# --------------------------------------------------------------------------- - -def interactive_login( - client: globus_sdk.NativeAppAuthClient, - requested_scopes: frozenset[str], - prompt_login: bool = False, -) -> dict: - """Run an interactive browser-based Globus login flow. - - Prints an authorization URL, waits for the user to paste an auth code, - and returns the full token response data including ``other_tokens``. - - Args: - client: Globus NativeAppAuthClient to drive the flow. - requested_scopes: Set of OAuth2 scopes to request. Should include - :data:`IRI_SCOPE` to obtain an IRI API token. - prompt_login: If True, add ``prompt=login`` to the authorize URL to - force a fresh identity-provider login. - - Returns: - Full token response dict (``token_response.data``), including - ``other_tokens``. - - Raises: - RuntimeError: If no authorization code is entered, or if the code - exchange fails. - """ - client.oauth2_start_flow( - requested_scopes=" ".join(sorted(requested_scopes)), - refresh_tokens=True, - ) - logger.info("Open this URL in your browser to authenticate with Globus:") - prompt = "login" if prompt_login else globus_sdk.MISSING - logger.info(client.oauth2_get_authorize_url(prompt=prompt)) - code = input("\nEnter authorization code: ").strip() - if not code: - raise RuntimeError( - "No authorization code entered. Re-run the script and paste the " - "code shown by Globus after login." - ) - try: - token_response = client.oauth2_exchange_code_for_tokens(code) - except GlobusAPIError as e: - if e.http_status == 400: - raise RuntimeError( - "Authorization code exchange failed — the code was empty, " - "invalid, expired, or already used. Re-run and try again." - ) from e - raise RuntimeError( - f"Authorization code exchange failed with HTTP {e.http_status}." - ) from e - return token_response.data - - -def _refresh_single_token( - client: globus_sdk.NativeAppAuthClient, - refresh_token: str, -) -> dict | None: - """Attempt a single Globus token refresh, returning raw response data. - - Args: - client: NativeAppAuthClient to drive the refresh. - refresh_token: The stored refresh token. - - Returns: - Raw token response data dict, or None if the refresh failed. - """ - try: - token_response = client.oauth2_refresh_token(refresh_token) - return token_response.data - except GlobusAPIError as e: - logger.warning( - f"Globus token refresh failed ({e.http_status}); " - "will fall back to interactive login." - ) - return None - - -def _refresh_stored_tokens( - client: globus_sdk.NativeAppAuthClient, - stored_tokens: dict, -) -> tuple[dict | None, bool]: - """Try to refresh stored tokens, preferring the IRI refresh token. - - Attempts the IRI-specific refresh token first, then falls back to the - top-level Globus Auth refresh token. - - Args: - client: NativeAppAuthClient to drive the refresh. - stored_tokens: Full stored token response dict. - - Returns: - Tuple of ``(updated_token_data, success)``. On failure both values - are ``(None, False)``. - """ - iri_refresh = _get_iri_refresh_token(stored_tokens) - if iri_refresh: - iri_token_data = _refresh_single_token(client, iri_refresh) - if iri_token_data is not None: - return _replace_iri_token(stored_tokens, iri_token_data), True - - auth_refresh = _get_auth_refresh_token(stored_tokens) - if auth_refresh: - auth_data = _refresh_single_token(client, auth_refresh) - if auth_data is not None: - return auth_data, True - - return None, False - - -def get_access_token( - client_id: str, - requested_scopes: frozenset[str], - token_file: Path | None = None, - force_login: bool = False, - prompt_login: bool = False, -) -> str: - """Get a valid IRI API access token via the NativeApp interactive flow. - - Attempts a silent refresh from the saved token file first. Falls back to - interactive browser login if no saved tokens exist, the refresh token is - absent, or the refresh fails. Saves the resulting tokens back to disk. - - The IRI token is extracted from ``other_tokens`` in the response — it is - not the top-level Globus Auth token. - - Args: - client_id: Globus NativeApp client ID. - requested_scopes: Set of OAuth2 scopes to request. Must include - :data:`IRI_SCOPE` to obtain a usable IRI API token. - token_file: Path to the JSON token file. Defaults to - ``~/.globus/auth_tokens.json``. - force_login: If True, skip refresh and force interactive login. - prompt_login: If True, add ``prompt=login`` to the authorize URL. - - Returns: - A valid IRI API access token string. - - Raises: - RuntimeError: If the IRI scope token is missing from the response. - """ - resolved_token_file = token_file or DEFAULT_TOKEN_FILE - globus_client = globus_sdk.NativeAppAuthClient(client_id) - - token_response_data: dict | None = None - used_refresh = False - - if not force_login: - stored = load_token_file(resolved_token_file) - if stored: - token_response_data, used_refresh = _refresh_stored_tokens( - globus_client, stored - ) - - if token_response_data is None: - logger.info("Initiating interactive Globus login.") - token_response_data = interactive_login( - globus_client, requested_scopes, prompt_login=prompt_login - ) - - # Extract IRI token — if a refresh ran but didn't return the IRI token, - # fall back to interactive login before raising. - try: - iri_token = extract_iri_token(token_response_data) - except RuntimeError: - if used_refresh: - logger.warning( - "Refreshed tokens did not include the IRI token; " - "falling back to interactive login." - ) - token_response_data = interactive_login( - globus_client, requested_scopes, prompt_login=prompt_login - ) - iri_token = extract_iri_token(token_response_data) - else: - raise - - save_token_file(resolved_token_file, token_response_data) - logger.info(f"Globus token saved to {resolved_token_file}.") - - return iri_token["access_token"] - - -# --------------------------------------------------------------------------- -# Confidential Client flow (machine-to-machine) -# --------------------------------------------------------------------------- - -def get_access_token_confidential( - client_id: str, - client_secret: str, - required_scopes: frozenset[str], - resource_server: str, - token_file: Path | None = None, -) -> str: - """Get a valid Globus access token using a Confidential Client. - - No browser or user interaction required. If a valid unexpired token exists - on disk it is reused; otherwise a new one is minted via the client - credentials grant and saved. - - Args: - client_id: Globus Confidential App client ID. - client_secret: Globus Confidential App client secret. - required_scopes: Set of OAuth2 scopes that must be present on the token. - resource_server: Resource server key to extract from the token response. - token_file: Path to the JSON token cache file. Defaults to - ``~/.globus/auth_tokens.json``. - - Returns: - A valid Globus access token string. - - Raises: - RuntimeError: If the acquired token is missing required scopes. - KeyError: If ``access_token`` is absent from the token response. - """ - resolved_token_file = token_file or DEFAULT_TOKEN_FILE - - stored = load_token_file(resolved_token_file) - if stored: - expires_at = stored.get("expires_at_seconds") - if expires_at and time.time() < expires_at: - logger.info("Using cached Globus token (still valid).") - return stored["access_token"] - logger.info("Cached Globus token is expired; minting a new one.") - else: - logger.info("No cached Globus token found; minting a new one.") - - globus_client = globus_sdk.ConfidentialAppAuthClient(client_id, client_secret) - token_response = globus_client.oauth2_client_credentials_tokens( - requested_scopes=" ".join(sorted(required_scopes)) - ) - auth_data = token_response.by_resource_server[resource_server] - - granted = set(auth_data.get("scope", "").split()) - missing = required_scopes - granted - if missing: - raise RuntimeError( - f"Globus token is missing required scopes: {sorted(missing)}" - ) - - save_token_file(resolved_token_file, auth_data) - logger.info(f"New Globus token saved to {resolved_token_file}.") - - return auth_data["access_token"] From 27ea5b2f0a5cc97e9d901507cbc8e35f1283f2f1 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 1 Apr 2026 16:21:55 -0700 Subject: [PATCH 09/29] moving get_globus_token.py to orchestration/globus/ to be used as a module --- .../globus}/get_globus_token.py | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) rename {scripts => orchestration/globus}/get_globus_token.py (84%) diff --git a/scripts/get_globus_token.py b/orchestration/globus/get_globus_token.py similarity index 84% rename from scripts/get_globus_token.py rename to orchestration/globus/get_globus_token.py index 6b615378..c47057e8 100644 --- a/scripts/get_globus_token.py +++ b/orchestration/globus/get_globus_token.py @@ -11,6 +11,7 @@ import globus_sdk from globus_sdk.exc import GlobusAPIError +DEFAULT_TOKEN_FILE: Path = Path.home() / ".globus" / "auth_tokens.json" CLIENT_ID = "fae5c579-490a-4d76-b6eb-d78f65caeb63" RESOURCE_SERVER = "auth.globus.org" IRI_SCOPE = ( @@ -264,6 +265,52 @@ def refresh_stored_tokens( return None, False +def get_iri_access_token( + token_file: Path = DEFAULT_TOKEN_FILE, + force_login: bool = False, + prompt_login: bool = False, +) -> str: + """ + Get a valid IRI access token, refreshing or prompting for login as needed. + Tokens are saved to the specified token_file path (default: ~/.globus/auth_tokens.json). + By default, the function will attempt to refresh saved tokens before falling back + to interactive login. Use force_login=True to skip refresh and require interactive login. + Use prompt_login=True to add prompt=login to the authorization URL, which forces + re-authentication even if the user has an active Globus session in their browser. + + Args: + token_file: Path to save and load token data (default: ~/.globus/auth_tokens.json) + force_login: If True, skip token refresh and require interactive login + prompt_login: If True, add prompt=login to the authorization URL to force re-authentication + + Returns: + A valid IRI access token string with the required scopes. + + Raises: + RuntimeError: If token refresh fails and interactive login is not allowed or fails, + or if the resulting tokens do not include a valid IRI access token. + """ + client = globus_sdk.NativeAppAuthClient(CLIENT_ID) + auth_data = None + used_refresh = False + if not force_login: + stored = load_tokens(token_file) + if stored: + auth_data, used_refresh = refresh_stored_tokens(client, stored) + if auth_data is None: + auth_data = interactive_login(client, prompt_login=prompt_login) + try: + iri_token_data = validate_auth_data(auth_data) + except RuntimeError as exc: + if used_refresh and "Missing token for required IRI scope" in str(exc): + auth_data = interactive_login(client, prompt_login=prompt_login) + iri_token_data = validate_auth_data(auth_data) + else: + raise + save_tokens(token_file, auth_data) + return iri_token_data["access_token"] + + def main() -> None: args = parse_args() if args.force_login and args.refresh_only: From bad1db503eed89a669ffc075d9e89a3bf0cbbf9c Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 1 Apr 2026 16:22:56 -0700 Subject: [PATCH 10/29] Cleaning up nersc.py --- orchestration/flows/bl832/nersc.py | 115 ++++++++++++++--------------- 1 file changed, 55 insertions(+), 60 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 2aad35de..629fac1b 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -23,10 +23,9 @@ NerscStreamingMixin, SlurmJobBlock, cancellation_hook, monitor_streaming_job, save_block ) from orchestration.mlflow import get_checkpoint_info -from orchestration.globus.token import ( - get_access_token, +from orchestration.globus.get_globus_token import ( + get_iri_access_token, DEFAULT_TOKEN_FILE, - IRI_SCOPE, ) from orchestration.prefect import schedule_prefect_flow from orchestration.prune_controller import get_prune_controller, PruneMethod @@ -39,17 +38,7 @@ # Applies only to NERSCLoginMethod.IRIAPI _IRIAPI_GLOBUS_CLIENT_ID_ENV: str = "GLOBUS_CLIENT_ID" _IRI_COMPUTE_RESOURCE: str = "compute" -_IRI_SCRATCH_RESOURCE: str = "scratch" -# _IRIAPI_GLOBUS_CLIENT_SECRET_ENV: str = "GLOBUS_CLIENT_SECRET" # set → confidential client _IRIAPI_TOKEN_FILE_ENV: str = "PATH_GLOBUS_TOKEN_FILE" -_IRIAPI_GLOBUS_RESOURCE_SERVER: str = "auth.globus.org" -_IRIAPI_GLOBUS_REQUIRED_SCOPES: frozenset[str] = frozenset({ - "openid", - "profile", - "email", - "urn:globus:auth:scope:auth.globus.org:view_identities", - IRI_SCOPE, -}) _API_BASE_URLS: dict[NERSCLoginMethod, str] = { NERSCLoginMethod.SFAPI: "https://api.nersc.gov/api/v1.2", @@ -263,26 +252,19 @@ def _create_iriapi_client() -> Client: RuntimeError: If the acquired token is missing required scopes. """ client_id = "fae5c579-490a-4d76-b6eb-d78f65caeb63" # os.getenv(_IRIAPI_GLOBUS_CLIENT_ID_ENV) - # client_secret = os.getenv(_IRIAPI_GLOBUS_CLIENT_SECRET_ENV) if not client_id: raise ValueError( f"Globus client ID is unset. Set {_IRIAPI_GLOBUS_CLIENT_ID_ENV}." ) - # if not client_secret: - # raise ValueError( - # f"Globus client secret is unset. Set {_IRIAPI_GLOBUS_CLIENT_SECRET_ENV}. " - # "A Globus Confidential App client is required for automated IRI API auth." - # ) token_file_env = os.getenv(_IRIAPI_TOKEN_FILE_ENV) token_file = Path(token_file_env) if token_file_env else DEFAULT_TOKEN_FILE - access_token = get_access_token( - client_id=client_id, - requested_scopes=_IRIAPI_GLOBUS_REQUIRED_SCOPES, + access_token = get_iri_access_token( token_file=token_file, force_login=False, + prompt_login=False ) return httpx.Client( @@ -832,43 +814,6 @@ def build_multi_resolution( except Exception as e: logger.error(f"Error during multiresolution job submission or completion: {e}") return False - # try: - # logger.info("Submitting Tiff to Zarr job script to Perlmutter.") - # perlmutter = self.client.compute(Machine.perlmutter) - # job = perlmutter.submit_job(job_script) - # logger.info(f"Submitted job ID: {job.jobid}") - - # try: - # job.update() - # except Exception as update_err: - # logger.warning(f"Initial job update failed, continuing: {update_err}") - - # time.sleep(60) - # logger.info(f"Job {job.jobid} current state: {job.state}") - - # job.complete() # Wait until the job completes - # logger.info("Reconstruction job completed successfully.") - - # return True - - # except Exception as e: - # logger.warning(f"Error during job submission or completion: {e}") - # match = re.search(r"Job not found:\s*(\d+)", str(e)) - - # if match: - # jobid = match.group(1) - # logger.info(f"Attempting to recover job {jobid}.") - # try: - # job = self.client.perlmutter.job(jobid=jobid) - # time.sleep(30) - # job.complete() - # logger.info("Reconstruction job completed successfully after recovery.") - # return True - # except Exception as recovery_err: - # logger.error(f"Failed to recover job {jobid}: {recovery_err}") - # return False - # else: - # return False def segmentation_sam3( self, @@ -1847,7 +1792,8 @@ def nersc_recon_flow( logger.info(f"Starting NERSC reconstruction flow for {file_path=}") controller = get_controller( hpc_type=HPC.NERSC, - config=config + config=config, + login_method=NERSCLoginMethod.SFAPI ) logger.info("NERSC reconstruction controller initialized") @@ -2433,6 +2379,55 @@ def nersc_moon_segment_flow( return False +@flow(name="nersc_recon_test_iriapi_flow", flow_run_name="nersc_recon-{file_path}") +def nersc_recon_test_iriapi_flow( + file_path: str, + config: Optional[Config832] = None, +) -> bool: + """ + Perform tomography reconstruction on NERSC. + + :param file_path: Path to the file to reconstruct. + :param config: Configuration object (if None, a default Config832 will be created). + :return: True if successful, False otherwise. + """ + logger.info(f"Starting NERSC reconstruction flow for {file_path=}") + controller = get_controller( + hpc_type=HPC.NERSC, + config=config, + login_method=NERSCLoginMethod.IRIAPI + ) + logger.info("NERSC reconstruction controller initialized") + + nersc_reconstruction_success = controller.reconstruct( + file_path=file_path, + ) + logger.info(f"NERSC reconstruction success: {nersc_reconstruction_success}") + + nersc_multi_res_success = controller.build_multi_resolution( + file_path=file_path, + ) + logger.info(f"NERSC multi-resolution success: {nersc_multi_res_success}") + + path = Path(file_path) + folder_name = path.parent.name + file_name = path.stem + + tiff_file_path = f"{folder_name}/rec{file_name}" + zarr_file_path = f"{folder_name}/rec{file_name}.zarr" + + logger.info(f"{tiff_file_path=}") + logger.info(f"{zarr_file_path=}") + + # Transfers and pruning omitted from test flow. + + # TODO: Ingest into SciCat + if nersc_reconstruction_success: + return True + else: + return False + + @flow(name="nersc_streaming_flow", on_cancellation=[cancellation_hook]) def nersc_streaming_flow( walltime: datetime.timedelta = datetime.timedelta(minutes=5), From da163416f2ed1236be56393210ed5f88d8e6dbb5 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 1 Apr 2026 16:28:37 -0700 Subject: [PATCH 11/29] cleaning up old commented code --- orchestration/flows/bl832/nersc.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 629fac1b..e1d49fa9 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -473,7 +473,6 @@ def reconstruct( """ logger.info("Starting NERSC reconstruction process.") - # user = self.client.user() username = self._get_nersc_username() raw_path = self.config.nersc832_alsdev_raw.root_path @@ -737,7 +736,6 @@ def build_multi_resolution( logger.info("Starting NERSC multiresolution process.") - # user = self.client.user() username = self._get_nersc_username() multires_image = self.config.ghcr_images832["multires_image"] From d1d65ad162516a39544bf47cc5bef1a0af0d638d Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 1 Apr 2026 16:34:53 -0700 Subject: [PATCH 12/29] Updating unit tests --- orchestration/_tests/test_bl832/test_nersc.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/orchestration/_tests/test_bl832/test_nersc.py b/orchestration/_tests/test_bl832/test_nersc.py index 7a8ca07a..5b0f1a3d 100644 --- a/orchestration/_tests/test_bl832/test_nersc.py +++ b/orchestration/_tests/test_bl832/test_nersc.py @@ -181,20 +181,20 @@ def mock_iriapi_client(mocker): client = mocker.MagicMock() submit_response = mocker.MagicMock() - submit_response.json.return_value = {"job_id": "99999"} + submit_response.json.return_value = {"id": "99999"} client.post.return_value = submit_response status_response = mocker.MagicMock() - status_response.json.return_value = {"state": "COMPLETED"} + status_response.json.return_value = {"status": {"state": "completed"}} client.get.return_value = status_response return client - # --------------------------------------------------------------------------- # _create_sfapi_client # --------------------------------------------------------------------------- + def test_create_sfapi_client_success(mocker): """Valid credentials produce a Client instance.""" from orchestration.flows.bl832.nersc import NERSCTomographyHPCController @@ -472,7 +472,7 @@ def test_reconstruct_iriapi_job_failed(mocker, mock_iriapi_client, mock_config83 monkeypatch.setenv("NERSC_USERNAME", "alsdev") mocker.patch("orchestration.flows.bl832.nersc.time.sleep") - mock_iriapi_client.get.return_value.json.return_value = {"state": "FAILED"} + mock_iriapi_client.get.return_value.json.return_value = {"status": {"state": "failed"}} # was {"state": "FAILED"} controller = NERSCTomographyHPCController( client=mock_iriapi_client, @@ -970,17 +970,17 @@ def test_build_multi_resolution_iriapi_success(mocker, mock_iriapi_client, mock_ assert result is True mock_iriapi_client.post.assert_called_once() mock_iriapi_client.get.assert_called_once_with( - "/api/v1/compute/status/perlmutter/99999" + "/api/v1/compute/status/compute/99999" ) def test_build_multi_resolution_iriapi_failure(mocker, mock_iriapi_client, mock_config, monkeypatch): - """IRIAPI build_multi_resolution returns False when job state is FAILED.""" + """IRIAPI build_multi_resolution returns False when job state is failed.""" from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod monkeypatch.setenv("NERSC_USERNAME", "alsdev") mocker.patch("orchestration.flows.bl832.nersc.time.sleep") - mock_iriapi_client.get.return_value.json.return_value = {"state": "FAILED"} + mock_iriapi_client.get.return_value.json.return_value = {"status": {"state": "failed"}} controller = NERSCTomographyHPCController( client=mock_iriapi_client, From dda78c534c35636eb40c8db8f12434ff1fa13e19 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Tue, 7 Apr 2026 11:20:19 -0700 Subject: [PATCH 13/29] updating login script --- scripts/login_to_globus_and_prefect.sh | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/scripts/login_to_globus_and_prefect.sh b/scripts/login_to_globus_and_prefect.sh index a38b629f..b8f60bde 100755 --- a/scripts/login_to_globus_and_prefect.sh +++ b/scripts/login_to_globus_and_prefect.sh @@ -18,4 +18,7 @@ export GLOBUS_CLI_CLIENT_SECRET="$GLOBUS_CLIENT_SECRET" export GLOBUS_COMPUTE_CLIENT_ID="$GLOBUS_CLIENT_ID" export GLOBUS_COMPUTE_CLIENT_SECRET="$GLOBUS_CLIENT_SECRET" export PREFECT_API_KEY="$PREFECT_API_KEY" -export PREFECT_API_URL="$PREFECT_API_URL" \ No newline at end of file +export PREFECT_API_URL="$PREFECT_API_URL" +export NERSC_USERNAME="$NERSC_USERNAME" +export PATH_NERSC_CLIENT_ID="$PATH_NERSC_CLIENT_ID" +export PATH_NERSC_PRI_KEY="$PATH_NERSC_PRI_KEY" \ No newline at end of file From 596106aad41d18ebffdef03cf5db6acc80801a42 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Tue, 7 Apr 2026 14:29:40 -0700 Subject: [PATCH 14/29] Rebasing and including segmentation flows as part of iri/sfapi abstraction --- orchestration/_tests/test_bl832/test_nersc.py | 39 ++- orchestration/flows/bl832/nersc.py | 286 +++++++----------- 2 files changed, 130 insertions(+), 195 deletions(-) diff --git a/orchestration/_tests/test_bl832/test_nersc.py b/orchestration/_tests/test_bl832/test_nersc.py index 5b0f1a3d..3ba9742f 100644 --- a/orchestration/_tests/test_bl832/test_nersc.py +++ b/orchestration/_tests/test_bl832/test_nersc.py @@ -23,16 +23,6 @@ def prefect_test_fixture(): # Shared fixtures # --------------------------------------------------------------------------- -@pytest.fixture -def mock_config(mocker): - config = mocker.MagicMock() - config.ghcr_images832 = { - "recon_image": "mock_recon_image", - "multires_image": "mock_multires_image", - } - return config - - @pytest.fixture def mock_sfapi_client(mocker): """sfapi_client.Client mock with user, compute, submit_job, and job chained.""" @@ -257,9 +247,10 @@ def test_build_multi_resolution_success(mocker, mock_sfapi_client, mock_config83 result = controller.build_multi_resolution(file_path="folder/file.h5") - mock_sfapi_client.compute.assert_called_once_with(Machine.perlmutter) + assert mock_sfapi_client.compute.call_count == 2 + mock_sfapi_client.compute.assert_called_with(Machine.perlmutter) mock_sfapi_client.compute.return_value.submit_job.assert_called_once() - mock_sfapi_client.compute.return_value.submit_job.return_value.complete.assert_called_once() + mock_sfapi_client.compute.return_value.job.return_value.complete.assert_called_once() assert result is True @@ -292,7 +283,7 @@ def test_segmentation_sam3_success(mocker, mock_sfapi_client, mock_config832): mock_sfapi_client.compute.assert_called_with(Machine.perlmutter) mock_sfapi_client.compute.return_value.submit_job.assert_called_once() - mock_sfapi_client.compute.return_value.submit_job.return_value.complete.assert_called_once() + mock_sfapi_client.compute.return_value.job.return_value.complete.assert_called_once() assert isinstance(result, dict) assert result["success"] is True assert result["job_id"] == "12345" @@ -370,7 +361,7 @@ def test_segmentation_dinov3_success(mocker, mock_sfapi_client, mock_config832): mock_sfapi_client.compute.assert_called_with(Machine.perlmutter) mock_sfapi_client.compute.return_value.submit_job.assert_called_once() - mock_sfapi_client.compute.return_value.submit_job.return_value.complete.assert_called_once() + mock_sfapi_client.compute.return_value.job.return_value.complete.assert_called_once() assert result is True @@ -454,11 +445,15 @@ def test_reconstruct_iriapi_success(mocker, mock_iriapi_client, mock_config832, assert result["success"] is True assert result["job_id"] == "99999" mock_iriapi_client.post.assert_called_once() - assert mock_iriapi_client.post.call_args.args[0] == "/api/v1/compute/job/perlmutter" - assert "script" in mock_iriapi_client.post.call_args.kwargs["json"] + assert mock_iriapi_client.post.call_args.args[0] == "/api/v1/compute/job/compute" + posted_json = mock_iriapi_client.post.call_args.kwargs["json"] + assert posted_json["executable"] == "/bin/bash" + assert posted_json["arguments"][0] == "-c" + assert isinstance(posted_json["arguments"][1], str) # the script body + assert "tomo_recon" in posted_json["arguments"][1] # sanity check it's the right script assert mock_iriapi_client.get.call_count == 2 mock_iriapi_client.get.assert_any_call( - "/api/v1/compute/status/perlmutter/99999" + "/api/v1/compute/status/compute/99999" ) mock_iriapi_client.get.assert_any_call( "/api/v1/filesystem/file/perlmutter", @@ -589,7 +584,7 @@ def test_combine_segmentations_success(mocker, mock_sfapi_client, mock_config832 mock_sfapi_client.compute.assert_called_with(Machine.perlmutter) mock_sfapi_client.compute.return_value.submit_job.assert_called_once() - mock_sfapi_client.compute.return_value.submit_job.return_value.complete.assert_called_once() + mock_sfapi_client.compute.return_value.job.return_value.complete.assert_called_once() assert result is True @@ -952,7 +947,7 @@ def test_moon_segment_flow_no_sam3_no_combine(mocker, mock_config832, mock_recon # --------------------------------------------------------------------------- -def test_build_multi_resolution_iriapi_success(mocker, mock_iriapi_client, mock_config, monkeypatch): +def test_build_multi_resolution_iriapi_success(mocker, mock_iriapi_client, mock_config832, monkeypatch): """IRIAPI build_multi_resolution POSTs and polls successfully.""" from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod @@ -961,7 +956,7 @@ def test_build_multi_resolution_iriapi_success(mocker, mock_iriapi_client, mock_ controller = NERSCTomographyHPCController( client=mock_iriapi_client, - config=mock_config, + config=mock_config832, login_method=NERSCLoginMethod.IRIAPI, ) @@ -974,7 +969,7 @@ def test_build_multi_resolution_iriapi_success(mocker, mock_iriapi_client, mock_ ) -def test_build_multi_resolution_iriapi_failure(mocker, mock_iriapi_client, mock_config, monkeypatch): +def test_build_multi_resolution_iriapi_failure(mocker, mock_iriapi_client, mock_config832, monkeypatch): """IRIAPI build_multi_resolution returns False when job state is failed.""" from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod @@ -984,7 +979,7 @@ def test_build_multi_resolution_iriapi_failure(mocker, mock_iriapi_client, mock_ controller = NERSCTomographyHPCController( client=mock_iriapi_client, - config=mock_config, + config=mock_config832, login_method=NERSCLoginMethod.IRIAPI, ) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index e1d49fa9..c2d5f031 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -459,6 +459,55 @@ def _wait_for_job(self, job_id: str) -> bool: else: raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") + def _mkdir_remote(self, path: str) -> None: + """Create a directory on Perlmutter remotely. + + Args: + path: Absolute path to create. + """ + if self.login_method is NERSCLoginMethod.SFAPI: + perlmutter = self.client.compute(Machine.perlmutter) + perlmutter.run(f"mkdir -p {path}") + elif self.login_method is NERSCLoginMethod.IRIAPI: + response = self.client.post( + "/api/v1/filesystem/mkdir/perlmutter", + json={"path": path, "parents": True}, + ) + response.raise_for_status() + else: + raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") + + def _read_remote_file(self, path: str) -> str: + """Read a remote file on Perlmutter and return its contents. + + Args: + path: Absolute path to the file on Perlmutter. + + Returns: + File contents as a string. + """ + if self.login_method is NERSCLoginMethod.SFAPI: + perlmutter = self.client.compute(Machine.perlmutter) + result = perlmutter.run(f"cat {path}") + if isinstance(result, str): + return result + elif hasattr(result, 'output'): + return result.output + elif hasattr(result, 'stdout'): + return result.stdout + return str(result) + + elif self.login_method is NERSCLoginMethod.IRIAPI: + response = self.client.get( + "/api/v1/filesystem/file/perlmutter", + params={"path": path}, + ) + response.raise_for_status() + return response.text + + else: + raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") + def reconstruct( self, file_path: str = "", @@ -660,27 +709,7 @@ def _fetch_timing_data(self, pscratch_path: str, job_id: str) -> dict: timing_file = f"{pscratch_path}/tomo_recon_logs/timing_{job_id}.txt" try: - # Use SFAPI to read the timing file - if self.login_method is NERSCLoginMethod.SFAPI: - perlmutter = self.client.compute(Machine.perlmutter) - result = perlmutter.run(f"cat {timing_file}") - - # result might be a string directly, or an object with .output - if isinstance(result, str): - output = result - elif hasattr(result, 'output'): - output = result.output - elif hasattr(result, 'stdout'): - output = result.stdout - else: - output = str(result) - elif self.login_method is NERSCLoginMethod.IRIAPI: - response = self.client.get( - "/api/v1/filesystem/file/perlmutter", - params={"path": timing_file}, - ) - response.raise_for_status() - output = response.text + output = self._read_remote_file(timing_file) logger.info(f"Timing file contents:\n{output}") @@ -823,8 +852,8 @@ def segmentation_sam3( """ logger.info("Starting NERSC segmentation process (inference_v6).") - user = self.client.user() - pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" + username = self._get_nersc_username() + pscratch_path = f"/pscratch/sd/{username[0]}/{username}" opts = _load_job_options( variable_name="nersc-segmentation-options", @@ -1014,34 +1043,21 @@ def segmentation_sam3( """ try: - logger.info("Submitting segmentation job to Perlmutter (v6).") - perlmutter = self.client.compute(Machine.perlmutter) + logger.info("Submitting segmentation job to Perlmutter.") # Ensure directories exist logger.info("Creating necessary directories...") - perlmutter.run(f"mkdir -p {pscratch_path}/tomo_seg_logs") - perlmutter.run(f"mkdir -p {output_dir}") + self._mkdir_remote(f"{pscratch_path}/tomo_seg_logs") + self._mkdir_remote(output_dir) # Submit job - job = perlmutter.submit_job(job_script) - logger.info(f"Submitted job ID: {job.jobid}") - - # Initial update - try: - job.update() - except Exception as update_err: - logger.warning(f"Initial job update failed, continuing: {update_err}") - - # Wait briefly before polling + job_id = self._submit_job(job_script) + logger.info(f"Submitted job ID: {job_id}") time.sleep(60) - logger.info(f"Job {job.jobid} current state: {job.state}") - - # Wait for completion - job.complete() + success = self._wait_for_job(job_id) logger.info("Segmentation job completed successfully.") - # Fetch timing data from output file - timing = self._fetch_seg_timing_from_output(perlmutter, pscratch_path, job.jobid, job_name) + timing = self._fetch_seg_timing_from_output(pscratch_path, job_id, job_name) if timing: logger.info("=" * 60) @@ -1055,43 +1071,21 @@ def segmentation_sam3( logger.info("=" * 60) return { - "success": True, - "job_id": job.jobid, + "success": success, + "job_id": job_id, "timing": timing, - "output_dir": output_dir + "output_dir": output_dir, } except Exception as e: logger.error(f"Error during segmentation job: {e}") import traceback logger.error(traceback.format_exc()) - - # Attempt recovery - match = re.search(r"Job not found:\s*(\d+)", str(e)) - if match: - jobid = match.group(1) - logger.info(f"Attempting to recover job {jobid}.") - try: - job = self.client.compute(Machine.perlmutter).job(jobid=jobid) - time.sleep(30) - job.complete() - logger.info("Segmentation job completed after recovery.") - - timing = self._fetch_seg_timing_from_output(perlmutter, pscratch_path, jobid, job_name) - return { - "success": True, - "job_id": jobid, - "timing": timing, - "output_dir": output_dir - } - except Exception as recovery_err: - logger.error(f"Failed to recover job {jobid}: {recovery_err}") - return { "success": False, "job_id": None, "timing": None, - "output_dir": None + "output_dir": None, } def segmentation_dinov3( @@ -1109,8 +1103,8 @@ def segmentation_dinov3( """ logger.info("Starting NERSC DINOv3 segmentation process.") - user = self.client.user() - pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" + username = self._get_nersc_username() + pscratch_path = f"/pscratch/sd/{username[0]}/{username}" # Load from config spec = self._get_segmentation_spec("dinov3", project) @@ -1245,39 +1239,15 @@ def segmentation_dinov3( """ try: logger.info("Submitting DINOv3 segmentation job to Perlmutter.") - perlmutter = self.client.compute(Machine.perlmutter) - job = perlmutter.submit_job(job_script) - logger.info(f"Submitted job ID: {job.jobid}") - - try: - job.update() - except Exception as update_err: - logger.warning(f"Initial job update failed, continuing: {update_err}") - + job_id = self._submit_job(job_script) + logger.info(f"Submitted job ID: {job_id}") time.sleep(60) - logger.info(f"Job {job.jobid} current state: {job.state}") - - job.complete() - logger.info("DINOv3 segmentation job completed successfully.") - return True - + success = self._wait_for_job(job_id) + logger.info(f"DINOv3 segmentation job {'completed successfully' if success else 'failed'}.") + return success except Exception as e: logger.error(f"Error during DINOv3 segmentation job submission or completion: {e}") - match = re.search(r"Job not found:\s*(\d+)", str(e)) - if match: - jobid = match.group(1) - logger.info(f"Attempting to recover job {jobid}.") - try: - job = self.client.compute(Machine.perlmutter).job(jobid=jobid) - time.sleep(30) - job.complete() - logger.info("DINOv3 segmentation job completed successfully after recovery.") - return True - except Exception as recovery_err: - logger.error(f"Failed to recover job {jobid}: {recovery_err}") - return False - else: - return False + return False def combine_segmentations( self, @@ -1285,7 +1255,7 @@ def combine_segmentations( ) -> bool: """ Run CPU-based combination of SAM3+DINOv3 segmentation results - at NERSC Perlmutter via SFAPI Slurm job. + at NERSC Perlmutter via Slurm job. :param recon_folder_path: Relative path to the reconstructed data folder, e.g. 'folder_name/recYYYYMMDD_hhmmss_scanname/' @@ -1293,8 +1263,8 @@ def combine_segmentations( """ logger.info("Starting NERSC segmentation combination process.") - user = self.client.user() - pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" + username = self._get_nersc_username() + pscratch_path = f"/pscratch/sd/{username[0]}/{username}" opts = _load_job_options( "nersc-combine-seg-options", self.config.nersc_combine_segmentation_settings @@ -1393,45 +1363,20 @@ def combine_segmentations( """ try: logger.info("Submitting segmentation combination job to Perlmutter.") - perlmutter = self.client.compute(Machine.perlmutter) - job = perlmutter.submit_job(job_script) - logger.info(f"Submitted job ID: {job.jobid}") - - try: - job.update() - except Exception as update_err: - logger.warning(f"Initial job update failed, continuing: {update_err}") - + job_id = self._submit_job(job_script) + logger.info(f"Submitted job ID: {job_id}") time.sleep(60) - logger.info(f"Job {job.jobid} current state: {job.state}") - - job.complete() - logger.info("Segmentation combination job completed successfully.") - return True - + success = self._wait_for_job(job_id) + logger.info(f"Segmentation combination job {'completed successfully' if success else 'failed'}.") + return success except Exception as e: logger.error(f"Error during segmentation combination job submission or completion: {e}") - match = re.search(r"Job not found:\s*(\d+)", str(e)) - if match: - jobid = match.group(1) - logger.info(f"Attempting to recover job {jobid}.") - try: - job = self.client.compute(Machine.perlmutter).job(jobid=jobid) - time.sleep(30) - job.complete() - logger.info("Segmentation combination job completed successfully after recovery.") - return True - except Exception as recovery_err: - logger.error(f"Failed to recover job {jobid}: {recovery_err}") - return False - else: - return False + return False - def _fetch_seg_timing_from_output(self, perlmutter, pscratch_path: str, job_id: str, job_name: str) -> dict: + def _fetch_seg_timing_from_output(self, pscratch_path: str, job_id: str, job_name: str) -> dict: """ Fetch and parse timing data from the SLURM output file. - :param perlmutter: SFAPI compute object for Perlmutter :param pscratch_path: Path to the user's pscratch directory :param job_id: SLURM job ID :param job_name: Job name for finding output file @@ -1440,18 +1385,7 @@ def _fetch_seg_timing_from_output(self, perlmutter, pscratch_path: str, job_id: output_file = f"{pscratch_path}/tomo_seg_logs/{job_name}_{job_id}.out" try: - # Use SFAPI to read the output file - result = perlmutter.run(f"cat {output_file}") - - # Handle different result types - if isinstance(result, str): - output = result - elif hasattr(result, 'output'): - output = result.output - elif hasattr(result, 'stdout'): - output = result.stdout - else: - output = str(result) + output = self._read_remote_file(output_file) logger.info("Job output file contents (last 50 lines):") lines = output.strip().split('\n') @@ -1528,8 +1462,8 @@ def pull_shifter_image( """ logger.info("Starting Shifter image pull.") - user = self.client.user() - pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" + username = self._get_nersc_username() + pscratch_path = f"/pscratch/sd/{username[0]}/{username}" if image is None: image = self.config.ghcr_images832["recon_image"] @@ -1576,24 +1510,16 @@ def pull_shifter_image( try: logger.info("Submitting Shifter image pull job to Perlmutter.") - perlmutter = self.client.compute(Machine.perlmutter) - job = perlmutter.submit_job(job_script) - logger.info(f"Submitted job ID: {job.jobid}") + job_id = self._submit_job(job_script) + logger.info(f"Submitted job ID: {job_id}") if wait: - try: - job.update() - except Exception as update_err: - logger.warning(f"Initial job update failed, continuing: {update_err}") - time.sleep(30) - logger.info(f"Job {job.jobid} current state: {job.state}") - - job.complete() - logger.info("Shifter image pull completed successfully.") - return True + success = self._wait_for_job(job_id) + logger.info(f"Shifter image pull {'completed successfully' if success else 'failed'}.") + return success else: - logger.info(f"Job submitted. Check status with job ID: {job.jobid}") + logger.info(f"Job submitted. Check status with job ID: {job_id}") return True except Exception as e: @@ -1616,17 +1542,31 @@ def check_shifter_image( image = self.config.ghcr_images832["recon_image"] try: - perlmutter = self.client.compute(Machine.perlmutter) - # Run shifterimg images command - result = perlmutter.run(f"shifterimg images | grep -E \"$(echo {image} | sed 's/:/.*/g')\"") + if self.login_method is NERSCLoginMethod.SFAPI: + # synchronous via utilities/command + perlmutter = self.client.compute(Machine.perlmutter) + result = perlmutter.run(f"shifterimg images | grep -E \"$(echo {image} | sed 's/:/.*/g')\"") + output = result if isinstance(result, str) else getattr(result, 'output', str(result)) - if isinstance(result, str): - output = result - elif hasattr(result, 'output'): - output = result.output - else: - output = str(result) + elif self.login_method is NERSCLoginMethod.IRIAPI: + # async: submit job → wait → read stdout file + username = self._get_nersc_username() + pscratch_path = f"/pscratch/sd/{username[0]}/{username}" + output_file = f"{pscratch_path}/tomo_recon_logs/shifter_check.txt" + check_script = f"""#!/bin/bash + #SBATCH -q debug + #SBATCH -A als + #SBATCH -C cpu + #SBATCH -N 1 + #SBATCH --ntasks=1 + #SBATCH --cpus-per-task=1 + #SBATCH --time=0:05:00 + shifterimg images | grep -E "$(echo {image} | sed 's/:/.*/g')" > {output_file} 2>&1 || true + """ + job_id = self._submit_job(check_script) + self._wait_for_job(job_id) + output = self._read_remote_file(output_file) if output.strip(): logger.info(f"Image found in Shifter cache: {output.strip()}") From 9da5e6e30a2d9fe5fc3a816be4d4e92b0e95480f Mon Sep 17 00:00:00 2001 From: David Abramov Date: Mon, 13 Apr 2026 13:59:39 -0700 Subject: [PATCH 15/29] commenting out petiole segmentation prune block for now, while testing --- orchestration/flows/bl832/nersc.py | 122 ++++++++++++++--------------- 1 file changed, 61 insertions(+), 61 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index c2d5f031..13345343 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -2047,67 +2047,67 @@ def nersc_petiole_segment_flow( ) # ── STEP 6: Pruning ─────────────────────────────────────────────────────── - logger.info("Scheduling file pruning tasks.") - prune_controller = get_prune_controller(prune_type=PruneMethod.GLOBUS, config=config) - - try: - prune_controller.prune( - file_path=f"{folder_name}/{path.name}", - source_endpoint=config.nersc832_alsdev_pscratch_raw, - check_endpoint=None, - days_from_now=1.0 - ) - except Exception as e: - logger.warning(f"Failed to schedule raw data pruning: {e}") - - if nersc_reconstruction_success: - try: - prune_controller.prune( - file_path=scratch_path_tiff, - source_endpoint=config.nersc832_alsdev_pscratch_scratch, - check_endpoint=config.data832_scratch if data832_tiff_transfer_success else None, - days_from_now=1.0 - ) - except Exception as e: - logger.warning(f"Failed to schedule reconstruction data pruning: {e}") - - if any_seg_success: - try: - prune_controller.prune( - file_path=scratch_path_segment, - source_endpoint=config.nersc832_alsdev_pscratch_scratch, - check_endpoint=config.data832_scratch if any([ - data832_sam3_transfer_success, - data832_dinov3_transfer_success, - ]) else None, - days_from_now=1.0 - ) - except Exception as e: - logger.warning(f"Failed to schedule segmentation data pruning: {e}") - - if data832_tiff_transfer_success: - try: - prune_controller.prune( - file_path=scratch_path_tiff, - source_endpoint=config.data832_scratch, - check_endpoint=None, - days_from_now=30.0 - ) - except Exception as e: - logger.warning(f"Failed to schedule data832 tiff pruning: {e}") - - if any([data832_sam3_transfer_success, - data832_dinov3_transfer_success, - data832_combined_transfer_success]): - try: - prune_controller.prune( - file_path=scratch_path_segment, - source_endpoint=config.data832_scratch, - check_endpoint=None, - days_from_now=30.0 - ) - except Exception as e: - logger.warning(f"Failed to schedule data832 segment pruning: {e}") + # logger.info("Scheduling file pruning tasks.") + # prune_controller = get_prune_controller(prune_type=PruneMethod.GLOBUS, config=config) + + # try: + # prune_controller.prune( + # file_path=f"{folder_name}/{path.name}", + # source_endpoint=config.nersc832_alsdev_pscratch_raw, + # check_endpoint=None, + # days_from_now=1.0 + # ) + # except Exception as e: + # logger.warning(f"Failed to schedule raw data pruning: {e}") + + # if nersc_reconstruction_success: + # try: + # prune_controller.prune( + # file_path=scratch_path_tiff, + # source_endpoint=config.nersc832_alsdev_pscratch_scratch, + # check_endpoint=config.data832_scratch if data832_tiff_transfer_success else None, + # days_from_now=1.0 + # ) + # except Exception as e: + # logger.warning(f"Failed to schedule reconstruction data pruning: {e}") + + # if any_seg_success: + # try: + # prune_controller.prune( + # file_path=scratch_path_segment, + # source_endpoint=config.nersc832_alsdev_pscratch_scratch, + # check_endpoint=config.data832_scratch if any([ + # data832_sam3_transfer_success, + # data832_dinov3_transfer_success, + # ]) else None, + # days_from_now=1.0 + # ) + # except Exception as e: + # logger.warning(f"Failed to schedule segmentation data pruning: {e}") + + # if data832_tiff_transfer_success: + # try: + # prune_controller.prune( + # file_path=scratch_path_tiff, + # source_endpoint=config.data832_scratch, + # check_endpoint=None, + # days_from_now=30.0 + # ) + # except Exception as e: + # logger.warning(f"Failed to schedule data832 tiff pruning: {e}") + + # if any([data832_sam3_transfer_success, + # data832_dinov3_transfer_success, + # data832_combined_transfer_success]): + # try: + # prune_controller.prune( + # file_path=scratch_path_segment, + # source_endpoint=config.data832_scratch, + # check_endpoint=None, + # days_from_now=30.0 + # ) + # except Exception as e: + # logger.warning(f"Failed to schedule data832 segment pruning: {e}") if nersc_reconstruction_success and any_seg_success: logger.info("NERSC reconstruction + multi-segmentation flow completed successfully.") From ef227af6598ae1042fd8119e2f462f1e6519ecfc Mon Sep 17 00:00:00 2001 From: David Abramov Date: Mon, 13 Apr 2026 14:43:44 -0700 Subject: [PATCH 16/29] Making reconstruction run as a task --- orchestration/flows/bl832/nersc.py | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 13345343..85e85bed 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -1873,7 +1873,6 @@ def nersc_petiole_segment_flow( logger.info(f"Reconstructed TIFFs will be at: {scratch_path_tiff}") logger.info(f"Segmented output will be at: {scratch_path_segment}") - controller = get_controller(hpc_type=HPC.NERSC, config=config) logger.info("NERSC controller initialized") if num_nodes is None: @@ -1894,9 +1893,10 @@ def nersc_petiole_segment_flow( # ── STEP 1: Multinode Reconstruction ───────────────────────────────────── logger.info(f"Using multi-node reconstruction with {num_nodes} nodes") - recon_result = controller.reconstruct( + recon_result = nersc_reconstruction_task( file_path=file_path, - num_nodes=num_nodes + num_nodes=num_nodes, + config=config, ) if isinstance(recon_result, dict): @@ -2427,6 +2427,30 @@ def pull_shifter_image_flow( return success +@task(name="nersc_reconstruction_task") +def nersc_reconstruction_task( + file_path: str, + num_nodes: int = 4, + config: Optional[Config832] = None, +) -> dict: + """ + Run tomography reconstruction at NERSC Perlmutter. + + :param file_path: Path to the raw HDF5 file to reconstruct. + :param num_nodes: Number of nodes to use for reconstruction. + :param config: Configuration object for the flow. + :return: Dict with keys 'success', 'job_id', 'timing'. + """ + logger = get_run_logger() + if config is None: + config = Config832() + + logger.info("Initializing NERSC Tomography HPC Controller.") + controller = get_controller(hpc_type=HPC.NERSC, config=config) + logger.info(f"Starting NERSC reconstruction task for {file_path=}") + return controller.reconstruct(file_path=file_path, num_nodes=num_nodes) + + @task(name="nersc_multiresolution_task") def nersc_multiresolution_task( file_path: str, From b4558bef7389ca6eaab99524961932e6ac2e882c Mon Sep 17 00:00:00 2001 From: David Abramov Date: Mon, 13 Apr 2026 14:50:03 -0700 Subject: [PATCH 17/29] Making IRIAPI the default login method for now --- orchestration/flows/bl832/nersc.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 85e85bed..29062308 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -1847,6 +1847,7 @@ def nersc_petiole_segment_flow( file_path: str, config: Optional[Config832] = None, num_nodes: Optional[int] = None, + login_method: Optional[NERSCLoginMethod] = NERSCLoginMethod.IRIAPI ) -> bool: """ Transfer raw data to NERSC, run reconstruction, then run SAM3 and DINOv3 @@ -1897,6 +1898,7 @@ def nersc_petiole_segment_flow( file_path=file_path, num_nodes=num_nodes, config=config, + login_method=login_method ) if isinstance(recon_result, dict): @@ -1950,10 +1952,10 @@ def nersc_petiole_segment_flow( logger.info("Submitting SAM3 and DINOv3 segmentation tasks concurrently.") sam3_future = nersc_segmentation_sam3_task.submit( - recon_folder_path=scratch_path_tiff, config=config + recon_folder_path=scratch_path_tiff, config=config, login_method=login_method ) dinov3_future = nersc_segmentation_dinov3_task.submit( - recon_folder_path=scratch_path_tiff, config=config, project="petiole" + recon_folder_path=scratch_path_tiff, config=config, project="petiole", login_method=login_method ) # ── STEP 4: Transfer each model's output as it completes ───────────────── @@ -1999,7 +2001,7 @@ def nersc_petiole_segment_flow( logger.info("Running segmentation combination.") combine_future = nersc_combine_segmentations_task.submit( - recon_folder_path=scratch_path_tiff, config=config + recon_folder_path=scratch_path_tiff, config=config, login_method=login_method ) combine_success = combine_future.result() @@ -2432,6 +2434,7 @@ def nersc_reconstruction_task( file_path: str, num_nodes: int = 4, config: Optional[Config832] = None, + login_method: Optional[NERSCLoginMethod] = NERSCLoginMethod.IRIAPI ) -> dict: """ Run tomography reconstruction at NERSC Perlmutter. @@ -2446,7 +2449,7 @@ def nersc_reconstruction_task( config = Config832() logger.info("Initializing NERSC Tomography HPC Controller.") - controller = get_controller(hpc_type=HPC.NERSC, config=config) + controller = get_controller(hpc_type=HPC.NERSC, config=config, login_method=login_method) logger.info(f"Starting NERSC reconstruction task for {file_path=}") return controller.reconstruct(file_path=file_path, num_nodes=num_nodes) @@ -2455,6 +2458,7 @@ def nersc_reconstruction_task( def nersc_multiresolution_task( file_path: str, config: Optional[Config832] = None, + login_method: Optional[NERSCLoginMethod] = NERSCLoginMethod.IRIAPI ) -> bool: """ Run multiresolution task at NERSC. @@ -2472,7 +2476,8 @@ def nersc_multiresolution_task( logger.info("Initializing NERSC Tomography HPC Controller.") tomography_controller = get_controller( hpc_type=HPC.NERSC, - config=config + config=config, + login_method=login_method ) logger.info(f"Starting NERSC multiresolution task for {file_path=}") nersc_multiresolution_success = tomography_controller.build_multi_resolution( @@ -2507,6 +2512,7 @@ def nersc_multiresolution_integration_test() -> bool: def nersc_segmentation_sam3_task( recon_folder_path: str, config: Optional[Config832] = None, + login_method: Optional[NERSCLoginMethod] = NERSCLoginMethod.IRIAPI ) -> bool: """ Run segmentation task at NERSC. @@ -2524,7 +2530,8 @@ def nersc_segmentation_sam3_task( logger.info("Initializing NERSC Tomography HPC Controller.") tomography_controller = get_controller( hpc_type=HPC.NERSC, - config=config + config=config, + login_method=login_method ) logger.info(f"Starting NERSC segmentation task for {recon_folder_path=}") nersc_segmentation_success = tomography_controller.segmentation_sam3( @@ -2545,12 +2552,13 @@ def nersc_segmentation_dinov3_task( recon_folder_path: str, config: Optional[Config832] = None, project: Optional[str] = "petiole", + login_method: Optional[NERSCLoginMethod] = NERSCLoginMethod.IRIAPI ) -> bool: logger = get_run_logger() if config is None: logger.info("No config provided, using default Config832.") config = Config832() - tomography_controller = get_controller(hpc_type=HPC.NERSC, config=config) + tomography_controller = get_controller(hpc_type=HPC.NERSC, config=config, login_method=login_method) logger.info(f"Starting NERSC DINOv3 segmentation task for {recon_folder_path=}, {project=}") success = tomography_controller.segmentation_dinov3(recon_folder_path=recon_folder_path, project=project) if not success: @@ -2564,12 +2572,13 @@ def nersc_segmentation_dinov3_task( def nersc_combine_segmentations_task( recon_folder_path: str, config: Optional[Config832] = None, + login_method: Optional[NERSCLoginMethod] = NERSCLoginMethod.IRIAPI ) -> bool: logger = get_run_logger() if config is None: logger.info("No config provided, using default Config832.") config = Config832() - tomography_controller = get_controller(hpc_type=HPC.NERSC, config=config) + tomography_controller = get_controller(hpc_type=HPC.NERSC, config=config, login_method=login_method) logger.info(f"Starting NERSC combine segmentations task for {recon_folder_path=}") success = tomography_controller.combine_segmentations(recon_folder_path=recon_folder_path) if not success: @@ -2591,7 +2600,8 @@ def nersc_segmentation_sam3_integration_test() -> bool: recon_folder_path = 'synaps-i/rec20211222_125057_petiole4' # 'test' # flow_success = nersc_segmentation_sam3_task( recon_folder_path=recon_folder_path, - config=Config832() + config=Config832(), + login_method=NERSCLoginMethod.IRIAPI ) logger.info(f"Flow success: {flow_success}") return flow_success From 241c889f91ffdb0743f0fecedfc8ac7ca413423d Mon Sep 17 00:00:00 2001 From: David Abramov Date: Mon, 13 Apr 2026 19:09:44 -0700 Subject: [PATCH 18/29] adjusting queue name and account --- orchestration/flows/bl832/nersc.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 29062308..32e370ca 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -365,7 +365,7 @@ def _get_nersc_username(self) -> str: ) return username - def _submit_job(self, job_script: str) -> str: + def _submit_job(self, job_script: str, num_nodes: int = 1) -> str: """Submit a Slurm job script and return the job ID. Dispatches to the appropriate submission mechanism based on @@ -373,6 +373,7 @@ def _submit_job(self, job_script: str) -> str: Args: job_script: The full Slurm batch script to submit. + num_nodes: The number of nodes to request for the job. Returns: The submitted job ID as a string. @@ -400,15 +401,15 @@ def _submit_job(self, job_script: str) -> str: "stdout_path": f"{pscratch_path}/tomo_recon_logs/iri_job.out", "stderr_path": f"{pscratch_path}/tomo_recon_logs/iri_job.err", "resources": { - "node_count": 1, + "node_count": num_nodes, "processes_per_node": 1, "cpu_cores_per_process": 64, "exclusive_node_use": True, }, "attributes": { "duration": 1800, - "queue_name": "realtime", - "account": "als", + "queue_name": "regular", # change to dynamic + "account": "dabramov", # change to dynamic "custom_attributes": {"constraint": "cpu"}, }, } From c9e7b14c330b2506caebd42faff03b00c70f053e Mon Sep 17 00:00:00 2001 From: David Abramov Date: Mon, 13 Apr 2026 19:21:45 -0700 Subject: [PATCH 19/29] Making the IRI job submission read sbatch settings --- orchestration/flows/bl832/nersc.py | 34 +++++++++++++++++++++--------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 32e370ca..01f66dd4 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -387,8 +387,23 @@ def _submit_job(self, job_script: str, num_nodes: int = 1) -> str: return str(job.jobid) elif self.login_method is NERSCLoginMethod.IRIAPI: - username = self._get_nersc_username() - pscratch_path = f"/pscratch/sd/{username[0]}/{username}" + # Parse SBATCH directives before stripping them + sbatch_values = {} + for line in job_script.splitlines(): + if line.startswith("#SBATCH"): + if "-q " in line: + sbatch_values["queue_name"] = line.split("-q ")[-1].strip() + elif "-A " in line: + sbatch_values["account"] = line.split("-A ")[-1].strip() + elif "--time=" in line: + t = line.split("--time=")[-1].strip() + # convert HH:MM:SS to seconds + parts = t.split(":") + sbatch_values["duration"] = int(parts[0])*3600 + int(parts[1])*60 + int(parts[2]) + elif "-N " in line: + sbatch_values["node_count"] = int(line.split("-N ")[-1].strip()) + elif "-C " in line: + sbatch_values["constraint"] = line.split("-C ")[-1].strip() script_body = "\n".join( line for line in job_script.splitlines() @@ -398,22 +413,21 @@ def _submit_job(self, job_script: str, num_nodes: int = 1) -> str: job_spec = { "executable": "/bin/bash", "arguments": ["-c", script_body], - "stdout_path": f"{pscratch_path}/tomo_recon_logs/iri_job.out", - "stderr_path": f"{pscratch_path}/tomo_recon_logs/iri_job.err", "resources": { - "node_count": num_nodes, + "node_count": sbatch_values.get("node_count", 1), "processes_per_node": 1, "cpu_cores_per_process": 64, "exclusive_node_use": True, }, "attributes": { - "duration": 1800, - "queue_name": "regular", # change to dynamic - "account": "dabramov", # change to dynamic - "custom_attributes": {"constraint": "cpu"}, + "duration": sbatch_values.get("duration", 1800), + "queue_name": sbatch_values.get("queue_name", "realtime"), + "account": sbatch_values.get("account", "als"), + "custom_attributes": { + "constraint": sbatch_values.get("constraint", "cpu") + }, }, } - response = self.client.post( f"/api/v1/compute/job/{_IRI_COMPUTE_RESOURCE}", json=job_spec, From 698d243cd74bd7754e6e2e23d993ce48c490281a Mon Sep 17 00:00:00 2001 From: David Abramov Date: Tue, 14 Apr 2026 14:39:58 -0700 Subject: [PATCH 20/29] Switching to debug queue/2 nodes for the IRI demo --- config.yml | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/config.yml b/config.yml index ea576b54..2d6d2ab3 100644 --- a/config.yml +++ b/config.yml @@ -178,15 +178,15 @@ hpc_submission_settings832: # ── RECON + MULTIRES SETTINGS ─────────────────────────────────────────────── nersc_reconstruction: # ── SLURM resource allocation ───────────────────────────────────────────── - qos: realtime + qos: debug account: als - reservation: "_CAP_TOMO_MOON_CPU" - num_nodes: 16 + reservation: "" + num_nodes: 2 cpus-per-task: 128 walltime: "0:30:00" nersc_multiresolution: # ── SLURM resource allocation ───────────────────────────────────────────── - qos: realtime + qos: debug account: als reservation: "" cpus-per-task: 128 @@ -195,15 +195,15 @@ hpc_submission_settings832: # ── PETIOLE SEGMENTATION SETTINGS ─────────────────────────────────────────── nersc_segmentation_sam3: # ── SLURM resource allocation ───────────────────────────────────────────── - qos: regular + qos: debug account: als constraint: gpu reservation: "" - num_nodes: 4 + num_nodes: 2 ntasks-per-node: 1 gpus-per-node: 4 cpus-per-task: 128 - walltime: "00:59:00" + walltime: "00:30:00" # ── Inference parameters ────────────────────────────────────────────────── script_name: "src/inference_v6.py" batch_size: 1 @@ -226,16 +226,16 @@ hpc_submission_settings832: finetuned_checkpoint_path: /global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/sam3_finetune/sam3/checkpoint_v6.pt nersc_segmentation_dinov3: # ── SLURM resource allocation ───────────────────────────────────────────── - qos: regular + qos: debug account: als constraint: gpu reservation: "" - num_nodes: 4 + num_nodes: 2 ntasks-per-node: 1 nproc_per_node: 4 gpus-per-node: 4 cpus-per-task: 128 - walltime: "00:59:00" + walltime: "00:30:00" # ── Inference parameters ────────────────────────────────────────────────── script_name: "src.inference_dino_v1" batch_size: 4 @@ -246,14 +246,14 @@ hpc_submission_settings832: dino_checkpoint_path: /global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/dino/best.ckpt nersc_combine_segmentations: # ── SLURM resource allocation ───────────────────────────────────────────── - qos: regular + qos: debug account: als constraint: cpu reservation: "" - num_nodes: 4 + num_nodes: 2 ntasks: 128 cpus-per-task: 1 - walltime: "01:00:00" + walltime: "00:30:00" # ── Combination parameters ──────────────────────────────────────────────── script_name: "src.combine_sam_dino_v3" dilate_px: 5 From 6e01f8fcb71b3aab2de26019d390ae221d688feb Mon Sep 17 00:00:00 2001 From: David Abramov Date: Tue, 14 Apr 2026 14:40:59 -0700 Subject: [PATCH 21/29] check globus token expiration before minting a new one. avoids race condition when submitting concurrent jobs --- orchestration/globus/get_globus_token.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/orchestration/globus/get_globus_token.py b/orchestration/globus/get_globus_token.py index c47057e8..f740a034 100644 --- a/orchestration/globus/get_globus_token.py +++ b/orchestration/globus/get_globus_token.py @@ -291,6 +291,19 @@ def get_iri_access_token( or if the resulting tokens do not include a valid IRI access token. """ client = globus_sdk.NativeAppAuthClient(CLIENT_ID) + + # Fast path: if token exists and is not expired, return it directly without refreshing or saving + if not force_login: + stored = load_tokens(token_file) + if stored: + try: + iri_token = get_iri_token(stored) + expires_at = iri_token.get("expires_at_seconds", 0) + if expires_at and time.time() < expires_at - 60: # 60s buffer + return iri_token["access_token"] + except RuntimeError: + pass # fall through to refresh + auth_data = None used_refresh = False if not force_login: From f4388e8581aeb93dd2f8d7170672e0db060f42b5 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Tue, 14 Apr 2026 14:41:33 -0700 Subject: [PATCH 22/29] Fixing IRIAPI bugs, also commenting out Globus transfers for now --- orchestration/flows/bl832/nersc.py | 219 ++++++++++++++++++++--------- 1 file changed, 152 insertions(+), 67 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 01f66dd4..5ea74f9b 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -45,6 +45,24 @@ NERSCLoginMethod.IRIAPI: "https://api.iri.nersc.gov", } +# NERSC resource IDs (from status/resources endpoint) +RESOURCE_IDS = { + # Perlmutter compute + "perlmutter_compute": "94351904-6dba-4c16-b5cd-fbd280d8615b", + "perlmutter_login": "e525a224-61c1-419f-9642-91168c792e39", + "perlmutter_realtime": "3776417d-747c-4753-895a-6323c17b9c98", + "perlmutter_job_submit": "3cf3c048-855e-4dd8-a189-065a483954bb", + # Storage + "scratch": "43d8f6c0-f900-48ce-b267-73714103f4ac", + "homes": "65b28619-c3b6-4942-8da1-044a3b3a2a9e", + "common": "7e07a611-f927-4a39-a44d-b1d6e307accd", + "cfs": "59e80c79-4dfd-4c53-9c07-7405685fcd37", + "archive": "f4916c65-9001-49c2-b0bf-6fe4276b564c", + # Services + "globus": "0a207df3-4bec-45b8-9060-13505d269da9", + "dtns": "a762cbdc-af7a-4b2b-9463-67f0189dd2ae", +} + @dataclass class SegmentationModelSpec: @@ -270,7 +288,7 @@ def _create_iriapi_client() -> Client: return httpx.Client( base_url=_API_BASE_URLS[NERSCLoginMethod.IRIAPI], headers={"Authorization": f"Bearer {access_token}"}, - timeout=httpx.Timeout(connect=10.0, read=60.0, write=10.0, pool=10.0), + timeout=httpx.Timeout(connect=10.0, read=120.0, write=30.0, pool=10.0), ) @staticmethod @@ -387,7 +405,6 @@ def _submit_job(self, job_script: str, num_nodes: int = 1) -> str: return str(job.jobid) elif self.login_method is NERSCLoginMethod.IRIAPI: - # Parse SBATCH directives before stripping them sbatch_values = {} for line in job_script.splitlines(): if line.startswith("#SBATCH"): @@ -397,44 +414,79 @@ def _submit_job(self, job_script: str, num_nodes: int = 1) -> str: sbatch_values["account"] = line.split("-A ")[-1].strip() elif "--time=" in line: t = line.split("--time=")[-1].strip() - # convert HH:MM:SS to seconds parts = t.split(":") sbatch_values["duration"] = int(parts[0])*3600 + int(parts[1])*60 + int(parts[2]) elif "-N " in line: sbatch_values["node_count"] = int(line.split("-N ")[-1].strip()) elif "-C " in line: sbatch_values["constraint"] = line.split("-C ")[-1].strip() + elif "--output=" in line: + sbatch_values["stdout_path"] = line.split("--output=")[-1].strip() + elif "--error=" in line: + sbatch_values["stderr_path"] = line.split("--error=")[-1].strip() + # Strip shebang and SBATCH headers, keep the script body script_body = "\n".join( line for line in job_script.splitlines() if not line.startswith("#SBATCH") and not line.startswith("#!/") ).strip() + constraint = sbatch_values.get("constraint", "cpu") + is_gpu = "gpu" in constraint.lower() + + resources = { + "node_count": sbatch_values.get("node_count", 1), + "processes_per_node": 1, + "exclusive_node_use": True, + } + if is_gpu: + resources["gpu_cores_per_process"] = 4 + else: + resources["cpu_cores_per_process"] = 128 + job_spec = { "executable": "/bin/bash", - "arguments": ["-c", script_body], - "resources": { - "node_count": sbatch_values.get("node_count", 1), - "processes_per_node": 1, - "cpu_cores_per_process": 64, - "exclusive_node_use": True, - }, + "arguments": ["-s"], # read script from stdin isn't supported, so... + "pre_launch": script_body, # run the body here before the executable + "resources": resources, + # { + # "node_count": sbatch_values.get("node_count", 1), + # "processes_per_node": 1, + # "cpu_cores_per_process": 64, + # "exclusive_node_use": True, + # }, "attributes": { "duration": sbatch_values.get("duration", 1800), - "queue_name": sbatch_values.get("queue_name", "realtime"), + "queue_name": sbatch_values.get("queue_name", "regular"), "account": sbatch_values.get("account", "als"), "custom_attributes": { - "constraint": sbatch_values.get("constraint", "cpu") + "constraint": constraint # sbatch_values.get("constraint", "cpu") }, }, } + + if "stdout_path" in sbatch_values: + job_spec["stdout_path"] = sbatch_values["stdout_path"] + if "stderr_path" in sbatch_values: + job_spec["stderr_path"] = sbatch_values["stderr_path"] + response = self.client.post( - f"/api/v1/compute/job/{_IRI_COMPUTE_RESOURCE}", + "/api/v1/compute/job/3cf3c048-855e-4dd8-a189-065a483954bb", json=job_spec, ) + if not response.is_success: + logger.error(f"Job submission failed: {response.status_code} {response.text}") + logger.error(f"Job spec was: {json.dumps(job_spec, indent=2)}") response.raise_for_status() return str(response.json()["id"]) + # response = self.client.post( + # "/api/v1/compute/job/3cf3c048-855e-4dd8-a189-065a483954bb", + # json=job_spec, + # ) + # response.raise_for_status() + # return str(response.json()["id"]) + else: raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") @@ -485,7 +537,7 @@ def _mkdir_remote(self, path: str) -> None: perlmutter.run(f"mkdir -p {path}") elif self.login_method is NERSCLoginMethod.IRIAPI: response = self.client.post( - "/api/v1/filesystem/mkdir/perlmutter", + f"/api/v1/filesystem/mkdir/{RESOURCE_IDS["perlmutter_login"]}", json={"path": path, "parents": True}, ) response.raise_for_status() @@ -514,11 +566,32 @@ def _read_remote_file(self, path: str) -> str: elif self.login_method is NERSCLoginMethod.IRIAPI: response = self.client.get( - "/api/v1/filesystem/file/perlmutter", + f"/api/v1/filesystem/view/{RESOURCE_IDS['perlmutter_login']}", params={"path": path}, ) response.raise_for_status() - return response.text + task_id = response.json().get("task_id") + if not task_id: + return response.text + + for _ in range(40): + task_response = self.client.get(f"/api/v1/task/{task_id}") + task_response.raise_for_status() + task = task_response.json() + status = task.get("status") + if status == "completed": + result = task.get("result", "") + if isinstance(result, dict): + output = result.get("output", result) + if isinstance(output, dict): + return output.get("content", str(output)) + return str(output) + return str(result) + elif status == "failed": + raise RuntimeError(f"File read task {task_id} failed: {task.get('result')}") + time.sleep(3) + + raise TimeoutError(f"File read task {task_id} did not complete") else: raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") @@ -567,6 +640,8 @@ def reconstruct( opts = _load_job_options("nersc-reconstruction-options", self.config.nersc_recon_settings) + logger.info(f"Resolved options: {opts}") + num_nodes = opts.get("num_nodes", num_nodes) cpus_per_task = opts["cpus-per-task"] qos = opts["qos"] @@ -631,6 +706,7 @@ def reconstruct( echo "METADATA_START=$(date +%s)" >> $TIMING_FILE NUM_SLICES=$(shifter \ + --image={recon_image} \ --volume={pscratch_path}/8.3.2:/alsdata \ python -c " import h5py @@ -675,6 +751,7 @@ def reconstruct( fi srun --nodes=1 --ntasks=1 --exclusive shifter \ + --image={recon_image} \ --env=NUMEXPR_MAX_THREADS=128 \ --env=NUMEXPR_NUM_THREADS=128 \ --env=OMP_NUM_THREADS=128 \ @@ -933,7 +1010,7 @@ def segmentation_sam3( #SBATCH -A {account} {reservation_line} #SBATCH -N {num_nodes} -#SBATCH -C {constraint} # gpu +#SBATCH -C {constraint} #SBATCH --job-name={job_name} #SBATCH --time={walltime} #SBATCH --ntasks-per-node={ntasks_per_node} @@ -1913,7 +1990,6 @@ def nersc_petiole_segment_flow( file_path=file_path, num_nodes=num_nodes, config=config, - login_method=login_method ) if isinstance(recon_result, dict): @@ -1950,24 +2026,24 @@ def nersc_petiole_segment_flow( logger.info("Reconstruction Successful.") # ── STEP 2: Transfer TIFFs to data832 ──────────────────────────────────── - logger.info("Transferring reconstructed TIFFs from NERSC pscratch to data832") - try: - data832_tiff_future = globus_transfer_task.submit( - file_path=scratch_path_tiff, - source=config.nersc832_alsdev_pscratch_scratch, - destination=config.data832_scratch, - config=config, - ) - logger.info("TIFF transfer to data832 submitted.") - except Exception as e: - logger.error(f"Failed to transfer TIFFs to data832: {e}") - data832_tiff_transfer_success = False + # logger.info("Transferring reconstructed TIFFs from NERSC pscratch to data832") + # try: + # data832_tiff_future = globus_transfer_task.submit( + # file_path=scratch_path_tiff, + # source=config.nersc832_alsdev_pscratch_scratch, + # destination=config.data832_scratch, + # config=config, + # ) + # logger.info("TIFF transfer to data832 submitted.") + # except Exception as e: + # logger.error(f"Failed to transfer TIFFs to data832: {e}") + # data832_tiff_transfer_success = False # ── STEP 3: SAM3 / DINOv3 ────────────────────────── logger.info("Submitting SAM3 and DINOv3 segmentation tasks concurrently.") sam3_future = nersc_segmentation_sam3_task.submit( - recon_folder_path=scratch_path_tiff, config=config, login_method=login_method + recon_folder_path=scratch_path_tiff, config=config ) dinov3_future = nersc_segmentation_dinov3_task.submit( recon_folder_path=scratch_path_tiff, config=config, project="petiole", login_method=login_method @@ -1979,15 +2055,17 @@ def nersc_petiole_segment_flow( logger.info(f"SAM3 segmentation result: {sam3_success}") if sam3_success: logger.info("Transferring SAM3 segmentation outputs to data832") - sam3_segment_path = f"{folder_name}/seg{file_name}/sam3" + # sam3_segment_path = f"{folder_name}/seg{file_name}/sam3" try: - data832_sam3_future = globus_transfer_task.submit( - file_path=sam3_segment_path, - source=config.nersc832_alsdev_pscratch_scratch, - destination=config.data832_scratch, - config=config, - ) - logger.info("SAM3 transfer to data832 submitted") + # data832_sam3_future = globus_transfer_task.submit( + # file_path=sam3_segment_path, + # source=config.nersc832_alsdev_pscratch_scratch, + # destination=config.data832_scratch, + # config=config, + # ) + # logger.info("SAM3 transfer to data832 submitted") + data832_sam3_transfer_success = True + logger.info(f"SAM3 transfer to data832 success: {data832_sam3_transfer_success}") except Exception as e: logger.error(f"Failed to transfer SAM3 outputs to data832: {e}") @@ -1995,15 +2073,17 @@ def nersc_petiole_segment_flow( logger.info(f"DINOv3 segmentation result: {dinov3_success}") if dinov3_success: logger.info("Transferring DINOv3 segmentation outputs to data832") - dinov3_segment_path = f"{folder_name}/seg{file_name}/dino" + # dinov3_segment_path = f"{folder_name}/seg{file_name}/dino" try: - data832_dinov3_future = globus_transfer_task.submit( - file_path=dinov3_segment_path, - source=config.nersc832_alsdev_pscratch_scratch, - destination=config.data832_scratch, - config=config, - ) - logger.info("DINOv3 transfer to data832 submitted") + # data832_dinov3_future = globus_transfer_task.submit( + # file_path=dinov3_segment_path, + # source=config.nersc832_alsdev_pscratch_scratch, + # destination=config.data832_scratch, + # config=config, + # ) + # logger.info("DINOv3 transfer to data832 submitted") + data832_dinov3_transfer_success = True + logger.info(f"DINOv3 transfer to data832 success: {data832_dinov3_transfer_success}") except Exception as e: logger.error(f"Failed to transfer DINOv3 outputs to data832: {e}") @@ -2016,22 +2096,24 @@ def nersc_petiole_segment_flow( logger.info("Running segmentation combination.") combine_future = nersc_combine_segmentations_task.submit( - recon_folder_path=scratch_path_tiff, config=config, login_method=login_method + recon_folder_path=scratch_path_tiff, config=config ) combine_success = combine_future.result() logger.info(f"Combination result: {combine_success}") if combine_success: logger.info("Transferring combined segmentation outputs to data832") - combined_segment_path = f"{folder_name}/seg{file_name}/combined/sam_dino" + # combined_segment_path = f"{folder_name}/seg{file_name}/combined/sam_dino" try: - data832_combined_future = globus_transfer_task.submit( - file_path=combined_segment_path, - source=config.nersc832_alsdev_pscratch_scratch, - destination=config.data832_scratch, - config=config, - ) - logger.info("Combined transfer to data832 submitted") + # data832_combined_future = globus_transfer_task.submit( + # file_path=combined_segment_path, + # source=config.nersc832_alsdev_pscratch_scratch, + # destination=config.data832_scratch, + # config=config, + # ) + # logger.info("Combined transfer to data832 submitted") + data832_combined_transfer_success = True + logger.info(f"Combined transfer to data832 success: {data832_combined_transfer_success}") except Exception as e: logger.error(f"Failed to transfer combined outputs to data832: {e}") @@ -2527,7 +2609,6 @@ def nersc_multiresolution_integration_test() -> bool: def nersc_segmentation_sam3_task( recon_folder_path: str, config: Optional[Config832] = None, - login_method: Optional[NERSCLoginMethod] = NERSCLoginMethod.IRIAPI ) -> bool: """ Run segmentation task at NERSC. @@ -2546,7 +2627,7 @@ def nersc_segmentation_sam3_task( tomography_controller = get_controller( hpc_type=HPC.NERSC, config=config, - login_method=login_method + login_method=NERSCLoginMethod.IRIAPI ) logger.info(f"Starting NERSC segmentation task for {recon_folder_path=}") nersc_segmentation_success = tomography_controller.segmentation_sam3( @@ -2573,7 +2654,7 @@ def nersc_segmentation_dinov3_task( if config is None: logger.info("No config provided, using default Config832.") config = Config832() - tomography_controller = get_controller(hpc_type=HPC.NERSC, config=config, login_method=login_method) + tomography_controller = get_controller(hpc_type=HPC.NERSC, config=config, login_method=NERSCLoginMethod.IRIAPI) logger.info(f"Starting NERSC DINOv3 segmentation task for {recon_folder_path=}, {project=}") success = tomography_controller.segmentation_dinov3(recon_folder_path=recon_folder_path, project=project) if not success: @@ -2587,13 +2668,12 @@ def nersc_segmentation_dinov3_task( def nersc_combine_segmentations_task( recon_folder_path: str, config: Optional[Config832] = None, - login_method: Optional[NERSCLoginMethod] = NERSCLoginMethod.IRIAPI ) -> bool: logger = get_run_logger() if config is None: logger.info("No config provided, using default Config832.") config = Config832() - tomography_controller = get_controller(hpc_type=HPC.NERSC, config=config, login_method=login_method) + tomography_controller = get_controller(hpc_type=HPC.NERSC, config=config, login_method=NERSCLoginMethod.IRIAPI) logger.info(f"Starting NERSC combine segmentations task for {recon_folder_path=}") success = tomography_controller.combine_segmentations(recon_folder_path=recon_folder_path) if not success: @@ -2622,9 +2702,14 @@ def nersc_segmentation_sam3_integration_test() -> bool: return flow_success -if __name__ == "__main__": - nersc_segmentation_dinov3_task( - recon_folder_path='dabramov/recmoon/', - config=Config832(), - project="moon" - ) +# if __name__ == "__main__": + # nersc_segmentation_dinov3_task( + # recon_folder_path='dabramov/recmoon/', + # config=Config832(), + # project="moon" + # ) + # nersc_petiole_segment_flow( + # file_path='dabramov/20260221_143000_petiole28', + # num_nodes=2, + # login_method=NERSCLoginMethod.IRIAPI + # ) From a490bfee39ac6f8344b33851a8d393f21dd3df1f Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 15 Apr 2026 16:03:38 -0700 Subject: [PATCH 23/29] removing IRIAPI client ID from nersc.py, since it is only used in globus/get_globus_token.py --- orchestration/flows/bl832/nersc.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 5ea74f9b..808a2a4c 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -269,13 +269,6 @@ def _create_iriapi_client() -> Client: ValueError: If ``GLOBUS_CLIENT_ID`` or ``GLOBUS_CLIENT_SECRET`` are unset. RuntimeError: If the acquired token is missing required scopes. """ - client_id = "fae5c579-490a-4d76-b6eb-d78f65caeb63" # os.getenv(_IRIAPI_GLOBUS_CLIENT_ID_ENV) - - if not client_id: - raise ValueError( - f"Globus client ID is unset. Set {_IRIAPI_GLOBUS_CLIENT_ID_ENV}." - ) - token_file_env = os.getenv(_IRIAPI_TOKEN_FILE_ENV) token_file = Path(token_file_env) if token_file_env else DEFAULT_TOKEN_FILE From 041f33688536035b6858a4d62959f20499400099 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Thu, 23 Apr 2026 11:04:53 -0700 Subject: [PATCH 24/29] Updating logger comments --- orchestration/flows/bl832/nersc.py | 42 ++++++++++-------------------- 1 file changed, 14 insertions(+), 28 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 808a2a4c..39bd4a02 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -209,7 +209,7 @@ def __init__( @staticmethod def create_nersc_client( login_method: NERSCLoginMethod = NERSCLoginMethod.SFAPI, - ) -> Client: + ) -> Client | httpx.Client: """Create and return a NERSC client for the requested login method. Two fundamentally different auth strategies are supported: @@ -255,7 +255,7 @@ def create_nersc_client( return client @staticmethod - def _create_iriapi_client() -> Client: + def _create_iriapi_client() -> httpx.Client: """Create a NERSC client for the IRI API using a Globus bearer token. Requires ``GLOBUS_CLIENT_ID`` and ``GLOBUS_CLIENT_SECRET`` in the @@ -263,7 +263,7 @@ def _create_iriapi_client() -> Client: via the client credentials grant. No browser or user interaction. Returns: - An authenticated :class:`sfapi_client.Client` targeting the IRI API. + An authenticated :class:`httpx.Client` targeting the IRI API. Raises: ValueError: If ``GLOBUS_CLIENT_ID`` or ``GLOBUS_CLIENT_SECRET`` are unset. @@ -442,18 +442,12 @@ def _submit_job(self, job_script: str, num_nodes: int = 1) -> str: "arguments": ["-s"], # read script from stdin isn't supported, so... "pre_launch": script_body, # run the body here before the executable "resources": resources, - # { - # "node_count": sbatch_values.get("node_count", 1), - # "processes_per_node": 1, - # "cpu_cores_per_process": 64, - # "exclusive_node_use": True, - # }, "attributes": { "duration": sbatch_values.get("duration", 1800), "queue_name": sbatch_values.get("queue_name", "regular"), "account": sbatch_values.get("account", "als"), "custom_attributes": { - "constraint": constraint # sbatch_values.get("constraint", "cpu") + "constraint": constraint }, }, } @@ -464,7 +458,7 @@ def _submit_job(self, job_script: str, num_nodes: int = 1) -> str: job_spec["stderr_path"] = sbatch_values["stderr_path"] response = self.client.post( - "/api/v1/compute/job/3cf3c048-855e-4dd8-a189-065a483954bb", + f"/api/v1/compute/job/{RESOURCE_IDS['perlmutter_job_submit']}", json=job_spec, ) if not response.is_success: @@ -473,13 +467,6 @@ def _submit_job(self, job_script: str, num_nodes: int = 1) -> str: response.raise_for_status() return str(response.json()["id"]) - # response = self.client.post( - # "/api/v1/compute/job/3cf3c048-855e-4dd8-a189-065a483954bb", - # json=job_spec, - # ) - # response.raise_for_status() - # return str(response.json()["id"]) - else: raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") @@ -1640,15 +1627,15 @@ def check_shifter_image( pscratch_path = f"/pscratch/sd/{username[0]}/{username}" output_file = f"{pscratch_path}/tomo_recon_logs/shifter_check.txt" check_script = f"""#!/bin/bash - #SBATCH -q debug - #SBATCH -A als - #SBATCH -C cpu - #SBATCH -N 1 - #SBATCH --ntasks=1 - #SBATCH --cpus-per-task=1 - #SBATCH --time=0:05:00 - shifterimg images | grep -E "$(echo {image} | sed 's/:/.*/g')" > {output_file} 2>&1 || true - """ +#SBATCH -q debug +#SBATCH -A als +#SBATCH -C cpu +#SBATCH -N 1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=1 +#SBATCH --time=0:05:00 +shifterimg images | grep -E "$(echo {image} | sed 's/:/.*/g')" > {output_file} 2>&1 || true +""" job_id = self._submit_job(check_script) self._wait_for_job(job_id) output = self._read_remote_file(output_file) @@ -2689,7 +2676,6 @@ def nersc_segmentation_sam3_integration_test() -> bool: flow_success = nersc_segmentation_sam3_task( recon_folder_path=recon_folder_path, config=Config832(), - login_method=NERSCLoginMethod.IRIAPI ) logger.info(f"Flow success: {flow_success}") return flow_success From 863b24e94dd2cab4e7ef5b92fdd78faa53fe261f Mon Sep 17 00:00:00 2001 From: David Abramov Date: Fri, 24 Apr 2026 12:37:31 -0700 Subject: [PATCH 25/29] connecting to AmSC MLflow service --- .env.example | 7 +- config.yml | 4 + orchestration/flows/bl832/config.py | 2 +- orchestration/flows/bl832/register_mlflow.py | 188 ++++++++++++++++--- orchestration/mlflow.py | 86 ++++++++- 5 files changed, 262 insertions(+), 25 deletions(-) diff --git a/.env.example b/.env.example index e3728e89..b7e54812 100644 --- a/.env.example +++ b/.env.example @@ -1,7 +1,12 @@ +BEAMLINE=8.3.2 GLOBUS_CLIENT_ID= GLOBUS_CLIENT_SECRET= PREFECT_API_URL= PREFECT_API_KEY= PUSHGATEWAY_URL= JOB_NAME= -INSTANCE_LABEL= \ No newline at end of file +INSTANCE_LABEL= +PATH_NERSC_CLIENT_ID= +PATH_NERSC_PRI_KEY= +NERSC_USERNAME= +AMSC_API_KEY= # found here: https://profile.american-science-cloud.org/ \ No newline at end of file diff --git a/config.yml b/config.yml index 2d6d2ab3..6ee1cbb7 100644 --- a/config.yml +++ b/config.yml @@ -173,6 +173,10 @@ mlflow: staging: tracking_uri: https://mlflow-staging.computing.als.lbl.gov registry_uri: https://mlflow-staging.computing.als.lbl.gov + amsc: + tracking_uri: https://mlflow.american-science-cloud.org/ + registry_uri: https://mlflow.american-science-cloud.org/ + experiment_name: als-bl832-models hpc_submission_settings832: # ── RECON + MULTIRES SETTINGS ─────────────────────────────────────────────── diff --git a/orchestration/flows/bl832/config.py b/orchestration/flows/bl832/config.py index 8bbbf78c..281ba167 100644 --- a/orchestration/flows/bl832/config.py +++ b/orchestration/flows/bl832/config.py @@ -30,7 +30,7 @@ def _beam_specific_config(self) -> None: # SciCat self.scicat = self.config["scicat"] # MLflow - self.mlflow = self.config["mlflow"]["local"] + self.mlflow = self.config["mlflow"]["amsc"] # NERSC HPC submission settings self.ghcr_images832 = self.config["ghcr_images832"] self.nersc_recon_settings = self.config["hpc_submission_settings832"]["nersc_reconstruction"] diff --git a/orchestration/flows/bl832/register_mlflow.py b/orchestration/flows/bl832/register_mlflow.py index 31fa3760..d540ea67 100644 --- a/orchestration/flows/bl832/register_mlflow.py +++ b/orchestration/flows/bl832/register_mlflow.py @@ -6,6 +6,7 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) +logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)-7s | %(name)s - %(message)s") def register_mlflow_checkpoints(): @@ -16,14 +17,18 @@ def register_mlflow_checkpoints(): register_checkpoint( model_name="sam3-petiole", nersc_path=f"{scripts_dir}sam3_finetune/sam3/checkpoint_v6.pt", - alcf_path="/eagle/IRIBeta/als/seg_models/sam3/checkpoint_v6.pt", + alcf_path="/eagle/SYNAPS-I/segmentation/sam3_finetune/sam3/checkpoint_v6.pt", config=config, alias="production", description="SAM3 v6 fine-tuned on petiole micro-CT data.", inference_params={ - # ── paths ────────────────────────────────────────────────────────── - "original_checkpoint_path": - f"{scripts_dir}sam3_finetune/sam3/sam3.pt", + # ── site-specific HF caches ───────────────────────────────────────── + "nersc_hf_home": "/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface", + "nersc_hf_hub_cache": "/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface/hub", + "alcf_hf_home": "/eagle/SYNAPS-I/segmentation/.cache/huggingface", + "alcf_hf_hub_cache": "/eagle/SYNAPS-I/segmentation/.cache/huggingface", + # ── paths ─────────────────────────────────────────────────────────── + "original_checkpoint_path": f"{scripts_dir}sam3_finetune/sam3/sam3.pt", "bpe_path": f"{scripts_dir}sam3_finetune/sam3/bpe_simple_vocab_16e6.txt.gz", "conda_env_path": "/global/cfs/cdirs/als/data_mover/8.3.2/envs/sam3-py311", "seg_scripts_dir": f"{scripts_dir}inference_latest/forge_feb_seg_model_demo/", @@ -32,9 +37,9 @@ def register_mlflow_checkpoints(): "script_name": "src/inference_v6.py", "batch_size": 1, "patch_size": 400, - "confidence": [0.5], # list → JSON-encoded automatically + "confidence": [0.5], "overlap": 0.25, - "prompts": [ # list → JSON-encoded automatically + "prompts": [ "Phloem Fibers", "Hydrated Xylem vessels", "Air-based Pith cells", @@ -46,12 +51,17 @@ def register_mlflow_checkpoints(): register_checkpoint( model_name="dinov3-petiole", nersc_path="/global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/dino/best.ckpt", - alcf_path="/eagle/IRIBeta/als/seg_models/dino/best.ckpt", + alcf_path="/eagle/SYNAPS-I/segmentation/dino/best.ckpt", config=config, alias="production", description="DINOv3 fine-tuned on petiole micro-CT data.", inference_params={ - # ── paths ────────────────────────────────────────────────────────── + # ── site-specific HF caches ───────────────────────────────────────── + "nersc_hf_home": "/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface", + "nersc_hf_hub_cache": "/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface/hub", + "alcf_hf_home": "/eagle/SYNAPS-I/segmentation/.cache/huggingface", + "alcf_hf_hub_cache": "/eagle/SYNAPS-I/segmentation/.cache/huggingface", + # ── paths ─────────────────────────────────────────────────────────── "conda_env_path": "/global/cfs/cdirs/als/data_mover/8.3.2/envs/dino_demo", "seg_scripts_dir": f"{scripts_dir}inference_v5_multiseg/forge_feb_seg_model_demo/", # ── inference hyperparameters ─────────────────────────────────────── @@ -64,19 +74,102 @@ def register_mlflow_checkpoints(): register_checkpoint( model_name="dinov3-moon", nersc_path="/global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/dino/best_moon.ckpt", - alcf_path="/eagle/IRIBeta/als/seg_models/dino/best_moon.ckpt", + alcf_path="/eagle/SYNAPS-I/segmentation/seg_models/dino/best_moon.ckpt", config=config, alias="production", description="DINOv3 fine-tuned on lunar regolith micro-CT data (ice, particles, pores).", inference_params={ + # ── site-specific HF caches ───────────────────────────────────────── + "nersc_hf_home": "/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface", + "nersc_hf_hub_cache": "/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface/hub", + "alcf_hf_home": "/eagle/SYNAPS-I/segmentation/.cache/huggingface", + "alcf_hf_hub_cache": "/eagle/SYNAPS-I/segmentation/.cache/huggingface", + # ── paths ─────────────────────────────────────────────────────────── "conda_env_path": "/global/cfs/cdirs/als/data_mover/8.3.2/envs/dino_demo", "seg_scripts_dir": f"{scripts_dir}moon_seg/forge_feb_seg_model_demo/", + # ── inference hyperparameters ─────────────────────────────────────── "script_name": "src.inference_dino_v2", "batch_size": 4, "nproc_per_node": 4, }, ) + # register_checkpoint( + # model_name="sam3-petiole", + # nersc_hf_home="/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface", + # nersc_hf_hub_cache="/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface/hub", + # nersc_path=f"{scripts_dir}sam3_finetune/sam3/checkpoint_v6.pt", + # alcf_hf_home="/eagle/SYNAPS-I/segmentation/.cache/huggingface", + # alcf_hf_hub_cache="/eagle/SYNAPS-I/segmentation/.cache/huggingface", + # alcf_path="/eagle/SYNAPS-I/segmentation/sam3_finetune/sam3/checkpoint_v6.pt", + # config=config, + # alias="production", + # description="SAM3 v6 fine-tuned on petiole micro-CT data.", + # inference_params={ + # # ── paths ────────────────────────────────────────────────────────── + # "original_checkpoint_path": + # f"{scripts_dir}sam3_finetune/sam3/sam3.pt", + # "bpe_path": f"{scripts_dir}sam3_finetune/sam3/bpe_simple_vocab_16e6.txt.gz", + # "conda_env_path": "/global/cfs/cdirs/als/data_mover/8.3.2/envs/sam3-py311", + # "seg_scripts_dir": f"{scripts_dir}inference_latest/forge_feb_seg_model_demo/", + # "checkpoints_dir": f"{scripts_dir}sam3_finetune/sam3/", + # # ── inference hyperparameters ─────────────────────────────────────── + # "script_name": "src/inference_v6.py", + # "batch_size": 1, + # "patch_size": 400, + # "confidence": [0.5], # list → JSON-encoded automatically + # "overlap": 0.25, + # "prompts": [ # list → JSON-encoded automatically + # "Phloem Fibers", + # "Hydrated Xylem vessels", + # "Air-based Pith cells", + # "Dehydrated Xylem vessels", + # ], + # }, + # ) + + # register_checkpoint( + # model_name="dinov3-petiole", + # nersc_hf_home="/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface", + # nersc_hf_hub_cache="/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface/hub", + # nersc_checkpoint_path="/global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/dino/best.ckpt", + # alcf_hf_home="/eagle/SYNAPS-I/segmentation/.cache/huggingface", + # alcf_hf_hub_cache="/eagle/SYNAPS-I/segmentation/.cache/huggingface", + # alcf_path="/eagle/SYNAPS-I/segmentation/dino/best.ckpt", + # config=config, + # alias="production", + # description="DINOv3 fine-tuned on petiole micro-CT data.", + # inference_params={ + # # ── paths ────────────────────────────────────────────────────────── + # "conda_env_path": "/global/cfs/cdirs/als/data_mover/8.3.2/envs/dino_demo", + # "seg_scripts_dir": f"{scripts_dir}inference_v5_multiseg/forge_feb_seg_model_demo/", + # # ── inference hyperparameters ─────────────────────────────────────── + # "script_name": "src.inference_dino_v1", + # "batch_size": 4, + # "nproc_per_node": 4, + # }, + # ) + + # register_checkpoint( + # model_name="dinov3-moon", + # nersc_hf_home="/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface", + # nersc_hf_hub_cache="/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface/hub", + # nersc_path="/global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/dino/best_moon.ckpt", + # alcf_hf_home="/eagle/SYNAPS-I/segmentation/.cache/huggingface", + # alcf_hf_hub_cache="/eagle/SYNAPS-I/segmentation/.cache/huggingface", + # alcf_path="/eagle/SYNAPS-I/segmentation/seg_models/dino/best_moon.ckpt", + # config=config, + # alias="production", + # description="DINOv3 fine-tuned on lunar regolith micro-CT data (ice, particles, pores).", + # inference_params={ + # "conda_env_path": "/global/cfs/cdirs/als/data_mover/8.3.2/envs/dino_demo", + # "seg_scripts_dir": f"{scripts_dir}moon_seg/forge_feb_seg_model_demo/", + # "script_name": "src.inference_dino_v2", + # "batch_size": 4, + # "nproc_per_node": 4, + # }, + # ) + def retrieve_mlflow_params_test() -> bool: """Test that _load_job_options correctly pulls inference params from the MLflow registry. @@ -106,28 +199,56 @@ def retrieve_mlflow_params_test() -> bool: ) sam3_checks = { - # MLflow should have overridden these + # ── MLflow should have overridden these ────────────────────────────── "finetuned_checkpoint_path": ( lambda v: "checkpoint" in v, "finetuned_checkpoint_path should contain 'checkpoint'" ), + "original_checkpoint_path": ( + lambda v: v.endswith(".pt") and "sam3" in v.lower(), + "original_checkpoint_path should point at a sam3 .pt file" + ), + "bpe_path": ( + lambda v: v.endswith(".txt.gz"), + "bpe_path should point at a .txt.gz vocab file" + ), "conda_env_path": ( lambda v: "sam3" in v, "conda_env_path should reference sam3 env" ), + "seg_scripts_dir": ( + lambda v: isinstance(v, str) and len(v) > 0, + "seg_scripts_dir should be a non-empty path" + ), + "checkpoints_dir": ( + lambda v: isinstance(v, str) and len(v) > 0, + "checkpoints_dir should be a non-empty path" + ), + "script_name": ( + lambda v: "inference" in v.lower(), + "script_name should reference an inference script" + ), "prompts": ( lambda v: isinstance(v, list) and len(v) > 0, "prompts should be a non-empty list (JSON-deserialized)" ), "confidence": ( - lambda v: isinstance(v, list), - "confidence should be a list (JSON-deserialized)" + lambda v: isinstance(v, list) and len(v) > 0, + "confidence should be a non-empty list (JSON-deserialized)" ), "batch_size": ( - lambda v: isinstance(v, int), - "batch_size should be an int" + lambda v: isinstance(v, int) and v > 0, + "batch_size should be a positive int" ), - # SLURM params should still come from config + "patch_size": ( + lambda v: isinstance(v, int) and v > 0, + "patch_size should be a positive int" + ), + "overlap": ( + lambda v: isinstance(v, float) and 0.0 <= v < 1.0, + "overlap should be a float in [0.0, 1.0)" + ), + # ── SLURM params should still come from config ─────────────────────── "qos": ( lambda v: v == config.nersc_segment_sam3_settings["qos"], "qos should be unchanged from config" @@ -160,23 +281,32 @@ def retrieve_mlflow_params_test() -> bool: ) dino_checks = { + # ── MLflow-overridden ──────────────────────────────────────────────── "dino_checkpoint_path": ( lambda v: v.endswith(".ckpt"), "dino_checkpoint_path should end with .ckpt" ), "conda_env_path": ( - lambda v: len(v) > 0, + lambda v: isinstance(v, str) and len(v) > 0, "conda_env_path should be non-empty" ), - "batch_size": ( - lambda v: isinstance(v, int) and v > 0, - "batch_size should be a positive int" + "seg_scripts_dir": ( + lambda v: isinstance(v, str) and len(v) > 0, + "seg_scripts_dir should be a non-empty path" ), "script_name": ( lambda v: "dino" in v.lower(), "script_name should reference dino" ), - # SLURM params unchanged + "batch_size": ( + lambda v: isinstance(v, int) and v > 0, + "batch_size should be a positive int" + ), + "nproc_per_node": ( + lambda v: isinstance(v, int) and v > 0, + "nproc_per_node should be a positive int" + ), + # ── SLURM params unchanged ─────────────────────────────────────────── "qos": ( lambda v: v == config.nersc_segment_dinov3_settings["qos"], "qos should be unchanged from config" @@ -209,9 +339,18 @@ def retrieve_mlflow_params_test() -> bool: ) moon_checks = { + # ── MLflow-overridden ──────────────────────────────────────────────── "dino_checkpoint_path": ( - lambda v: v.endswith(".ckpt"), - "dino_checkpoint_path should end with .ckpt" + lambda v: v.endswith(".ckpt") and "moon" in v.lower(), + "dino_checkpoint_path should end with .ckpt and reference moon" + ), + "conda_env_path": ( + lambda v: isinstance(v, str) and len(v) > 0, + "conda_env_path should be non-empty" + ), + "seg_scripts_dir": ( + lambda v: isinstance(v, str) and "moon" in v.lower(), + "seg_scripts_dir should reference moon_seg" ), "script_name": ( lambda v: "v2" in v.lower(), @@ -221,6 +360,11 @@ def retrieve_mlflow_params_test() -> bool: lambda v: isinstance(v, int) and v > 0, "batch_size should be a positive int" ), + "nproc_per_node": ( + lambda v: isinstance(v, int) and v > 0, + "nproc_per_node should be a positive int" + ), + # ── SLURM params unchanged ─────────────────────────────────────────── "qos": ( lambda v: v == config.nersc_segment_dinov3_moon_settings["qos"], "qos should be unchanged from config" diff --git a/orchestration/mlflow.py b/orchestration/mlflow.py index cff8487c..c337a2ba 100644 --- a/orchestration/mlflow.py +++ b/orchestration/mlflow.py @@ -1,16 +1,23 @@ import logging from dataclasses import dataclass, field +from dotenv import load_dotenv import json +import os import requests from typing import Any import mlflow from mlflow.tracking import MlflowClient +import mlflow.utils.rest_utils as rest_utils + from orchestration.config import BeamlineConfig logger = logging.getLogger(__name__) +_AMSC_PATCH_FLAG: str = "_amsc_x_api_key_patched" +load_dotenv() + @dataclass class ModelCheckpointInfo: @@ -48,13 +55,63 @@ def _is_mlflow_reachable(tracking_uri: str, timeout: float = 2.0) -> bool: Returns: True if the server responds with HTTP 200, False otherwise. """ + headers = {} + api_key = os.environ.get("AMSC_API_KEY") + if api_key: + headers["X-Api-Key"] = api_key try: - response = requests.get(f"{tracking_uri}/health", timeout=timeout) + response = requests.get( + f"{tracking_uri}/health", headers=headers, timeout=timeout + ) return response.status_code == 200 except Exception: return False +def _enable_amsc_x_api_key() -> bool: + """Patch mlflow.utils.rest_utils.http_request to inject X-Api-Key. + + Required by the American Science Cloud MLflow server, which enforces + API-key auth on all REST calls. Standard MLflow does not send custom + headers, so we wrap ``http_request`` at import time. + + Idempotent: repeat calls are no-ops thanks to a sentinel attribute on + the wrapper. Silently skips patching if ``AMSC_API_KEY`` is unset, + which lets the same codebase target non-AMSC MLflow servers. + + Returns: + True if the patch is (or was already) active, False if the API + key env var is unset. + """ + + api_key = os.environ.get("AMSC_API_KEY") + if not api_key: + return False + + if getattr(rest_utils.http_request, _AMSC_PATCH_FLAG, False): + return True + + original = rest_utils.http_request + + def patched(host_creds, endpoint, method, *args, **kwargs): + # MLflow internals call http_request with either `headers` or + # `extra_headers` depending on the code path — handle both. + if "headers" in kwargs and kwargs["headers"] is not None: + h = dict(kwargs["headers"]) + h["X-Api-Key"] = api_key + kwargs["headers"] = h + else: + h = dict(kwargs.get("extra_headers") or {}) + h["X-Api-Key"] = api_key + kwargs["extra_headers"] = h + return original(host_creds, endpoint, method, *args, **kwargs) + + setattr(patched, _AMSC_PATCH_FLAG, True) + rest_utils.http_request = patched + logger.info("AMSC X-Api-Key injection enabled for MLflow REST calls.") + return True + + def get_mlflow_client(config: BeamlineConfig) -> MlflowClient: """Construct an MlflowClient pointed at the configured tracking server. @@ -65,6 +122,7 @@ def get_mlflow_client(config: BeamlineConfig) -> MlflowClient: An authenticated MlflowClient instance. """ tracking_uri = config.mlflow["tracking_uri"] + _enable_amsc_x_api_key() # Idempotent patch for AMSC API key injection mlflow.set_tracking_uri(tracking_uri) return MlflowClient(tracking_uri=tracking_uri) @@ -183,7 +241,19 @@ def register_checkpoint( client.create_registered_model(model_name) mlflow.set_tracking_uri(config.mlflow["tracking_uri"]) + + # Use a dedicated experiment so the creator (this user) gets MANAGE + # permission automatically — avoids 403 on the default experiment. + experiment_name = config.mlflow.get("experiment_name", "als-model-registration") + experiment = mlflow.get_experiment_by_name(experiment_name) + if experiment is None: + experiment_id = mlflow.create_experiment(experiment_name) + logger.info(f"Created MLflow experiment '{experiment_name}' (id={experiment_id}).") + else: + experiment_id = experiment.experiment_id + with mlflow.start_run( + experiment_id=experiment_id, run_name=f"register_{model_name}", tags={"mlflow.note.content": description}, ) as run: @@ -247,7 +317,21 @@ def log_segmentation_metrics( run_tags: dict[str, str] = {"model": model_name, "slurm_job_id": job_id} + tracking_uri = config.mlflow["tracking_uri"] + mlflow.set_tracking_uri(tracking_uri) + _enable_amsc_x_api_key() # ensure AMSC auth patch is active for this entrypoint too + + experiment_name = config.mlflow.get("experiment_name", "als-model-registration") + experiment = mlflow.get_experiment_by_name(experiment_name) + if experiment is None: + experiment_id = mlflow.create_experiment(experiment_name) + else: + experiment_id = experiment.experiment_id + + run_tags: dict[str, str] = {"model": model_name, "slurm_job_id": job_id} + with mlflow.start_run( + experiment_id=experiment_id, run_name=run_name, nested=parent_run_id is not None, parent_run_id=parent_run_id, From 0144f525f55cb4feaa4a1fa24b8163ed8ec0e805 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Fri, 24 Apr 2026 13:20:29 -0700 Subject: [PATCH 26/29] removing old commented code --- orchestration/flows/bl832/register_mlflow.py | 76 -------------------- 1 file changed, 76 deletions(-) diff --git a/orchestration/flows/bl832/register_mlflow.py b/orchestration/flows/bl832/register_mlflow.py index d540ea67..93603295 100644 --- a/orchestration/flows/bl832/register_mlflow.py +++ b/orchestration/flows/bl832/register_mlflow.py @@ -94,82 +94,6 @@ def register_mlflow_checkpoints(): }, ) - # register_checkpoint( - # model_name="sam3-petiole", - # nersc_hf_home="/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface", - # nersc_hf_hub_cache="/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface/hub", - # nersc_path=f"{scripts_dir}sam3_finetune/sam3/checkpoint_v6.pt", - # alcf_hf_home="/eagle/SYNAPS-I/segmentation/.cache/huggingface", - # alcf_hf_hub_cache="/eagle/SYNAPS-I/segmentation/.cache/huggingface", - # alcf_path="/eagle/SYNAPS-I/segmentation/sam3_finetune/sam3/checkpoint_v6.pt", - # config=config, - # alias="production", - # description="SAM3 v6 fine-tuned on petiole micro-CT data.", - # inference_params={ - # # ── paths ────────────────────────────────────────────────────────── - # "original_checkpoint_path": - # f"{scripts_dir}sam3_finetune/sam3/sam3.pt", - # "bpe_path": f"{scripts_dir}sam3_finetune/sam3/bpe_simple_vocab_16e6.txt.gz", - # "conda_env_path": "/global/cfs/cdirs/als/data_mover/8.3.2/envs/sam3-py311", - # "seg_scripts_dir": f"{scripts_dir}inference_latest/forge_feb_seg_model_demo/", - # "checkpoints_dir": f"{scripts_dir}sam3_finetune/sam3/", - # # ── inference hyperparameters ─────────────────────────────────────── - # "script_name": "src/inference_v6.py", - # "batch_size": 1, - # "patch_size": 400, - # "confidence": [0.5], # list → JSON-encoded automatically - # "overlap": 0.25, - # "prompts": [ # list → JSON-encoded automatically - # "Phloem Fibers", - # "Hydrated Xylem vessels", - # "Air-based Pith cells", - # "Dehydrated Xylem vessels", - # ], - # }, - # ) - - # register_checkpoint( - # model_name="dinov3-petiole", - # nersc_hf_home="/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface", - # nersc_hf_hub_cache="/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface/hub", - # nersc_checkpoint_path="/global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/dino/best.ckpt", - # alcf_hf_home="/eagle/SYNAPS-I/segmentation/.cache/huggingface", - # alcf_hf_hub_cache="/eagle/SYNAPS-I/segmentation/.cache/huggingface", - # alcf_path="/eagle/SYNAPS-I/segmentation/dino/best.ckpt", - # config=config, - # alias="production", - # description="DINOv3 fine-tuned on petiole micro-CT data.", - # inference_params={ - # # ── paths ────────────────────────────────────────────────────────── - # "conda_env_path": "/global/cfs/cdirs/als/data_mover/8.3.2/envs/dino_demo", - # "seg_scripts_dir": f"{scripts_dir}inference_v5_multiseg/forge_feb_seg_model_demo/", - # # ── inference hyperparameters ─────────────────────────────────────── - # "script_name": "src.inference_dino_v1", - # "batch_size": 4, - # "nproc_per_node": 4, - # }, - # ) - - # register_checkpoint( - # model_name="dinov3-moon", - # nersc_hf_home="/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface", - # nersc_hf_hub_cache="/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface/hub", - # nersc_path="/global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/dino/best_moon.ckpt", - # alcf_hf_home="/eagle/SYNAPS-I/segmentation/.cache/huggingface", - # alcf_hf_hub_cache="/eagle/SYNAPS-I/segmentation/.cache/huggingface", - # alcf_path="/eagle/SYNAPS-I/segmentation/seg_models/dino/best_moon.ckpt", - # config=config, - # alias="production", - # description="DINOv3 fine-tuned on lunar regolith micro-CT data (ice, particles, pores).", - # inference_params={ - # "conda_env_path": "/global/cfs/cdirs/als/data_mover/8.3.2/envs/dino_demo", - # "seg_scripts_dir": f"{scripts_dir}moon_seg/forge_feb_seg_model_demo/", - # "script_name": "src.inference_dino_v2", - # "batch_size": 4, - # "nproc_per_node": 4, - # }, - # ) - def retrieve_mlflow_params_test() -> bool: """Test that _load_job_options correctly pulls inference params from the MLflow registry. From 0ad03aca2729c628c9c9b479e4db3f9eddd87803 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Fri, 24 Apr 2026 13:20:38 -0700 Subject: [PATCH 27/29] updating pytest --- orchestration/_tests/test_bl832/test_nersc.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/orchestration/_tests/test_bl832/test_nersc.py b/orchestration/_tests/test_bl832/test_nersc.py index 3ba9742f..9952335d 100644 --- a/orchestration/_tests/test_bl832/test_nersc.py +++ b/orchestration/_tests/test_bl832/test_nersc.py @@ -5,6 +5,8 @@ from prefect.blocks.system import Secret from prefect.testing.utilities import prefect_test_harness +from orchestration.flows.bl832.nersc import RESOURCE_IDS, _IRI_COMPUTE_RESOURCE + # ────────────────────────────────────────────────────────────────────────────── # Session fixture @@ -445,18 +447,21 @@ def test_reconstruct_iriapi_success(mocker, mock_iriapi_client, mock_config832, assert result["success"] is True assert result["job_id"] == "99999" mock_iriapi_client.post.assert_called_once() - assert mock_iriapi_client.post.call_args.args[0] == "/api/v1/compute/job/compute" + assert ( + mock_iriapi_client.post.call_args.args[0] + == f"/api/v1/compute/job/{RESOURCE_IDS['perlmutter_job_submit']}" + ) posted_json = mock_iriapi_client.post.call_args.kwargs["json"] assert posted_json["executable"] == "/bin/bash" - assert posted_json["arguments"][0] == "-c" - assert isinstance(posted_json["arguments"][1], str) # the script body - assert "tomo_recon" in posted_json["arguments"][1] # sanity check it's the right script - assert mock_iriapi_client.get.call_count == 2 + assert posted_json["arguments"] == ["-s"] # matches nersc.py + assert "pre_launch" in posted_json # script body lives here + assert "tomo_recon" in posted_json["pre_launch"] # sanity check it's the right script + mock_iriapi_client.get.assert_any_call( - "/api/v1/compute/status/compute/99999" + f"/api/v1/compute/status/{_IRI_COMPUTE_RESOURCE}/99999" ) mock_iriapi_client.get.assert_any_call( - "/api/v1/filesystem/file/perlmutter", + f"/api/v1/filesystem/view/{RESOURCE_IDS['perlmutter_login']}", params={"path": mocker.ANY}, ) From 9d8e2c113470070f754cbdfeb19e170d70623a28 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Fri, 24 Apr 2026 13:25:48 -0700 Subject: [PATCH 28/29] linting --- orchestration/flows/bl832/nersc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 39bd4a02..6e4292d0 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -517,7 +517,7 @@ def _mkdir_remote(self, path: str) -> None: perlmutter.run(f"mkdir -p {path}") elif self.login_method is NERSCLoginMethod.IRIAPI: response = self.client.post( - f"/api/v1/filesystem/mkdir/{RESOURCE_IDS["perlmutter_login"]}", + f"/api/v1/filesystem/mkdir/{RESOURCE_IDS['perlmutter_login']}", json={"path": path, "parents": True}, ) response.raise_for_status() From e4f4e08d72656bc4ce55714be0eec4be131d66ff Mon Sep 17 00:00:00 2001 From: David Abramov Date: Fri, 24 Apr 2026 13:31:46 -0700 Subject: [PATCH 29/29] adjusting import in pytest to avoid error on github that did not occur locally --- orchestration/_tests/test_bl832/test_nersc.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/orchestration/_tests/test_bl832/test_nersc.py b/orchestration/_tests/test_bl832/test_nersc.py index 9952335d..1ec12264 100644 --- a/orchestration/_tests/test_bl832/test_nersc.py +++ b/orchestration/_tests/test_bl832/test_nersc.py @@ -5,8 +5,6 @@ from prefect.blocks.system import Secret from prefect.testing.utilities import prefect_test_harness -from orchestration.flows.bl832.nersc import RESOURCE_IDS, _IRI_COMPUTE_RESOURCE - # ────────────────────────────────────────────────────────────────────────────── # Session fixture @@ -432,6 +430,7 @@ def test_reconstruct_sfapi_submission_failure(mocker, mock_sfapi_client, mock_co def test_reconstruct_iriapi_success(mocker, mock_iriapi_client, mock_config832, monkeypatch): """IRIAPI reconstruct POSTs a job and polls for COMPLETED state.""" from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + from orchestration.flows.bl832.nersc import RESOURCE_IDS, _IRI_COMPUTE_RESOURCE monkeypatch.setenv("NERSC_USERNAME", "alsdev") mocker.patch("orchestration.flows.bl832.nersc.time.sleep")