diff --git a/.env.example b/.env.example index e3728e89..b7e54812 100644 --- a/.env.example +++ b/.env.example @@ -1,7 +1,12 @@ +BEAMLINE=8.3.2 GLOBUS_CLIENT_ID= GLOBUS_CLIENT_SECRET= PREFECT_API_URL= PREFECT_API_KEY= PUSHGATEWAY_URL= JOB_NAME= -INSTANCE_LABEL= \ No newline at end of file +INSTANCE_LABEL= +PATH_NERSC_CLIENT_ID= +PATH_NERSC_PRI_KEY= +NERSC_USERNAME= +AMSC_API_KEY= # found here: https://profile.american-science-cloud.org/ \ No newline at end of file diff --git a/config.yml b/config.yml index ea576b54..6ee1cbb7 100644 --- a/config.yml +++ b/config.yml @@ -173,20 +173,24 @@ mlflow: staging: tracking_uri: https://mlflow-staging.computing.als.lbl.gov registry_uri: https://mlflow-staging.computing.als.lbl.gov + amsc: + tracking_uri: https://mlflow.american-science-cloud.org/ + registry_uri: https://mlflow.american-science-cloud.org/ + experiment_name: als-bl832-models hpc_submission_settings832: # ── RECON + MULTIRES SETTINGS ─────────────────────────────────────────────── nersc_reconstruction: # ── SLURM resource allocation ───────────────────────────────────────────── - qos: realtime + qos: debug account: als - reservation: "_CAP_TOMO_MOON_CPU" - num_nodes: 16 + reservation: "" + num_nodes: 2 cpus-per-task: 128 walltime: "0:30:00" nersc_multiresolution: # ── SLURM resource allocation ───────────────────────────────────────────── - qos: realtime + qos: debug account: als reservation: "" cpus-per-task: 128 @@ -195,15 +199,15 @@ hpc_submission_settings832: # ── PETIOLE SEGMENTATION SETTINGS ─────────────────────────────────────────── nersc_segmentation_sam3: # ── SLURM resource allocation ───────────────────────────────────────────── - qos: regular + qos: debug account: als constraint: gpu reservation: "" - num_nodes: 4 + num_nodes: 2 ntasks-per-node: 1 gpus-per-node: 4 cpus-per-task: 128 - walltime: "00:59:00" + walltime: "00:30:00" # ── Inference parameters ────────────────────────────────────────────────── script_name: "src/inference_v6.py" batch_size: 1 @@ -226,16 +230,16 @@ hpc_submission_settings832: finetuned_checkpoint_path: /global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/sam3_finetune/sam3/checkpoint_v6.pt nersc_segmentation_dinov3: # ── SLURM resource allocation ───────────────────────────────────────────── - qos: regular + qos: debug account: als constraint: gpu reservation: "" - num_nodes: 4 + num_nodes: 2 ntasks-per-node: 1 nproc_per_node: 4 gpus-per-node: 4 cpus-per-task: 128 - walltime: "00:59:00" + walltime: "00:30:00" # ── Inference parameters ────────────────────────────────────────────────── script_name: "src.inference_dino_v1" batch_size: 4 @@ -246,14 +250,14 @@ hpc_submission_settings832: dino_checkpoint_path: /global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/dino/best.ckpt nersc_combine_segmentations: # ── SLURM resource allocation ───────────────────────────────────────────── - qos: regular + qos: debug account: als constraint: cpu reservation: "" - num_nodes: 4 + num_nodes: 2 ntasks: 128 cpus-per-task: 1 - walltime: "01:00:00" + walltime: "00:30:00" # ── Combination parameters ──────────────────────────────────────────────── script_name: "src.combine_sam_dino_v3" dilate_px: 5 diff --git a/orchestration/_tests/test_bl832/test_nersc.py b/orchestration/_tests/test_bl832/test_nersc.py index 8d7056a8..1ec12264 100644 --- a/orchestration/_tests/test_bl832/test_nersc.py +++ b/orchestration/_tests/test_bl832/test_nersc.py @@ -1,5 +1,4 @@ -# orchestration/_tests/bl832/test_nersc.py - +# orchestration/_tests/test_bl832/test_nersc.py import pytest from uuid import uuid4 @@ -20,18 +19,18 @@ def prefect_test_fixture(): yield -# ────────────────────────────────────────────────────────────────────────────── +# --------------------------------------------------------------------------- # Shared fixtures -# ────────────────────────────────────────────────────────────────────────────── +# --------------------------------------------------------------------------- @pytest.fixture def mock_sfapi_client(mocker): - """Mock sfapi_client.Client with a completed job on Perlmutter.""" - mock_client = mocker.MagicMock() + """sfapi_client.Client mock with user, compute, submit_job, and job chained.""" + client = mocker.MagicMock() mock_user = mocker.MagicMock() mock_user.name = "testuser" - mock_client.user.return_value = mock_user + client.user.return_value = mock_user mock_job = mocker.MagicMock() mock_job.jobid = "12345" @@ -39,10 +38,9 @@ def mock_sfapi_client(mocker): mock_compute = mocker.MagicMock() mock_compute.submit_job.return_value = mock_job - mock_client.compute.return_value = mock_compute - - mocker.patch("orchestration.flows.bl832.nersc.Client", return_value=mock_client) - return mock_client + client.compute.return_value = mock_compute + mocker.patch("orchestration.flows.bl832.nersc.Client", return_value=client) + return client @pytest.fixture @@ -167,11 +165,28 @@ def _make_future(mocker, value): return f -# ────────────────────────────────────────────────────────────────────────────── -# create_sfapi_client -# ────────────────────────────────────────────────────────────────────────────── +@pytest.fixture +def mock_iriapi_client(mocker): + """httpx.Client mock for IRI API responses.""" + client = mocker.MagicMock() + + submit_response = mocker.MagicMock() + submit_response.json.return_value = {"id": "99999"} + client.post.return_value = submit_response + + status_response = mocker.MagicMock() + status_response.json.return_value = {"status": {"state": "completed"}} + client.get.return_value = status_response + + return client + +# --------------------------------------------------------------------------- +# _create_sfapi_client +# --------------------------------------------------------------------------- + def test_create_sfapi_client_success(mocker): + """Valid credentials produce a Client instance.""" from orchestration.flows.bl832.nersc import NERSCTomographyHPCController mocker.patch("orchestration.flows.bl832.nersc.os.getenv", side_effect=lambda x: { @@ -179,29 +194,34 @@ def test_create_sfapi_client_success(mocker): "PATH_NERSC_PRI_KEY": "/path/to/client_secret", }.get(x)) mocker.patch("orchestration.flows.bl832.nersc.os.path.isfile", return_value=True) - mocker.patch("builtins.open", side_effect=[ - mocker.mock_open(read_data="client_id_value")(), - mocker.mock_open(read_data='{"key": "value"}')(), - ]) + mocker.patch( + "builtins.open", + side_effect=[ + mocker.mock_open(read_data="my-client-id")(), + mocker.mock_open(read_data='{"kty": "RSA", "n": "x", "e": "y"}')(), + ] + ) mocker.patch("orchestration.flows.bl832.nersc.JsonWebKey.import_key", return_value="mock_secret") mock_client_cls = mocker.patch("orchestration.flows.bl832.nersc.Client") - client = NERSCTomographyHPCController.create_sfapi_client() + client = NERSCTomographyHPCController._create_sfapi_client() - mock_client_cls.assert_called_once_with("client_id_value", "mock_secret") - assert client == mock_client_cls.return_value + mock_client_cls.assert_called_once_with("my-client-id", "mock_secret") + assert client is mock_client_cls.return_value def test_create_sfapi_client_missing_paths(mocker): + """Unset env vars raise ValueError.""" from orchestration.flows.bl832.nersc import NERSCTomographyHPCController mocker.patch("orchestration.flows.bl832.nersc.os.getenv", return_value=None) with pytest.raises(ValueError, match="Missing NERSC credentials paths."): - NERSCTomographyHPCController.create_sfapi_client() + NERSCTomographyHPCController._create_sfapi_client() def test_create_sfapi_client_missing_files(mocker): + """Env vars set but files absent raise FileNotFoundError.""" from orchestration.flows.bl832.nersc import NERSCTomographyHPCController mocker.patch("orchestration.flows.bl832.nersc.os.getenv", side_effect=lambda x: { @@ -211,40 +231,7 @@ def test_create_sfapi_client_missing_files(mocker): mocker.patch("orchestration.flows.bl832.nersc.os.path.isfile", return_value=False) with pytest.raises(FileNotFoundError, match="NERSC credential files are missing."): - NERSCTomographyHPCController.create_sfapi_client() - - -# ────────────────────────────────────────────────────────────────────────────── -# reconstruct -# ────────────────────────────────────────────────────────────────────────────── - -def test_reconstruct_success(mocker, mock_sfapi_client, mock_config832): - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - from sfapi_client.compute import Machine - - mocker.patch("orchestration.flows.bl832.nersc.time.sleep") - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) - - result = controller.reconstruct(file_path="folder/file.h5") - - mock_sfapi_client.compute.assert_called_once_with(Machine.perlmutter) - mock_sfapi_client.compute.return_value.submit_job.assert_called_once() - mock_sfapi_client.compute.return_value.submit_job.return_value.complete.assert_called_once() - assert isinstance(result, dict) - assert result["success"] is True - assert result["job_id"] == "12345" - - -def test_reconstruct_submission_failure(mocker, mock_sfapi_client, mock_config832): - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - - mocker.patch("orchestration.flows.bl832.nersc.time.sleep") - mock_sfapi_client.compute.return_value.submit_job.side_effect = Exception("Submission failed") - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) - - result = controller.reconstruct(file_path="folder/file.h5") - - assert result is False + NERSCTomographyHPCController._create_sfapi_client() # ────────────────────────────────────────────────────────────────────────────── @@ -260,9 +247,10 @@ def test_build_multi_resolution_success(mocker, mock_sfapi_client, mock_config83 result = controller.build_multi_resolution(file_path="folder/file.h5") - mock_sfapi_client.compute.assert_called_once_with(Machine.perlmutter) + assert mock_sfapi_client.compute.call_count == 2 + mock_sfapi_client.compute.assert_called_with(Machine.perlmutter) mock_sfapi_client.compute.return_value.submit_job.assert_called_once() - mock_sfapi_client.compute.return_value.submit_job.return_value.complete.assert_called_once() + mock_sfapi_client.compute.return_value.job.return_value.complete.assert_called_once() assert result is True @@ -295,7 +283,7 @@ def test_segmentation_sam3_success(mocker, mock_sfapi_client, mock_config832): mock_sfapi_client.compute.assert_called_with(Machine.perlmutter) mock_sfapi_client.compute.return_value.submit_job.assert_called_once() - mock_sfapi_client.compute.return_value.submit_job.return_value.complete.assert_called_once() + mock_sfapi_client.compute.return_value.job.return_value.complete.assert_called_once() assert isinstance(result, dict) assert result["success"] is True assert result["job_id"] == "12345" @@ -373,7 +361,7 @@ def test_segmentation_dinov3_success(mocker, mock_sfapi_client, mock_config832): mock_sfapi_client.compute.assert_called_with(Machine.perlmutter) mock_sfapi_client.compute.return_value.submit_job.assert_called_once() - mock_sfapi_client.compute.return_value.submit_job.return_value.complete.assert_called_once() + mock_sfapi_client.compute.return_value.job.return_value.complete.assert_called_once() assert result is True @@ -386,6 +374,170 @@ def test_segmentation_dinov3_submission_failure(mocker, mock_sfapi_client, mock_ controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) result = controller.segmentation_dinov3(recon_folder_path="folder/recfile") + assert result is False + +# --------------------------------------------------------------------------- +# reconstruct — SFAPI +# --------------------------------------------------------------------------- + + +def test_reconstruct_sfapi_success(mocker, mock_sfapi_client, mock_config832): + """SFAPI reconstruct submits a job and waits for completion.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + from sfapi_client.compute import Machine + + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + + controller = NERSCTomographyHPCController( + client=mock_sfapi_client, + config=mock_config832, + login_method=NERSCLoginMethod.SFAPI, + ) + + result = controller.reconstruct(file_path="folder/scan.h5") + + assert result["success"] is True + assert result["job_id"] == "12345" + assert mock_sfapi_client.compute.call_count == 3 # 1 _submit_job() + 1 _wait_for_job() + 1 _fetch_timing_data() + mock_sfapi_client.compute.assert_called_with(Machine.perlmutter) + mock_sfapi_client.compute.return_value.submit_job.assert_called_once() + mock_sfapi_client.compute.return_value.job.assert_called_once_with(jobid="12345") + mock_sfapi_client.compute.return_value.job.return_value.complete.assert_called_once() + + +def test_reconstruct_sfapi_submission_failure(mocker, mock_sfapi_client, mock_config832): + """SFAPI reconstruct returns False when submission raises.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + mock_sfapi_client.compute.return_value.submit_job.side_effect = Exception("SFAPI error") + + controller = NERSCTomographyHPCController( + client=mock_sfapi_client, + config=mock_config832, + login_method=NERSCLoginMethod.SFAPI, + ) + + result = controller.reconstruct(file_path="folder/scan.h5") + + assert result["success"] is False + + +# --------------------------------------------------------------------------- +# reconstruct — IRIAPI +# --------------------------------------------------------------------------- + +def test_reconstruct_iriapi_success(mocker, mock_iriapi_client, mock_config832, monkeypatch): + """IRIAPI reconstruct POSTs a job and polls for COMPLETED state.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + from orchestration.flows.bl832.nersc import RESOURCE_IDS, _IRI_COMPUTE_RESOURCE + + monkeypatch.setenv("NERSC_USERNAME", "alsdev") + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + + controller = NERSCTomographyHPCController( + client=mock_iriapi_client, + config=mock_config832, + login_method=NERSCLoginMethod.IRIAPI, + ) + + result = controller.reconstruct(file_path="folder/scan.h5") + + assert result["success"] is True + assert result["job_id"] == "99999" + mock_iriapi_client.post.assert_called_once() + assert ( + mock_iriapi_client.post.call_args.args[0] + == f"/api/v1/compute/job/{RESOURCE_IDS['perlmutter_job_submit']}" + ) + posted_json = mock_iriapi_client.post.call_args.kwargs["json"] + assert posted_json["executable"] == "/bin/bash" + assert posted_json["arguments"] == ["-s"] # matches nersc.py + assert "pre_launch" in posted_json # script body lives here + assert "tomo_recon" in posted_json["pre_launch"] # sanity check it's the right script + + mock_iriapi_client.get.assert_any_call( + f"/api/v1/compute/status/{_IRI_COMPUTE_RESOURCE}/99999" + ) + mock_iriapi_client.get.assert_any_call( + f"/api/v1/filesystem/view/{RESOURCE_IDS['perlmutter_login']}", + params={"path": mocker.ANY}, + ) + + +def test_reconstruct_iriapi_job_failed(mocker, mock_iriapi_client, mock_config832, monkeypatch): + """IRIAPI reconstruct returns False when job state is FAILED.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + + monkeypatch.setenv("NERSC_USERNAME", "alsdev") + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + mock_iriapi_client.get.return_value.json.return_value = {"status": {"state": "failed"}} # was {"state": "FAILED"} + + controller = NERSCTomographyHPCController( + client=mock_iriapi_client, + config=mock_config832, + login_method=NERSCLoginMethod.IRIAPI, + ) + + result = controller.reconstruct(file_path="folder/scan.h5") + + assert result["success"] is False + + +def test_reconstruct_iriapi_missing_username(mocker, mock_iriapi_client, mock_config832, monkeypatch): + """IRIAPI reconstruct raises ValueError when NERSC_USERNAME is unset.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + + monkeypatch.delenv("NERSC_USERNAME", raising=False) + + controller = NERSCTomographyHPCController( + client=mock_iriapi_client, + config=mock_config832, + login_method=NERSCLoginMethod.IRIAPI, + ) + + with pytest.raises(ValueError, match="NERSC_USERNAME"): + controller.reconstruct(file_path="folder/scan.h5") + + +# --------------------------------------------------------------------------- +# build_multi_resolution — SFAPI +# --------------------------------------------------------------------------- + +def test_build_multi_resolution_sfapi_success(mocker, mock_sfapi_client, mock_config832): + """SFAPI build_multi_resolution submits and waits successfully.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + from sfapi_client.compute import Machine + + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + + controller = NERSCTomographyHPCController( + client=mock_sfapi_client, + config=mock_config832, + login_method=NERSCLoginMethod.SFAPI, + ) + + result = controller.build_multi_resolution(file_path="folder/scan.h5") + + assert result is True + assert mock_sfapi_client.compute.call_count == 2 + mock_sfapi_client.compute.assert_called_with(Machine.perlmutter) + + +def test_build_multi_resolution_sfapi_failure(mocker, mock_sfapi_client, mock_config832): + """SFAPI build_multi_resolution returns False when submission raises.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + mock_sfapi_client.compute.return_value.submit_job.side_effect = Exception("error") + + controller = NERSCTomographyHPCController( + client=mock_sfapi_client, + config=mock_config832, + login_method=NERSCLoginMethod.SFAPI, + ) + + result = controller.build_multi_resolution(file_path="folder/scan.h5") assert result is False @@ -436,7 +588,7 @@ def test_combine_segmentations_success(mocker, mock_sfapi_client, mock_config832 mock_sfapi_client.compute.assert_called_with(Machine.perlmutter) mock_sfapi_client.compute.return_value.submit_job.assert_called_once() - mock_sfapi_client.compute.return_value.submit_job.return_value.complete.assert_called_once() + mock_sfapi_client.compute.return_value.job.return_value.complete.assert_called_once() assert result is True @@ -794,3 +946,47 @@ def test_moon_segment_flow_no_sam3_no_combine(mocker, mock_config832, mock_recon mock_sam3_task.submit.assert_not_called() mock_combine_task.submit.assert_not_called() +# --------------------------------------------------------------------------- +# build_multi_resolution — IRIAPI +# --------------------------------------------------------------------------- + + +def test_build_multi_resolution_iriapi_success(mocker, mock_iriapi_client, mock_config832, monkeypatch): + """IRIAPI build_multi_resolution POSTs and polls successfully.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + + monkeypatch.setenv("NERSC_USERNAME", "alsdev") + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + + controller = NERSCTomographyHPCController( + client=mock_iriapi_client, + config=mock_config832, + login_method=NERSCLoginMethod.IRIAPI, + ) + + result = controller.build_multi_resolution(file_path="folder/scan.h5") + + assert result is True + mock_iriapi_client.post.assert_called_once() + mock_iriapi_client.get.assert_called_once_with( + "/api/v1/compute/status/compute/99999" + ) + + +def test_build_multi_resolution_iriapi_failure(mocker, mock_iriapi_client, mock_config832, monkeypatch): + """IRIAPI build_multi_resolution returns False when job state is failed.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + + monkeypatch.setenv("NERSC_USERNAME", "alsdev") + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + mock_iriapi_client.get.return_value.json.return_value = {"status": {"state": "failed"}} + + controller = NERSCTomographyHPCController( + client=mock_iriapi_client, + config=mock_config832, + login_method=NERSCLoginMethod.IRIAPI, + ) + + result = controller.build_multi_resolution(file_path="folder/scan.h5") + + assert result is False diff --git a/orchestration/_tests/test_sfapi_flow.py b/orchestration/_tests/test_sfapi_flow.py index 6e9bf225..e0d4a854 100644 --- a/orchestration/_tests/test_sfapi_flow.py +++ b/orchestration/_tests/test_sfapi_flow.py @@ -1,8 +1,5 @@ # orchestration/_tests/test_sfapi_flow.py - -from pathlib import Path import pytest -from unittest.mock import MagicMock, patch, mock_open from uuid import uuid4 from prefect.blocks.system import Secret @@ -11,307 +8,56 @@ @pytest.fixture(autouse=True, scope="session") def prefect_test_fixture(): - """ - A pytest fixture that automatically sets up and tears down the Prefect test harness - for the entire test session. It creates and saves test secrets and configurations - required for Globus integration. - - Yields: - None - """ with prefect_test_harness(): - globus_client_id = Secret(value=str(uuid4())) - globus_client_id.save(name="globus-client-id", overwrite=True) - globus_client_secret = Secret(value=str(uuid4())) - globus_client_secret.save(name="globus-client-secret", overwrite=True) - + Secret(value=str(uuid4())).save(name="globus-client-id", overwrite=True) + Secret(value=str(uuid4())).save(name="globus-client-secret", overwrite=True) yield -# ---------------------------- -# Tests for create_sfapi_client -# ---------------------------- - - -def test_create_sfapi_client_success(): - """ - Test successful creation of the SFAPI client. - """ - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - - # Mock data for client_id and client_secret files - mock_client_id = 'value' - mock_client_secret = '{"key": "value"}' - - # Create separate mock_open instances for each file - mock_open_client_id = mock_open(read_data=mock_client_id) - mock_open_client_secret = mock_open(read_data=mock_client_secret) - - with patch("orchestration.flows.bl832.nersc.os.getenv") as mock_getenv, \ - patch("orchestration.flows.bl832.nersc.os.path.isfile") as mock_isfile, \ - patch("builtins.open", side_effect=[ - mock_open_client_id.return_value, - mock_open_client_secret.return_value - ]), \ - patch("orchestration.flows.bl832.nersc.JsonWebKey.import_key") as mock_import_key, \ - patch("orchestration.flows.bl832.nersc.Client") as MockClient: - - # Mock environment variables - mock_getenv.side_effect = lambda x: { - "PATH_NERSC_CLIENT_ID": "/path/to/client_id", - "PATH_NERSC_PRI_KEY": "/path/to/client_secret" - }.get(x, None) - - # Mock file existence - mock_isfile.return_value = True - - # Mock JsonWebKey.import_key to return a mock secret - mock_import_key.return_value = "mock_secret" - - # Create the client - client = NERSCTomographyHPCController.create_sfapi_client() - - # Assert that Client was instantiated with 'value' and 'mock_secret' - MockClient.assert_called_once_with("value", "mock_secret") - - # Assert that the returned client is the mocked client - assert client == MockClient.return_value, "Client should be the mocked sfapi_client.Client instance" - - -def test_create_sfapi_client_missing_paths(): - """ - Test creation of the SFAPI client with missing credential paths. - """ - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - - with patch("orchestration.flows.bl832.nersc.os.getenv", return_value=None): - with pytest.raises(ValueError, match="Missing NERSC credentials paths."): - NERSCTomographyHPCController.create_sfapi_client() - - -def test_create_sfapi_client_missing_files(): - """ - Test creation of the SFAPI client with missing credential files. - """ - with ( - # Mock environment variables - patch( - "orchestration.flows.bl832.nersc.os.getenv", - side_effect=lambda x: { - "PATH_NERSC_CLIENT_ID": "/path/to/client_id", - "PATH_NERSC_PRI_KEY": "/path/to/client_secret" - }.get(x, None) - ), - - # Mock file existence to simulate missing files - patch("orchestration.flows.bl832.nersc.os.path.isfile", return_value=False) - ): - # Import the module after applying patches to ensure mocks are in place - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - - # Expect a FileNotFoundError due to missing credential files - with pytest.raises(FileNotFoundError, match="NERSC credential files are missing."): - NERSCTomographyHPCController.create_sfapi_client() - -# ---------------------------- -# Fixture for Mocking SFAPI Client -# ---------------------------- - - -@pytest.fixture -def mock_sfapi_client(): - """ - Mock the sfapi_client.Client class with necessary methods. - """ - with patch("orchestration.flows.bl832.nersc.Client") as MockClient: - mock_client_instance = MockClient.return_value - - # Mock the user method - mock_user = MagicMock() - mock_user.name = "testuser" - mock_client_instance.user.return_value = mock_user - - # Mock the compute method to return a mocked compute object - mock_compute = MagicMock() - mock_job = MagicMock() - mock_job.jobid = "12345" - mock_job.state = "COMPLETED" - mock_compute.submit_job.return_value = mock_job - mock_client_instance.compute.return_value = mock_compute - - yield mock_client_instance - - -# ---------------------------- -# Fixture for Mocking Config832 -# ---------------------------- - -@pytest.fixture -def mock_config832(): - """ - Mock the Config832 class to provide necessary configurations. - - All settings dicts must be fully populated to match the config YAML schema, - because _load_job_options() passes config_settings directly as the defaults - dict and then accesses keys by name. - """ - with patch("orchestration.flows.bl832.nersc.Config832") as MockConfig: - mock_config = MockConfig.return_value - mock_config.ghcr_images832 = { - "recon_image": "mock_recon_image", - "multires_image": "mock_multires_image", - } - mock_config.nersc_recon_settings = { - "qos": "realtime", - "account": "mock_account", - "reservation": "", - "num_nodes": 4, - "cpus-per-task": 128, - "walltime": "0:30:00", - } - mock_config.nersc_multiresolution_settings = { - "qos": "realtime", - "account": "mock_account", - "reservation": "", - "cpus-per-task": 128, - "walltime": "0:15:00", - } - mock_config.apps = {"als_transfer": "some_config"} - yield mock_config - - -# ---------------------------- -# Tests for NERSCTomographyHPCController -# ---------------------------- - -def test_reconstruct_success(mock_sfapi_client, mock_config832): - """ - Test successful reconstruction job submission. - """ +def test_create_sfapi_client_success(mocker): + """Valid credentials produce a Client instance.""" from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - from sfapi_client.compute import Machine - - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) - file_path = "path/to/file.h5" - - with patch("orchestration.flows.bl832.nersc.time.sleep", return_value=None): - result = controller.reconstruct(file_path=file_path) - - # Verify that compute was called with Machine.perlmutter - mock_sfapi_client.compute.assert_called_once_with(Machine.perlmutter) - - # Verify that submit_job was called once - mock_sfapi_client.compute.return_value.submit_job.assert_called_once() - - # Verify that complete was called on the job - mock_sfapi_client.compute.return_value.submit_job.return_value.complete.assert_called_once() - - # Assert that the method returns True - assert isinstance(result, dict) - assert result["success"] is True - assert result["job_id"] == "12345" - -def test_reconstruct_submission_failure(mock_sfapi_client, mock_config832): - """ - Test reconstruction job submission failure. - """ + mocker.patch("orchestration.flows.bl832.nersc.os.getenv", side_effect=lambda x: { + "PATH_NERSC_CLIENT_ID": "/path/to/client_id", + "PATH_NERSC_PRI_KEY": "/path/to/client_secret", + }.get(x)) + mocker.patch("orchestration.flows.bl832.nersc.os.path.isfile", return_value=True) + mocker.patch( + "builtins.open", + side_effect=[ + mocker.mock_open(read_data="my-client-id")(), + mocker.mock_open(read_data='{"kty": "RSA", "n": "x", "e": "y"}')(), + ] + ) + mocker.patch("orchestration.flows.bl832.nersc.JsonWebKey.import_key", return_value="mock_secret") + mock_client_cls = mocker.patch("orchestration.flows.bl832.nersc.Client") + + client = NERSCTomographyHPCController._create_sfapi_client() + + mock_client_cls.assert_called_once_with("my-client-id", "mock_secret") + assert client is mock_client_cls.return_value + + +def test_create_sfapi_client_missing_paths(mocker): + """Unset env vars raise ValueError.""" from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) - file_path = "path/to/file.h5" - - # Simulate submission failure - mock_sfapi_client.compute.return_value.submit_job.side_effect = Exception("Submission failed") + mocker.patch("orchestration.flows.bl832.nersc.os.getenv", return_value=None) - with patch("orchestration.flows.bl832.nersc.time.sleep", return_value=None): - result = controller.reconstruct(file_path=file_path) + with pytest.raises(ValueError, match="Missing NERSC credentials paths."): + NERSCTomographyHPCController._create_sfapi_client() - # Assert that the method returns False - assert result is False, "reconstruct should return False on submission failure." - -def test_build_multi_resolution_success(mock_sfapi_client, mock_config832): - """ - Test successful multi-resolution job submission. - """ +def test_create_sfapi_client_missing_files(mocker): + """Env vars set but files absent raise FileNotFoundError.""" from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - from sfapi_client.compute import Machine - - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) - file_path = "path/to/file.h5" - - with patch("orchestration.flows.bl832.nersc.time.sleep", return_value=None): - result = controller.build_multi_resolution(file_path=file_path) - - # Verify that compute was called with Machine.perlmutter - mock_sfapi_client.compute.assert_called_once_with(Machine.perlmutter) - - # Verify that submit_job was called once - mock_sfapi_client.compute.return_value.submit_job.assert_called_once() - - # Verify that complete was called on the job - mock_sfapi_client.compute.return_value.submit_job.return_value.complete.assert_called_once() - - # Assert that the method returns True - assert result is True, "build_multi_resolution should return True on successful job completion." - - -def test_build_multi_resolution_submission_failure(mock_sfapi_client, mock_config832): - """ - Test multi-resolution job submission failure. - """ - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) - file_path = "path/to/file.h5" - - # Simulate submission failure - mock_sfapi_client.compute.return_value.submit_job.side_effect = Exception("Submission failed") - - with patch("orchestration.flows.bl832.nersc.time.sleep", return_value=None): - result = controller.build_multi_resolution(file_path=file_path) - - # Assert that the method returns False - assert result is False, "build_multi_resolution should return False on submission failure." - - -def test_job_submission(mock_sfapi_client): - """ - Test job submission and status updates. - """ - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - from sfapi_client.compute import Machine - - mock_config = MagicMock() - mock_config.nersc_recon_settings = { - "qos": "realtime", - "account": "mock_account", - "reservation": "", - "num_nodes": 4, - "cpus-per-task": 128, - "walltime": "0:30:00", - } - - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config) - file_path = "path/to/file.h5" - - # Mock Path to extract file and folder names - with patch.object(Path, 'parent', new_callable=MagicMock) as mock_parent, \ - patch.object(Path, 'stem', new_callable=MagicMock) as mock_stem: - mock_parent.name = "to" - mock_stem.return_value = "file" - - with patch("orchestration.flows.bl832.nersc.time.sleep", return_value=None): - controller.reconstruct(file_path=file_path) - - # Verify that compute was called with Machine.perlmutter - mock_sfapi_client.compute.assert_called_once_with(Machine.perlmutter) - # Verify that submit_job was called once - mock_sfapi_client.compute.return_value.submit_job.assert_called_once() + mocker.patch("orchestration.flows.bl832.nersc.os.getenv", side_effect=lambda x: { + "PATH_NERSC_CLIENT_ID": "/path/to/client_id", + "PATH_NERSC_PRI_KEY": "/path/to/client_secret", + }.get(x)) + mocker.patch("orchestration.flows.bl832.nersc.os.path.isfile", return_value=False) - # Verify the returned job has the expected attributes - submitted_job = mock_sfapi_client.compute.return_value.submit_job.return_value - assert submitted_job.jobid == "12345", "Job ID should match the mock job ID." - assert submitted_job.state == "COMPLETED", "Job state should be COMPLETED." + with pytest.raises(FileNotFoundError, match="NERSC credential files are missing."): + NERSCTomographyHPCController._create_sfapi_client() diff --git a/orchestration/flows/bl832/config.py b/orchestration/flows/bl832/config.py index 8bbbf78c..281ba167 100644 --- a/orchestration/flows/bl832/config.py +++ b/orchestration/flows/bl832/config.py @@ -30,7 +30,7 @@ def _beam_specific_config(self) -> None: # SciCat self.scicat = self.config["scicat"] # MLflow - self.mlflow = self.config["mlflow"]["local"] + self.mlflow = self.config["mlflow"]["amsc"] # NERSC HPC submission settings self.ghcr_images832 = self.config["ghcr_images832"] self.nersc_recon_settings = self.config["hpc_submission_settings832"]["nersc_reconstruction"] diff --git a/orchestration/flows/bl832/job_controller.py b/orchestration/flows/bl832/job_controller.py index b2ff064b..1a23d02a 100644 --- a/orchestration/flows/bl832/job_controller.py +++ b/orchestration/flows/bl832/job_controller.py @@ -10,6 +10,19 @@ load_dotenv() +class NERSCLoginMethod(Enum): + """Selects which NERSC API login method to use when creating a NERSC client. + + Each method corresponds to a different set of credentials and API base URL. + """ + + SFAPI = "sfapi" + """Standard Superfacility API via Iris-registered OAuth2 credentials.""" + + IRIAPI = "iriapi" + """Integrated Research Infrastructure API via IRI-registered OAuth2 credentials.""" + + class TomographyHPCController(ABC): """ Abstract class for tomography HPC controllers. @@ -65,7 +78,8 @@ class HPC(Enum): def get_controller( hpc_type: HPC, - config: Config832 + config: Config832, + login_method: "NERSCLoginMethod | None" = None, ) -> TomographyHPCController: """ Factory function that returns an HPC controller instance for the given HPC environment. @@ -86,10 +100,14 @@ def get_controller( config=config ) elif hpc_type == HPC.NERSC: - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + resolved_login_method = login_method if isinstance(login_method, NERSCLoginMethod) else NERSCLoginMethod.SFAPI return NERSCTomographyHPCController( - client=NERSCTomographyHPCController.create_sfapi_client(), - config=config + client=NERSCTomographyHPCController.create_nersc_client( + login_method=resolved_login_method + ), + config=config, + login_method=resolved_login_method, ) elif hpc_type == HPC.OLCF: # TODO: Implement OLCF controller diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index f4ffc9fb..6e4292d0 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -1,6 +1,8 @@ from dataclasses import dataclass, field import datetime from dotenv import load_dotenv +import enum +import httpx import json import logging import os @@ -16,19 +18,51 @@ from typing import Any, Optional from orchestration.flows.bl832.config import Config832 -from orchestration.flows.bl832.job_controller import get_controller, HPC, TomographyHPCController -from orchestration.mlflow import get_checkpoint_info -from orchestration.prune_controller import get_prune_controller, PruneMethod -from orchestration.transfer_controller import globus_transfer_task +from orchestration.flows.bl832.job_controller import get_controller, HPC, NERSCLoginMethod, TomographyHPCController from orchestration.flows.bl832.streaming_mixin import ( NerscStreamingMixin, SlurmJobBlock, cancellation_hook, monitor_streaming_job, save_block ) +from orchestration.mlflow import get_checkpoint_info +from orchestration.globus.get_globus_token import ( + get_iri_access_token, + DEFAULT_TOKEN_FILE, +) from orchestration.prefect import schedule_prefect_flow +from orchestration.prune_controller import get_prune_controller, PruneMethod +from orchestration.transfer_controller import globus_transfer_task logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) load_dotenv() +# Applies only to NERSCLoginMethod.IRIAPI +_IRIAPI_GLOBUS_CLIENT_ID_ENV: str = "GLOBUS_CLIENT_ID" +_IRI_COMPUTE_RESOURCE: str = "compute" +_IRIAPI_TOKEN_FILE_ENV: str = "PATH_GLOBUS_TOKEN_FILE" + +_API_BASE_URLS: dict[NERSCLoginMethod, str] = { + NERSCLoginMethod.SFAPI: "https://api.nersc.gov/api/v1.2", + NERSCLoginMethod.IRIAPI: "https://api.iri.nersc.gov", +} + +# NERSC resource IDs (from status/resources endpoint) +RESOURCE_IDS = { + # Perlmutter compute + "perlmutter_compute": "94351904-6dba-4c16-b5cd-fbd280d8615b", + "perlmutter_login": "e525a224-61c1-419f-9642-91168c792e39", + "perlmutter_realtime": "3776417d-747c-4753-895a-6323c17b9c98", + "perlmutter_job_submit": "3cf3c048-855e-4dd8-a189-065a483954bb", + # Storage + "scratch": "43d8f6c0-f900-48ce-b267-73714103f4ac", + "homes": "65b28619-c3b6-4942-8da1-044a3b3a2a9e", + "common": "7e07a611-f927-4a39-a44d-b1d6e307accd", + "cfs": "59e80c79-4dfd-4c53-9c07-7405685fcd37", + "archive": "f4916c65-9001-49c2-b0bf-6fe4276b564c", + # Services + "globus": "0a207df3-4bec-45b8-9060-13505d269da9", + "dtns": "a762cbdc-af7a-4b2b-9463-67f0189dd2ae", +} + @dataclass class SegmentationModelSpec: @@ -142,6 +176,19 @@ def _load_job_options( return {**opts, **overrides} +class NERSCLoginMethod(enum.Enum): + """Selects which NERSC API login method to use when creating a NERSC client. + + Each method corresponds to a different set of credentials and API base URL. + """ + + SFAPI = "sfapi" + """Standard Superfacility API via Iris-registered OAuth2 credentials.""" + + IRIAPI = "iriapi" + """Integrated Research Infrastructure API via IRI-registered OAuth2 credentials.""" + + class NERSCTomographyHPCController(TomographyHPCController, NerscStreamingMixin): """ Implementation for a NERSC-based tomography HPC controller. @@ -151,14 +198,94 @@ class NERSCTomographyHPCController(TomographyHPCController, NerscStreamingMixin) def __init__( self, - client: Client, - config: Config832 + config: Config832, + client: Client | httpx.Client | None = None, + login_method: NERSCLoginMethod = NERSCLoginMethod.SFAPI, ) -> None: TomographyHPCController.__init__(self, config) self.client = client + self.login_method = login_method @staticmethod - def create_sfapi_client() -> Client: + def create_nersc_client( + login_method: NERSCLoginMethod = NERSCLoginMethod.SFAPI, + ) -> Client | httpx.Client: + """Create and return a NERSC client for the requested login method. + + Two fundamentally different auth strategies are supported: + + - :attr:`NERSCLoginMethod.SFAPI`: uses an Iris-registered OAuth2 + client ID + private key (NERSC OIDC flow). Set ``PATH_NERSC_CLIENT_ID`` + and ``PATH_NERSC_PRI_KEY`` to the paths of those files. + + - :attr:`NERSCLoginMethod.IRIAPI`: uses a Globus bearer token written + by ``globus_token.py``. Set ``PATH_GLOBUS_TOKEN_FILE`` to the token + file path, or rely on the default (``~/.globus/auth_tokens.json``). + + Args: + login_method: Which NERSC API to authenticate against. + Defaults to :attr:`NERSCLoginMethod.SFAPI`. + + Returns: + An authenticated :class:`sfapi_client.Client` instance. + + Raises: + ValueError: If SFAPI credential environment variables are unset. + FileNotFoundError: If credential or token files are absent. + RuntimeError: If the Globus token is expired. + Exception: If the underlying client construction fails. + """ + logger.info(f"Creating NERSC client using login method: {login_method.value}") + api_url = _API_BASE_URLS[login_method] + logger.info(f"Targeting API base URL: {api_url}") + + if login_method is NERSCLoginMethod.SFAPI: + client = NERSCTomographyHPCController._create_sfapi_client() + + elif login_method is NERSCLoginMethod.IRIAPI: + client = NERSCTomographyHPCController._create_iriapi_client() + + else: + raise ValueError(f"Unhandled NERSCLoginMethod: {login_method}") + + logger.info( + f"NERSC client created successfully " + f"(method={login_method.value}, api_url={api_url})." + ) + return client + + @staticmethod + def _create_iriapi_client() -> httpx.Client: + """Create a NERSC client for the IRI API using a Globus bearer token. + + Requires ``GLOBUS_CLIENT_ID`` and ``GLOBUS_CLIENT_SECRET`` in the + environment. Reuses a cached token if valid; otherwise mints a new one + via the client credentials grant. No browser or user interaction. + + Returns: + An authenticated :class:`httpx.Client` targeting the IRI API. + + Raises: + ValueError: If ``GLOBUS_CLIENT_ID`` or ``GLOBUS_CLIENT_SECRET`` are unset. + RuntimeError: If the acquired token is missing required scopes. + """ + token_file_env = os.getenv(_IRIAPI_TOKEN_FILE_ENV) + token_file = Path(token_file_env) if token_file_env else DEFAULT_TOKEN_FILE + + access_token = get_iri_access_token( + token_file=token_file, + force_login=False, + prompt_login=False + ) + + return httpx.Client( + base_url=_API_BASE_URLS[NERSCLoginMethod.IRIAPI], + headers={"Authorization": f"Bearer {access_token}"}, + timeout=httpx.Timeout(connect=10.0, read=120.0, write=30.0, pool=10.0), + ) + + @staticmethod + def _create_sfapi_client() -> Client: """Create and return an NERSC client instance""" # When generating the SFAPI Key in Iris, make sure to select "asldev" as the user! @@ -227,6 +354,228 @@ def _get_segmentation_spec(self, model: str, project: str) -> SegmentationModelS ) return registry[key] + def _get_nersc_username(self) -> str: + """Get the NERSC username for constructing pscratch paths. + + Uses the sfapi_client user endpoint for SFAPI, or reads + ``NERSC_USERNAME`` from the environment for IRIAPI. + + Returns: + NERSC username string. + + Raises: + ValueError: If IRIAPI is selected and NERSC_USERNAME is unset. + """ + if self.login_method is NERSCLoginMethod.SFAPI: + return self.client.user().name + else: + username = os.getenv("NERSC_USERNAME") + if not username: + raise ValueError( + "NERSC_USERNAME must be set in the environment when using IRIAPI." + ) + return username + + def _submit_job(self, job_script: str, num_nodes: int = 1) -> str: + """Submit a Slurm job script and return the job ID. + + Dispatches to the appropriate submission mechanism based on + ``self.login_method``. + + Args: + job_script: The full Slurm batch script to submit. + num_nodes: The number of nodes to request for the job. + + Returns: + The submitted job ID as a string. + + Raises: + RuntimeError: If job submission fails. + """ + if self.login_method is NERSCLoginMethod.SFAPI: + perlmutter = self.client.compute(Machine.perlmutter) + job = perlmutter.submit_job(job_script) + return str(job.jobid) + + elif self.login_method is NERSCLoginMethod.IRIAPI: + sbatch_values = {} + for line in job_script.splitlines(): + if line.startswith("#SBATCH"): + if "-q " in line: + sbatch_values["queue_name"] = line.split("-q ")[-1].strip() + elif "-A " in line: + sbatch_values["account"] = line.split("-A ")[-1].strip() + elif "--time=" in line: + t = line.split("--time=")[-1].strip() + parts = t.split(":") + sbatch_values["duration"] = int(parts[0])*3600 + int(parts[1])*60 + int(parts[2]) + elif "-N " in line: + sbatch_values["node_count"] = int(line.split("-N ")[-1].strip()) + elif "-C " in line: + sbatch_values["constraint"] = line.split("-C ")[-1].strip() + elif "--output=" in line: + sbatch_values["stdout_path"] = line.split("--output=")[-1].strip() + elif "--error=" in line: + sbatch_values["stderr_path"] = line.split("--error=")[-1].strip() + + # Strip shebang and SBATCH headers, keep the script body + script_body = "\n".join( + line for line in job_script.splitlines() + if not line.startswith("#SBATCH") and not line.startswith("#!/") + ).strip() + + constraint = sbatch_values.get("constraint", "cpu") + is_gpu = "gpu" in constraint.lower() + + resources = { + "node_count": sbatch_values.get("node_count", 1), + "processes_per_node": 1, + "exclusive_node_use": True, + } + if is_gpu: + resources["gpu_cores_per_process"] = 4 + else: + resources["cpu_cores_per_process"] = 128 + + job_spec = { + "executable": "/bin/bash", + "arguments": ["-s"], # read script from stdin isn't supported, so... + "pre_launch": script_body, # run the body here before the executable + "resources": resources, + "attributes": { + "duration": sbatch_values.get("duration", 1800), + "queue_name": sbatch_values.get("queue_name", "regular"), + "account": sbatch_values.get("account", "als"), + "custom_attributes": { + "constraint": constraint + }, + }, + } + + if "stdout_path" in sbatch_values: + job_spec["stdout_path"] = sbatch_values["stdout_path"] + if "stderr_path" in sbatch_values: + job_spec["stderr_path"] = sbatch_values["stderr_path"] + + response = self.client.post( + f"/api/v1/compute/job/{RESOURCE_IDS['perlmutter_job_submit']}", + json=job_spec, + ) + if not response.is_success: + logger.error(f"Job submission failed: {response.status_code} {response.text}") + logger.error(f"Job spec was: {json.dumps(job_spec, indent=2)}") + response.raise_for_status() + return str(response.json()["id"]) + + else: + raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") + + def _wait_for_job(self, job_id: str) -> bool: + """Block until a submitted job completes. + + Dispatches to the appropriate polling mechanism based on + ``self.login_method``. + + Args: + job_id: The job ID returned by `_submit_job`. + + Returns: + True if the job completed successfully, False otherwise. + """ + if self.login_method is NERSCLoginMethod.SFAPI: + perlmutter = self.client.compute(Machine.perlmutter) + job = perlmutter.job(jobid=job_id) + job.complete() + return True + + elif self.login_method is NERSCLoginMethod.IRIAPI: + while True: + response = self.client.get( + f"/api/v1/compute/status/{_IRI_COMPUTE_RESOURCE}/{job_id}" # ← was "perlmutter" + ) + response.raise_for_status() + state = response.json().get("status", {}).get("state") + logger.info(f"Job {job_id} state: {state}") + if state == "completed": + return True + if state in ("failed", "canceled", "timeout"): + logger.error(f"Job {job_id} ended with state: {state}") + return False + time.sleep(60) + + else: + raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") + + def _mkdir_remote(self, path: str) -> None: + """Create a directory on Perlmutter remotely. + + Args: + path: Absolute path to create. + """ + if self.login_method is NERSCLoginMethod.SFAPI: + perlmutter = self.client.compute(Machine.perlmutter) + perlmutter.run(f"mkdir -p {path}") + elif self.login_method is NERSCLoginMethod.IRIAPI: + response = self.client.post( + f"/api/v1/filesystem/mkdir/{RESOURCE_IDS['perlmutter_login']}", + json={"path": path, "parents": True}, + ) + response.raise_for_status() + else: + raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") + + def _read_remote_file(self, path: str) -> str: + """Read a remote file on Perlmutter and return its contents. + + Args: + path: Absolute path to the file on Perlmutter. + + Returns: + File contents as a string. + """ + if self.login_method is NERSCLoginMethod.SFAPI: + perlmutter = self.client.compute(Machine.perlmutter) + result = perlmutter.run(f"cat {path}") + if isinstance(result, str): + return result + elif hasattr(result, 'output'): + return result.output + elif hasattr(result, 'stdout'): + return result.stdout + return str(result) + + elif self.login_method is NERSCLoginMethod.IRIAPI: + response = self.client.get( + f"/api/v1/filesystem/view/{RESOURCE_IDS['perlmutter_login']}", + params={"path": path}, + ) + response.raise_for_status() + task_id = response.json().get("task_id") + if not task_id: + return response.text + + for _ in range(40): + task_response = self.client.get(f"/api/v1/task/{task_id}") + task_response.raise_for_status() + task = task_response.json() + status = task.get("status") + if status == "completed": + result = task.get("result", "") + if isinstance(result, dict): + output = result.get("output", result) + if isinstance(output, dict): + return output.get("content", str(output)) + return str(output) + return str(result) + elif status == "failed": + raise RuntimeError(f"File read task {task_id} failed: {task.get('result')}") + time.sleep(3) + + raise TimeoutError(f"File read task {task_id} did not complete") + + else: + raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") + def reconstruct( self, file_path: str = "", @@ -241,7 +590,7 @@ def reconstruct( """ logger.info("Starting NERSC reconstruction process.") - user = self.client.user() + username = self._get_nersc_username() raw_path = self.config.nersc832_alsdev_raw.root_path logger.info(f"{raw_path=}") @@ -255,7 +604,7 @@ def reconstruct( scratch_path = self.config.nersc832_alsdev_scratch.root_path logger.info(f"{scratch_path=}") - pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" + pscratch_path = f"/pscratch/sd/{username[0]}/{username}" logger.info(f"{pscratch_path=}") path = Path(file_path) @@ -271,6 +620,8 @@ def reconstruct( opts = _load_job_options("nersc-reconstruction-options", self.config.nersc_recon_settings) + logger.info(f"Resolved options: {opts}") + num_nodes = opts.get("num_nodes", num_nodes) cpus_per_task = opts["cpus-per-task"] qos = opts["qos"] @@ -335,6 +686,7 @@ def reconstruct( echo "METADATA_START=$(date +%s)" >> $TIMING_FILE NUM_SLICES=$(shifter \ + --image={recon_image} \ --volume={pscratch_path}/8.3.2:/alsdata \ python -c " import h5py @@ -379,6 +731,7 @@ def reconstruct( fi srun --nodes=1 --ntasks=1 --exclusive shifter \ + --image={recon_image} \ --env=NUMEXPR_MAX_THREADS=128 \ --env=NUMEXPR_NUM_THREADS=128 \ --env=OMP_NUM_THREADS=128 \ @@ -404,55 +757,23 @@ def reconstruct( echo "JOB_STATUS=SUCCESS" >> $TIMING_FILE echo "JOB_END=$(date +%s)" >> $TIMING_FILE """ + job_id = None try: - logger.info("Submitting reconstruction job script to Perlmutter.") - perlmutter = self.client.compute(Machine.perlmutter) - job = perlmutter.submit_job(job_script) - logger.info(f"Submitted job ID: {job.jobid}") - - try: - job.update() - except Exception as update_err: - logger.warning(f"Initial job update failed, continuing: {update_err}") - + logger.info("Submitting reconstruction job to Perlmutter.") + job_id = self._submit_job(job_script) + logger.info(f"Submitted job ID: {job_id}") time.sleep(60) - logger.info(f"Job {job.jobid} current state: {job.state}") - - job.complete() # Wait until the job completes - logger.info("Reconstruction job completed successfully.") - # Fetch timing data - timing = self._fetch_timing_data(perlmutter, pscratch_path, job.jobid) - - return { - "success": True, - "job_id": job.jobid, - "timing": timing - } - + success = self._wait_for_job(job_id) + timing = self._fetch_timing_data(pscratch_path, job_id) if success else None + return {"success": success, "job_id": job_id, "timing": timing} except Exception as e: - logger.info(f"Error during job submission or completion: {e}") - match = re.search(r"Job not found:\s*(\d+)", str(e)) - - if match: - jobid = match.group(1) - logger.info(f"Attempting to recover job {jobid}.") - try: - job = self.client.perlmutter.job(jobid=jobid) - time.sleep(30) - job.complete() - logger.info("Reconstruction job completed successfully after recovery.") - return True - except Exception as recovery_err: - logger.error(f"Failed to recover job {jobid}: {recovery_err}") - return False - else: - return False + logger.error(f"Error during reconstruction job submission or completion: {e}") + return {"success": False, "job_id": job_id, "timing": None} - def _fetch_timing_data(self, perlmutter, pscratch_path: str, job_id: str) -> dict: + def _fetch_timing_data(self, pscratch_path: str, job_id: str) -> dict: """ Fetch and parse timing data from the SLURM job. - :param perlmutter: SFAPI compute object for Perlmutter :param pscratch_path: Path to the user's pscratch directory :param job_id: SLURM job ID :return: Dictionary with timing breakdown @@ -460,18 +781,7 @@ def _fetch_timing_data(self, perlmutter, pscratch_path: str, job_id: str) -> dic timing_file = f"{pscratch_path}/tomo_recon_logs/timing_{job_id}.txt" try: - # Use SFAPI to read the timing file - result = perlmutter.run(f"cat {timing_file}") - - # result might be a string directly, or an object with .output - if isinstance(result, str): - output = result - elif hasattr(result, 'output'): - output = result.output - elif hasattr(result, 'stdout'): - output = result.stdout - else: - output = str(result) + output = self._read_remote_file(timing_file) logger.info(f"Timing file contents:\n{output}") @@ -527,7 +837,7 @@ def build_multi_resolution( logger.info("Starting NERSC multiresolution process.") - user = self.client.user() + username = self._get_nersc_username() multires_image = self.config.ghcr_images832["multires_image"] logger.info(f"{multires_image=}") @@ -538,7 +848,7 @@ def build_multi_resolution( scratch_path = self.config.nersc832_alsdev_scratch.root_path logger.info(f"{scratch_path=}") - pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" + pscratch_path = f"/pscratch/sd/{username[0]}/{username}" logger.info(f"{pscratch_path=}") path = Path(file_path) @@ -593,42 +903,16 @@ def build_multi_resolution( date """ try: - logger.info("Submitting Tiff to Zarr job script to Perlmutter.") - perlmutter = self.client.compute(Machine.perlmutter) - job = perlmutter.submit_job(job_script) - logger.info(f"Submitted job ID: {job.jobid}") - - try: - job.update() - except Exception as update_err: - logger.warning(f"Initial job update failed, continuing: {update_err}") - + logger.info("Submitting Tiff to Zarr job to Perlmutter.") + job_id = self._submit_job(job_script) + logger.info(f"Submitted job ID: {job_id}") time.sleep(60) - logger.info(f"Job {job.jobid} current state: {job.state}") - - job.complete() # Wait until the job completes - logger.info("Reconstruction job completed successfully.") - - return True - + success = self._wait_for_job(job_id) + logger.info(f"Multiresolution job {'completed' if success else 'failed'}.") + return success except Exception as e: - logger.warning(f"Error during job submission or completion: {e}") - match = re.search(r"Job not found:\s*(\d+)", str(e)) - - if match: - jobid = match.group(1) - logger.info(f"Attempting to recover job {jobid}.") - try: - job = self.client.perlmutter.job(jobid=jobid) - time.sleep(30) - job.complete() - logger.info("Reconstruction job completed successfully after recovery.") - return True - except Exception as recovery_err: - logger.error(f"Failed to recover job {jobid}: {recovery_err}") - return False - else: - return False + logger.error(f"Error during multiresolution job submission or completion: {e}") + return False def segmentation_sam3( self, @@ -640,8 +924,8 @@ def segmentation_sam3( """ logger.info("Starting NERSC segmentation process (inference_v6).") - user = self.client.user() - pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" + username = self._get_nersc_username() + pscratch_path = f"/pscratch/sd/{username[0]}/{username}" opts = _load_job_options( variable_name="nersc-segmentation-options", @@ -706,7 +990,7 @@ def segmentation_sam3( #SBATCH -A {account} {reservation_line} #SBATCH -N {num_nodes} -#SBATCH -C {constraint} # gpu +#SBATCH -C {constraint} #SBATCH --job-name={job_name} #SBATCH --time={walltime} #SBATCH --ntasks-per-node={ntasks_per_node} @@ -831,34 +1115,21 @@ def segmentation_sam3( """ try: - logger.info("Submitting segmentation job to Perlmutter (v6).") - perlmutter = self.client.compute(Machine.perlmutter) + logger.info("Submitting segmentation job to Perlmutter.") # Ensure directories exist logger.info("Creating necessary directories...") - perlmutter.run(f"mkdir -p {pscratch_path}/tomo_seg_logs") - perlmutter.run(f"mkdir -p {output_dir}") + self._mkdir_remote(f"{pscratch_path}/tomo_seg_logs") + self._mkdir_remote(output_dir) # Submit job - job = perlmutter.submit_job(job_script) - logger.info(f"Submitted job ID: {job.jobid}") - - # Initial update - try: - job.update() - except Exception as update_err: - logger.warning(f"Initial job update failed, continuing: {update_err}") - - # Wait briefly before polling + job_id = self._submit_job(job_script) + logger.info(f"Submitted job ID: {job_id}") time.sleep(60) - logger.info(f"Job {job.jobid} current state: {job.state}") - - # Wait for completion - job.complete() + success = self._wait_for_job(job_id) logger.info("Segmentation job completed successfully.") - # Fetch timing data from output file - timing = self._fetch_seg_timing_from_output(perlmutter, pscratch_path, job.jobid, job_name) + timing = self._fetch_seg_timing_from_output(pscratch_path, job_id, job_name) if timing: logger.info("=" * 60) @@ -872,43 +1143,21 @@ def segmentation_sam3( logger.info("=" * 60) return { - "success": True, - "job_id": job.jobid, + "success": success, + "job_id": job_id, "timing": timing, - "output_dir": output_dir + "output_dir": output_dir, } except Exception as e: logger.error(f"Error during segmentation job: {e}") import traceback logger.error(traceback.format_exc()) - - # Attempt recovery - match = re.search(r"Job not found:\s*(\d+)", str(e)) - if match: - jobid = match.group(1) - logger.info(f"Attempting to recover job {jobid}.") - try: - job = self.client.compute(Machine.perlmutter).job(jobid=jobid) - time.sleep(30) - job.complete() - logger.info("Segmentation job completed after recovery.") - - timing = self._fetch_seg_timing_from_output(perlmutter, pscratch_path, jobid, job_name) - return { - "success": True, - "job_id": jobid, - "timing": timing, - "output_dir": output_dir - } - except Exception as recovery_err: - logger.error(f"Failed to recover job {jobid}: {recovery_err}") - return { "success": False, "job_id": None, "timing": None, - "output_dir": None + "output_dir": None, } def segmentation_dinov3( @@ -926,8 +1175,8 @@ def segmentation_dinov3( """ logger.info("Starting NERSC DINOv3 segmentation process.") - user = self.client.user() - pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" + username = self._get_nersc_username() + pscratch_path = f"/pscratch/sd/{username[0]}/{username}" # Load from config spec = self._get_segmentation_spec("dinov3", project) @@ -1062,39 +1311,15 @@ def segmentation_dinov3( """ try: logger.info("Submitting DINOv3 segmentation job to Perlmutter.") - perlmutter = self.client.compute(Machine.perlmutter) - job = perlmutter.submit_job(job_script) - logger.info(f"Submitted job ID: {job.jobid}") - - try: - job.update() - except Exception as update_err: - logger.warning(f"Initial job update failed, continuing: {update_err}") - + job_id = self._submit_job(job_script) + logger.info(f"Submitted job ID: {job_id}") time.sleep(60) - logger.info(f"Job {job.jobid} current state: {job.state}") - - job.complete() - logger.info("DINOv3 segmentation job completed successfully.") - return True - + success = self._wait_for_job(job_id) + logger.info(f"DINOv3 segmentation job {'completed successfully' if success else 'failed'}.") + return success except Exception as e: logger.error(f"Error during DINOv3 segmentation job submission or completion: {e}") - match = re.search(r"Job not found:\s*(\d+)", str(e)) - if match: - jobid = match.group(1) - logger.info(f"Attempting to recover job {jobid}.") - try: - job = self.client.compute(Machine.perlmutter).job(jobid=jobid) - time.sleep(30) - job.complete() - logger.info("DINOv3 segmentation job completed successfully after recovery.") - return True - except Exception as recovery_err: - logger.error(f"Failed to recover job {jobid}: {recovery_err}") - return False - else: - return False + return False def combine_segmentations( self, @@ -1102,7 +1327,7 @@ def combine_segmentations( ) -> bool: """ Run CPU-based combination of SAM3+DINOv3 segmentation results - at NERSC Perlmutter via SFAPI Slurm job. + at NERSC Perlmutter via Slurm job. :param recon_folder_path: Relative path to the reconstructed data folder, e.g. 'folder_name/recYYYYMMDD_hhmmss_scanname/' @@ -1110,8 +1335,8 @@ def combine_segmentations( """ logger.info("Starting NERSC segmentation combination process.") - user = self.client.user() - pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" + username = self._get_nersc_username() + pscratch_path = f"/pscratch/sd/{username[0]}/{username}" opts = _load_job_options( "nersc-combine-seg-options", self.config.nersc_combine_segmentation_settings @@ -1210,45 +1435,20 @@ def combine_segmentations( """ try: logger.info("Submitting segmentation combination job to Perlmutter.") - perlmutter = self.client.compute(Machine.perlmutter) - job = perlmutter.submit_job(job_script) - logger.info(f"Submitted job ID: {job.jobid}") - - try: - job.update() - except Exception as update_err: - logger.warning(f"Initial job update failed, continuing: {update_err}") - + job_id = self._submit_job(job_script) + logger.info(f"Submitted job ID: {job_id}") time.sleep(60) - logger.info(f"Job {job.jobid} current state: {job.state}") - - job.complete() - logger.info("Segmentation combination job completed successfully.") - return True - + success = self._wait_for_job(job_id) + logger.info(f"Segmentation combination job {'completed successfully' if success else 'failed'}.") + return success except Exception as e: logger.error(f"Error during segmentation combination job submission or completion: {e}") - match = re.search(r"Job not found:\s*(\d+)", str(e)) - if match: - jobid = match.group(1) - logger.info(f"Attempting to recover job {jobid}.") - try: - job = self.client.compute(Machine.perlmutter).job(jobid=jobid) - time.sleep(30) - job.complete() - logger.info("Segmentation combination job completed successfully after recovery.") - return True - except Exception as recovery_err: - logger.error(f"Failed to recover job {jobid}: {recovery_err}") - return False - else: - return False + return False - def _fetch_seg_timing_from_output(self, perlmutter, pscratch_path: str, job_id: str, job_name: str) -> dict: + def _fetch_seg_timing_from_output(self, pscratch_path: str, job_id: str, job_name: str) -> dict: """ Fetch and parse timing data from the SLURM output file. - :param perlmutter: SFAPI compute object for Perlmutter :param pscratch_path: Path to the user's pscratch directory :param job_id: SLURM job ID :param job_name: Job name for finding output file @@ -1257,18 +1457,7 @@ def _fetch_seg_timing_from_output(self, perlmutter, pscratch_path: str, job_id: output_file = f"{pscratch_path}/tomo_seg_logs/{job_name}_{job_id}.out" try: - # Use SFAPI to read the output file - result = perlmutter.run(f"cat {output_file}") - - # Handle different result types - if isinstance(result, str): - output = result - elif hasattr(result, 'output'): - output = result.output - elif hasattr(result, 'stdout'): - output = result.stdout - else: - output = str(result) + output = self._read_remote_file(output_file) logger.info("Job output file contents (last 50 lines):") lines = output.strip().split('\n') @@ -1345,8 +1534,8 @@ def pull_shifter_image( """ logger.info("Starting Shifter image pull.") - user = self.client.user() - pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" + username = self._get_nersc_username() + pscratch_path = f"/pscratch/sd/{username[0]}/{username}" if image is None: image = self.config.ghcr_images832["recon_image"] @@ -1393,24 +1582,16 @@ def pull_shifter_image( try: logger.info("Submitting Shifter image pull job to Perlmutter.") - perlmutter = self.client.compute(Machine.perlmutter) - job = perlmutter.submit_job(job_script) - logger.info(f"Submitted job ID: {job.jobid}") + job_id = self._submit_job(job_script) + logger.info(f"Submitted job ID: {job_id}") if wait: - try: - job.update() - except Exception as update_err: - logger.warning(f"Initial job update failed, continuing: {update_err}") - time.sleep(30) - logger.info(f"Job {job.jobid} current state: {job.state}") - - job.complete() - logger.info("Shifter image pull completed successfully.") - return True + success = self._wait_for_job(job_id) + logger.info(f"Shifter image pull {'completed successfully' if success else 'failed'}.") + return success else: - logger.info(f"Job submitted. Check status with job ID: {job.jobid}") + logger.info(f"Job submitted. Check status with job ID: {job_id}") return True except Exception as e: @@ -1433,17 +1614,31 @@ def check_shifter_image( image = self.config.ghcr_images832["recon_image"] try: - perlmutter = self.client.compute(Machine.perlmutter) - # Run shifterimg images command - result = perlmutter.run(f"shifterimg images | grep -E \"$(echo {image} | sed 's/:/.*/g')\"") - - if isinstance(result, str): - output = result - elif hasattr(result, 'output'): - output = result.output - else: - output = str(result) + if self.login_method is NERSCLoginMethod.SFAPI: + # synchronous via utilities/command + perlmutter = self.client.compute(Machine.perlmutter) + result = perlmutter.run(f"shifterimg images | grep -E \"$(echo {image} | sed 's/:/.*/g')\"") + output = result if isinstance(result, str) else getattr(result, 'output', str(result)) + + elif self.login_method is NERSCLoginMethod.IRIAPI: + # async: submit job → wait → read stdout file + username = self._get_nersc_username() + pscratch_path = f"/pscratch/sd/{username[0]}/{username}" + output_file = f"{pscratch_path}/tomo_recon_logs/shifter_check.txt" + check_script = f"""#!/bin/bash +#SBATCH -q debug +#SBATCH -A als +#SBATCH -C cpu +#SBATCH -N 1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=1 +#SBATCH --time=0:05:00 +shifterimg images | grep -E "$(echo {image} | sed 's/:/.*/g')" > {output_file} 2>&1 || true +""" + job_id = self._submit_job(check_script) + self._wait_for_job(job_id) + output = self._read_remote_file(output_file) if output.strip(): logger.info(f"Image found in Shifter cache: {output.strip()}") @@ -1607,7 +1802,8 @@ def nersc_recon_flow( logger.info(f"Starting NERSC reconstruction flow for {file_path=}") controller = get_controller( hpc_type=HPC.NERSC, - config=config + config=config, + login_method=NERSCLoginMethod.SFAPI ) logger.info("NERSC reconstruction controller initialized") @@ -1723,6 +1919,7 @@ def nersc_petiole_segment_flow( file_path: str, config: Optional[Config832] = None, num_nodes: Optional[int] = None, + login_method: Optional[NERSCLoginMethod] = NERSCLoginMethod.IRIAPI ) -> bool: """ Transfer raw data to NERSC, run reconstruction, then run SAM3 and DINOv3 @@ -1749,7 +1946,6 @@ def nersc_petiole_segment_flow( logger.info(f"Reconstructed TIFFs will be at: {scratch_path_tiff}") logger.info(f"Segmented output will be at: {scratch_path_segment}") - controller = get_controller(hpc_type=HPC.NERSC, config=config) logger.info("NERSC controller initialized") if num_nodes is None: @@ -1770,9 +1966,10 @@ def nersc_petiole_segment_flow( # ── STEP 1: Multinode Reconstruction ───────────────────────────────────── logger.info(f"Using multi-node reconstruction with {num_nodes} nodes") - recon_result = controller.reconstruct( + recon_result = nersc_reconstruction_task( file_path=file_path, - num_nodes=num_nodes + num_nodes=num_nodes, + config=config, ) if isinstance(recon_result, dict): @@ -1809,18 +2006,18 @@ def nersc_petiole_segment_flow( logger.info("Reconstruction Successful.") # ── STEP 2: Transfer TIFFs to data832 ──────────────────────────────────── - logger.info("Transferring reconstructed TIFFs from NERSC pscratch to data832") - try: - data832_tiff_future = globus_transfer_task.submit( - file_path=scratch_path_tiff, - source=config.nersc832_alsdev_pscratch_scratch, - destination=config.data832_scratch, - config=config, - ) - logger.info("TIFF transfer to data832 submitted.") - except Exception as e: - logger.error(f"Failed to transfer TIFFs to data832: {e}") - data832_tiff_transfer_success = False + # logger.info("Transferring reconstructed TIFFs from NERSC pscratch to data832") + # try: + # data832_tiff_future = globus_transfer_task.submit( + # file_path=scratch_path_tiff, + # source=config.nersc832_alsdev_pscratch_scratch, + # destination=config.data832_scratch, + # config=config, + # ) + # logger.info("TIFF transfer to data832 submitted.") + # except Exception as e: + # logger.error(f"Failed to transfer TIFFs to data832: {e}") + # data832_tiff_transfer_success = False # ── STEP 3: SAM3 / DINOv3 ────────────────────────── logger.info("Submitting SAM3 and DINOv3 segmentation tasks concurrently.") @@ -1829,7 +2026,7 @@ def nersc_petiole_segment_flow( recon_folder_path=scratch_path_tiff, config=config ) dinov3_future = nersc_segmentation_dinov3_task.submit( - recon_folder_path=scratch_path_tiff, config=config, project="petiole" + recon_folder_path=scratch_path_tiff, config=config, project="petiole", login_method=login_method ) # ── STEP 4: Transfer each model's output as it completes ───────────────── @@ -1838,15 +2035,17 @@ def nersc_petiole_segment_flow( logger.info(f"SAM3 segmentation result: {sam3_success}") if sam3_success: logger.info("Transferring SAM3 segmentation outputs to data832") - sam3_segment_path = f"{folder_name}/seg{file_name}/sam3" + # sam3_segment_path = f"{folder_name}/seg{file_name}/sam3" try: - data832_sam3_future = globus_transfer_task.submit( - file_path=sam3_segment_path, - source=config.nersc832_alsdev_pscratch_scratch, - destination=config.data832_scratch, - config=config, - ) - logger.info("SAM3 transfer to data832 submitted") + # data832_sam3_future = globus_transfer_task.submit( + # file_path=sam3_segment_path, + # source=config.nersc832_alsdev_pscratch_scratch, + # destination=config.data832_scratch, + # config=config, + # ) + # logger.info("SAM3 transfer to data832 submitted") + data832_sam3_transfer_success = True + logger.info(f"SAM3 transfer to data832 success: {data832_sam3_transfer_success}") except Exception as e: logger.error(f"Failed to transfer SAM3 outputs to data832: {e}") @@ -1854,15 +2053,17 @@ def nersc_petiole_segment_flow( logger.info(f"DINOv3 segmentation result: {dinov3_success}") if dinov3_success: logger.info("Transferring DINOv3 segmentation outputs to data832") - dinov3_segment_path = f"{folder_name}/seg{file_name}/dino" + # dinov3_segment_path = f"{folder_name}/seg{file_name}/dino" try: - data832_dinov3_future = globus_transfer_task.submit( - file_path=dinov3_segment_path, - source=config.nersc832_alsdev_pscratch_scratch, - destination=config.data832_scratch, - config=config, - ) - logger.info("DINOv3 transfer to data832 submitted") + # data832_dinov3_future = globus_transfer_task.submit( + # file_path=dinov3_segment_path, + # source=config.nersc832_alsdev_pscratch_scratch, + # destination=config.data832_scratch, + # config=config, + # ) + # logger.info("DINOv3 transfer to data832 submitted") + data832_dinov3_transfer_success = True + logger.info(f"DINOv3 transfer to data832 success: {data832_dinov3_transfer_success}") except Exception as e: logger.error(f"Failed to transfer DINOv3 outputs to data832: {e}") @@ -1882,15 +2083,17 @@ def nersc_petiole_segment_flow( logger.info(f"Combination result: {combine_success}") if combine_success: logger.info("Transferring combined segmentation outputs to data832") - combined_segment_path = f"{folder_name}/seg{file_name}/combined/sam_dino" + # combined_segment_path = f"{folder_name}/seg{file_name}/combined/sam_dino" try: - data832_combined_future = globus_transfer_task.submit( - file_path=combined_segment_path, - source=config.nersc832_alsdev_pscratch_scratch, - destination=config.data832_scratch, - config=config, - ) - logger.info("Combined transfer to data832 submitted") + # data832_combined_future = globus_transfer_task.submit( + # file_path=combined_segment_path, + # source=config.nersc832_alsdev_pscratch_scratch, + # destination=config.data832_scratch, + # config=config, + # ) + # logger.info("Combined transfer to data832 submitted") + data832_combined_transfer_success = True + logger.info(f"Combined transfer to data832 success: {data832_combined_transfer_success}") except Exception as e: logger.error(f"Failed to transfer combined outputs to data832: {e}") @@ -1923,67 +2126,67 @@ def nersc_petiole_segment_flow( ) # ── STEP 6: Pruning ─────────────────────────────────────────────────────── - logger.info("Scheduling file pruning tasks.") - prune_controller = get_prune_controller(prune_type=PruneMethod.GLOBUS, config=config) - - try: - prune_controller.prune( - file_path=f"{folder_name}/{path.name}", - source_endpoint=config.nersc832_alsdev_pscratch_raw, - check_endpoint=None, - days_from_now=1.0 - ) - except Exception as e: - logger.warning(f"Failed to schedule raw data pruning: {e}") - - if nersc_reconstruction_success: - try: - prune_controller.prune( - file_path=scratch_path_tiff, - source_endpoint=config.nersc832_alsdev_pscratch_scratch, - check_endpoint=config.data832_scratch if data832_tiff_transfer_success else None, - days_from_now=1.0 - ) - except Exception as e: - logger.warning(f"Failed to schedule reconstruction data pruning: {e}") - - if any_seg_success: - try: - prune_controller.prune( - file_path=scratch_path_segment, - source_endpoint=config.nersc832_alsdev_pscratch_scratch, - check_endpoint=config.data832_scratch if any([ - data832_sam3_transfer_success, - data832_dinov3_transfer_success, - ]) else None, - days_from_now=1.0 - ) - except Exception as e: - logger.warning(f"Failed to schedule segmentation data pruning: {e}") - - if data832_tiff_transfer_success: - try: - prune_controller.prune( - file_path=scratch_path_tiff, - source_endpoint=config.data832_scratch, - check_endpoint=None, - days_from_now=30.0 - ) - except Exception as e: - logger.warning(f"Failed to schedule data832 tiff pruning: {e}") - - if any([data832_sam3_transfer_success, - data832_dinov3_transfer_success, - data832_combined_transfer_success]): - try: - prune_controller.prune( - file_path=scratch_path_segment, - source_endpoint=config.data832_scratch, - check_endpoint=None, - days_from_now=30.0 - ) - except Exception as e: - logger.warning(f"Failed to schedule data832 segment pruning: {e}") + # logger.info("Scheduling file pruning tasks.") + # prune_controller = get_prune_controller(prune_type=PruneMethod.GLOBUS, config=config) + + # try: + # prune_controller.prune( + # file_path=f"{folder_name}/{path.name}", + # source_endpoint=config.nersc832_alsdev_pscratch_raw, + # check_endpoint=None, + # days_from_now=1.0 + # ) + # except Exception as e: + # logger.warning(f"Failed to schedule raw data pruning: {e}") + + # if nersc_reconstruction_success: + # try: + # prune_controller.prune( + # file_path=scratch_path_tiff, + # source_endpoint=config.nersc832_alsdev_pscratch_scratch, + # check_endpoint=config.data832_scratch if data832_tiff_transfer_success else None, + # days_from_now=1.0 + # ) + # except Exception as e: + # logger.warning(f"Failed to schedule reconstruction data pruning: {e}") + + # if any_seg_success: + # try: + # prune_controller.prune( + # file_path=scratch_path_segment, + # source_endpoint=config.nersc832_alsdev_pscratch_scratch, + # check_endpoint=config.data832_scratch if any([ + # data832_sam3_transfer_success, + # data832_dinov3_transfer_success, + # ]) else None, + # days_from_now=1.0 + # ) + # except Exception as e: + # logger.warning(f"Failed to schedule segmentation data pruning: {e}") + + # if data832_tiff_transfer_success: + # try: + # prune_controller.prune( + # file_path=scratch_path_tiff, + # source_endpoint=config.data832_scratch, + # check_endpoint=None, + # days_from_now=30.0 + # ) + # except Exception as e: + # logger.warning(f"Failed to schedule data832 tiff pruning: {e}") + + # if any([data832_sam3_transfer_success, + # data832_dinov3_transfer_success, + # data832_combined_transfer_success]): + # try: + # prune_controller.prune( + # file_path=scratch_path_segment, + # source_endpoint=config.data832_scratch, + # check_endpoint=None, + # days_from_now=30.0 + # ) + # except Exception as e: + # logger.warning(f"Failed to schedule data832 segment pruning: {e}") if nersc_reconstruction_success and any_seg_success: logger.info("NERSC reconstruction + multi-segmentation flow completed successfully.") @@ -2193,6 +2396,55 @@ def nersc_moon_segment_flow( return False +@flow(name="nersc_recon_test_iriapi_flow", flow_run_name="nersc_recon-{file_path}") +def nersc_recon_test_iriapi_flow( + file_path: str, + config: Optional[Config832] = None, +) -> bool: + """ + Perform tomography reconstruction on NERSC. + + :param file_path: Path to the file to reconstruct. + :param config: Configuration object (if None, a default Config832 will be created). + :return: True if successful, False otherwise. + """ + logger.info(f"Starting NERSC reconstruction flow for {file_path=}") + controller = get_controller( + hpc_type=HPC.NERSC, + config=config, + login_method=NERSCLoginMethod.IRIAPI + ) + logger.info("NERSC reconstruction controller initialized") + + nersc_reconstruction_success = controller.reconstruct( + file_path=file_path, + ) + logger.info(f"NERSC reconstruction success: {nersc_reconstruction_success}") + + nersc_multi_res_success = controller.build_multi_resolution( + file_path=file_path, + ) + logger.info(f"NERSC multi-resolution success: {nersc_multi_res_success}") + + path = Path(file_path) + folder_name = path.parent.name + file_name = path.stem + + tiff_file_path = f"{folder_name}/rec{file_name}" + zarr_file_path = f"{folder_name}/rec{file_name}.zarr" + + logger.info(f"{tiff_file_path=}") + logger.info(f"{zarr_file_path=}") + + # Transfers and pruning omitted from test flow. + + # TODO: Ingest into SciCat + if nersc_reconstruction_success: + return True + else: + return False + + @flow(name="nersc_streaming_flow", on_cancellation=[cancellation_hook]) def nersc_streaming_flow( walltime: datetime.timedelta = datetime.timedelta(minutes=5), @@ -2254,10 +2506,36 @@ def pull_shifter_image_flow( return success +@task(name="nersc_reconstruction_task") +def nersc_reconstruction_task( + file_path: str, + num_nodes: int = 4, + config: Optional[Config832] = None, + login_method: Optional[NERSCLoginMethod] = NERSCLoginMethod.IRIAPI +) -> dict: + """ + Run tomography reconstruction at NERSC Perlmutter. + + :param file_path: Path to the raw HDF5 file to reconstruct. + :param num_nodes: Number of nodes to use for reconstruction. + :param config: Configuration object for the flow. + :return: Dict with keys 'success', 'job_id', 'timing'. + """ + logger = get_run_logger() + if config is None: + config = Config832() + + logger.info("Initializing NERSC Tomography HPC Controller.") + controller = get_controller(hpc_type=HPC.NERSC, config=config, login_method=login_method) + logger.info(f"Starting NERSC reconstruction task for {file_path=}") + return controller.reconstruct(file_path=file_path, num_nodes=num_nodes) + + @task(name="nersc_multiresolution_task") def nersc_multiresolution_task( file_path: str, config: Optional[Config832] = None, + login_method: Optional[NERSCLoginMethod] = NERSCLoginMethod.IRIAPI ) -> bool: """ Run multiresolution task at NERSC. @@ -2275,7 +2553,8 @@ def nersc_multiresolution_task( logger.info("Initializing NERSC Tomography HPC Controller.") tomography_controller = get_controller( hpc_type=HPC.NERSC, - config=config + config=config, + login_method=login_method ) logger.info(f"Starting NERSC multiresolution task for {file_path=}") nersc_multiresolution_success = tomography_controller.build_multi_resolution( @@ -2327,7 +2606,8 @@ def nersc_segmentation_sam3_task( logger.info("Initializing NERSC Tomography HPC Controller.") tomography_controller = get_controller( hpc_type=HPC.NERSC, - config=config + config=config, + login_method=NERSCLoginMethod.IRIAPI ) logger.info(f"Starting NERSC segmentation task for {recon_folder_path=}") nersc_segmentation_success = tomography_controller.segmentation_sam3( @@ -2348,12 +2628,13 @@ def nersc_segmentation_dinov3_task( recon_folder_path: str, config: Optional[Config832] = None, project: Optional[str] = "petiole", + login_method: Optional[NERSCLoginMethod] = NERSCLoginMethod.IRIAPI ) -> bool: logger = get_run_logger() if config is None: logger.info("No config provided, using default Config832.") config = Config832() - tomography_controller = get_controller(hpc_type=HPC.NERSC, config=config) + tomography_controller = get_controller(hpc_type=HPC.NERSC, config=config, login_method=NERSCLoginMethod.IRIAPI) logger.info(f"Starting NERSC DINOv3 segmentation task for {recon_folder_path=}, {project=}") success = tomography_controller.segmentation_dinov3(recon_folder_path=recon_folder_path, project=project) if not success: @@ -2372,7 +2653,7 @@ def nersc_combine_segmentations_task( if config is None: logger.info("No config provided, using default Config832.") config = Config832() - tomography_controller = get_controller(hpc_type=HPC.NERSC, config=config) + tomography_controller = get_controller(hpc_type=HPC.NERSC, config=config, login_method=NERSCLoginMethod.IRIAPI) logger.info(f"Starting NERSC combine segmentations task for {recon_folder_path=}") success = tomography_controller.combine_segmentations(recon_folder_path=recon_folder_path) if not success: @@ -2394,15 +2675,20 @@ def nersc_segmentation_sam3_integration_test() -> bool: recon_folder_path = 'synaps-i/rec20211222_125057_petiole4' # 'test' # flow_success = nersc_segmentation_sam3_task( recon_folder_path=recon_folder_path, - config=Config832() + config=Config832(), ) logger.info(f"Flow success: {flow_success}") return flow_success -if __name__ == "__main__": - nersc_segmentation_dinov3_task( - recon_folder_path='dabramov/recmoon/', - config=Config832(), - project="moon" - ) +# if __name__ == "__main__": + # nersc_segmentation_dinov3_task( + # recon_folder_path='dabramov/recmoon/', + # config=Config832(), + # project="moon" + # ) + # nersc_petiole_segment_flow( + # file_path='dabramov/20260221_143000_petiole28', + # num_nodes=2, + # login_method=NERSCLoginMethod.IRIAPI + # ) diff --git a/orchestration/flows/bl832/register_mlflow.py b/orchestration/flows/bl832/register_mlflow.py index 31fa3760..93603295 100644 --- a/orchestration/flows/bl832/register_mlflow.py +++ b/orchestration/flows/bl832/register_mlflow.py @@ -6,6 +6,7 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) +logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)-7s | %(name)s - %(message)s") def register_mlflow_checkpoints(): @@ -16,14 +17,18 @@ def register_mlflow_checkpoints(): register_checkpoint( model_name="sam3-petiole", nersc_path=f"{scripts_dir}sam3_finetune/sam3/checkpoint_v6.pt", - alcf_path="/eagle/IRIBeta/als/seg_models/sam3/checkpoint_v6.pt", + alcf_path="/eagle/SYNAPS-I/segmentation/sam3_finetune/sam3/checkpoint_v6.pt", config=config, alias="production", description="SAM3 v6 fine-tuned on petiole micro-CT data.", inference_params={ - # ── paths ────────────────────────────────────────────────────────── - "original_checkpoint_path": - f"{scripts_dir}sam3_finetune/sam3/sam3.pt", + # ── site-specific HF caches ───────────────────────────────────────── + "nersc_hf_home": "/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface", + "nersc_hf_hub_cache": "/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface/hub", + "alcf_hf_home": "/eagle/SYNAPS-I/segmentation/.cache/huggingface", + "alcf_hf_hub_cache": "/eagle/SYNAPS-I/segmentation/.cache/huggingface", + # ── paths ─────────────────────────────────────────────────────────── + "original_checkpoint_path": f"{scripts_dir}sam3_finetune/sam3/sam3.pt", "bpe_path": f"{scripts_dir}sam3_finetune/sam3/bpe_simple_vocab_16e6.txt.gz", "conda_env_path": "/global/cfs/cdirs/als/data_mover/8.3.2/envs/sam3-py311", "seg_scripts_dir": f"{scripts_dir}inference_latest/forge_feb_seg_model_demo/", @@ -32,9 +37,9 @@ def register_mlflow_checkpoints(): "script_name": "src/inference_v6.py", "batch_size": 1, "patch_size": 400, - "confidence": [0.5], # list → JSON-encoded automatically + "confidence": [0.5], "overlap": 0.25, - "prompts": [ # list → JSON-encoded automatically + "prompts": [ "Phloem Fibers", "Hydrated Xylem vessels", "Air-based Pith cells", @@ -46,12 +51,17 @@ def register_mlflow_checkpoints(): register_checkpoint( model_name="dinov3-petiole", nersc_path="/global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/dino/best.ckpt", - alcf_path="/eagle/IRIBeta/als/seg_models/dino/best.ckpt", + alcf_path="/eagle/SYNAPS-I/segmentation/dino/best.ckpt", config=config, alias="production", description="DINOv3 fine-tuned on petiole micro-CT data.", inference_params={ - # ── paths ────────────────────────────────────────────────────────── + # ── site-specific HF caches ───────────────────────────────────────── + "nersc_hf_home": "/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface", + "nersc_hf_hub_cache": "/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface/hub", + "alcf_hf_home": "/eagle/SYNAPS-I/segmentation/.cache/huggingface", + "alcf_hf_hub_cache": "/eagle/SYNAPS-I/segmentation/.cache/huggingface", + # ── paths ─────────────────────────────────────────────────────────── "conda_env_path": "/global/cfs/cdirs/als/data_mover/8.3.2/envs/dino_demo", "seg_scripts_dir": f"{scripts_dir}inference_v5_multiseg/forge_feb_seg_model_demo/", # ── inference hyperparameters ─────────────────────────────────────── @@ -64,13 +74,20 @@ def register_mlflow_checkpoints(): register_checkpoint( model_name="dinov3-moon", nersc_path="/global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/dino/best_moon.ckpt", - alcf_path="/eagle/IRIBeta/als/seg_models/dino/best_moon.ckpt", + alcf_path="/eagle/SYNAPS-I/segmentation/seg_models/dino/best_moon.ckpt", config=config, alias="production", description="DINOv3 fine-tuned on lunar regolith micro-CT data (ice, particles, pores).", inference_params={ + # ── site-specific HF caches ───────────────────────────────────────── + "nersc_hf_home": "/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface", + "nersc_hf_hub_cache": "/global/cfs/cdirs/als/data_mover/8.3.2/.cache/huggingface/hub", + "alcf_hf_home": "/eagle/SYNAPS-I/segmentation/.cache/huggingface", + "alcf_hf_hub_cache": "/eagle/SYNAPS-I/segmentation/.cache/huggingface", + # ── paths ─────────────────────────────────────────────────────────── "conda_env_path": "/global/cfs/cdirs/als/data_mover/8.3.2/envs/dino_demo", "seg_scripts_dir": f"{scripts_dir}moon_seg/forge_feb_seg_model_demo/", + # ── inference hyperparameters ─────────────────────────────────────── "script_name": "src.inference_dino_v2", "batch_size": 4, "nproc_per_node": 4, @@ -106,28 +123,56 @@ def retrieve_mlflow_params_test() -> bool: ) sam3_checks = { - # MLflow should have overridden these + # ── MLflow should have overridden these ────────────────────────────── "finetuned_checkpoint_path": ( lambda v: "checkpoint" in v, "finetuned_checkpoint_path should contain 'checkpoint'" ), + "original_checkpoint_path": ( + lambda v: v.endswith(".pt") and "sam3" in v.lower(), + "original_checkpoint_path should point at a sam3 .pt file" + ), + "bpe_path": ( + lambda v: v.endswith(".txt.gz"), + "bpe_path should point at a .txt.gz vocab file" + ), "conda_env_path": ( lambda v: "sam3" in v, "conda_env_path should reference sam3 env" ), + "seg_scripts_dir": ( + lambda v: isinstance(v, str) and len(v) > 0, + "seg_scripts_dir should be a non-empty path" + ), + "checkpoints_dir": ( + lambda v: isinstance(v, str) and len(v) > 0, + "checkpoints_dir should be a non-empty path" + ), + "script_name": ( + lambda v: "inference" in v.lower(), + "script_name should reference an inference script" + ), "prompts": ( lambda v: isinstance(v, list) and len(v) > 0, "prompts should be a non-empty list (JSON-deserialized)" ), "confidence": ( - lambda v: isinstance(v, list), - "confidence should be a list (JSON-deserialized)" + lambda v: isinstance(v, list) and len(v) > 0, + "confidence should be a non-empty list (JSON-deserialized)" ), "batch_size": ( - lambda v: isinstance(v, int), - "batch_size should be an int" + lambda v: isinstance(v, int) and v > 0, + "batch_size should be a positive int" + ), + "patch_size": ( + lambda v: isinstance(v, int) and v > 0, + "patch_size should be a positive int" ), - # SLURM params should still come from config + "overlap": ( + lambda v: isinstance(v, float) and 0.0 <= v < 1.0, + "overlap should be a float in [0.0, 1.0)" + ), + # ── SLURM params should still come from config ─────────────────────── "qos": ( lambda v: v == config.nersc_segment_sam3_settings["qos"], "qos should be unchanged from config" @@ -160,23 +205,32 @@ def retrieve_mlflow_params_test() -> bool: ) dino_checks = { + # ── MLflow-overridden ──────────────────────────────────────────────── "dino_checkpoint_path": ( lambda v: v.endswith(".ckpt"), "dino_checkpoint_path should end with .ckpt" ), "conda_env_path": ( - lambda v: len(v) > 0, + lambda v: isinstance(v, str) and len(v) > 0, "conda_env_path should be non-empty" ), - "batch_size": ( - lambda v: isinstance(v, int) and v > 0, - "batch_size should be a positive int" + "seg_scripts_dir": ( + lambda v: isinstance(v, str) and len(v) > 0, + "seg_scripts_dir should be a non-empty path" ), "script_name": ( lambda v: "dino" in v.lower(), "script_name should reference dino" ), - # SLURM params unchanged + "batch_size": ( + lambda v: isinstance(v, int) and v > 0, + "batch_size should be a positive int" + ), + "nproc_per_node": ( + lambda v: isinstance(v, int) and v > 0, + "nproc_per_node should be a positive int" + ), + # ── SLURM params unchanged ─────────────────────────────────────────── "qos": ( lambda v: v == config.nersc_segment_dinov3_settings["qos"], "qos should be unchanged from config" @@ -209,9 +263,18 @@ def retrieve_mlflow_params_test() -> bool: ) moon_checks = { + # ── MLflow-overridden ──────────────────────────────────────────────── "dino_checkpoint_path": ( - lambda v: v.endswith(".ckpt"), - "dino_checkpoint_path should end with .ckpt" + lambda v: v.endswith(".ckpt") and "moon" in v.lower(), + "dino_checkpoint_path should end with .ckpt and reference moon" + ), + "conda_env_path": ( + lambda v: isinstance(v, str) and len(v) > 0, + "conda_env_path should be non-empty" + ), + "seg_scripts_dir": ( + lambda v: isinstance(v, str) and "moon" in v.lower(), + "seg_scripts_dir should reference moon_seg" ), "script_name": ( lambda v: "v2" in v.lower(), @@ -221,6 +284,11 @@ def retrieve_mlflow_params_test() -> bool: lambda v: isinstance(v, int) and v > 0, "batch_size should be a positive int" ), + "nproc_per_node": ( + lambda v: isinstance(v, int) and v > 0, + "nproc_per_node should be a positive int" + ), + # ── SLURM params unchanged ─────────────────────────────────────────── "qos": ( lambda v: v == config.nersc_segment_dinov3_moon_settings["qos"], "qos should be unchanged from config" diff --git a/orchestration/globus/get_globus_token.py b/orchestration/globus/get_globus_token.py new file mode 100644 index 00000000..f740a034 --- /dev/null +++ b/orchestration/globus/get_globus_token.py @@ -0,0 +1,397 @@ +#!/usr/bin/env python3 +import argparse +import json +import os +import stat +import time +import urllib.error +import urllib.request +from pathlib import Path + +import globus_sdk +from globus_sdk.exc import GlobusAPIError + +DEFAULT_TOKEN_FILE: Path = Path.home() / ".globus" / "auth_tokens.json" +CLIENT_ID = "fae5c579-490a-4d76-b6eb-d78f65caeb63" +RESOURCE_SERVER = "auth.globus.org" +IRI_SCOPE = ( + "https://auth.globus.org/scopes/" + "ed3e577d-f7f3-4639-b96e-ff5a8445d699/iri_api" +) +REQUIRED_SCOPES = { + "openid", + "profile", + "email", + "urn:globus:auth:scope:auth.globus.org:view_identities", +} +REQUESTED_SCOPES = REQUIRED_SCOPES | {IRI_SCOPE} +DEFAULT_IRI_VALIDATE_URL = "https://api.iri.nersc.gov/api/v1/account/projects" + + +def parse_args() -> argparse.Namespace: + default_token_file = Path.home() / ".globus" / "auth_tokens.json" + parser = argparse.ArgumentParser( + description=( + "Get Globus Auth tokens with required scopes. " + "Tokens are saved to a secure local file by default." + ) + ) + parser.add_argument( + "--token-file", + type=Path, + default=default_token_file, + help=f"Path for saved token JSON (default: {default_token_file})", + ) + parser.add_argument( + "--print-token", + action="store_true", + help="Print the access token to stdout (off by default).", + ) + parser.add_argument( + "--force-login", + action="store_true", + help="Skip refresh and force interactive browser login.", + ) + parser.add_argument( + "--refresh-only", + action="store_true", + help="Refresh saved tokens only; do not fall back to interactive login.", + ) + parser.add_argument( + "--prompt-login", + action="store_true", + help="Add prompt=login to the Globus authorize URL to force re-authentication.", + ) + parser.add_argument( + "--validate-iri", + action="store_true", + help="Validate the IRI token by calling the IRI account/projects endpoint.", + ) + parser.add_argument( + "--iri-validate-url", + default=DEFAULT_IRI_VALIDATE_URL, + help=( + "IRI endpoint used by --validate-iri " + f"(default: {DEFAULT_IRI_VALIDATE_URL})" + ), + ) + return parser.parse_args() + + +def parse_scope_string(scope_string: str) -> set[str]: + return set(scope_string.split()) if scope_string else set() + + +def ensure_private_parent_dir(path: Path) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + os.chmod(path.parent, 0o700) + + +def load_tokens(token_file: Path) -> dict | None: + if not token_file.exists(): + return None + with token_file.open("r", encoding="utf-8") as f: + return json.load(f) + + +def save_tokens(token_file: Path, tokens: dict) -> None: + ensure_private_parent_dir(token_file) + tmp = token_file.with_suffix(".tmp") + with os.fdopen( + os.open(tmp, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600), + "w", + encoding="utf-8", + ) as f: + json.dump(tokens, f, indent=2) + os.replace(tmp, token_file) + os.chmod(token_file, stat.S_IRUSR | stat.S_IWUSR) + + +def get_refresh_token(stored_tokens: dict) -> str | None: + if "refresh_token" in stored_tokens: + return stored_tokens.get("refresh_token") + + auth_tokens = stored_tokens.get(RESOURCE_SERVER) + if isinstance(auth_tokens, dict): + return auth_tokens.get("refresh_token") + + return None + + +def get_iri_token(token_response_data: dict) -> dict: + for token_data in token_response_data.get("other_tokens", []): + if IRI_SCOPE in parse_scope_string(token_data.get("scope", "")): + return token_data + raise RuntimeError(f"Missing token for required IRI scope: {IRI_SCOPE}") + + +def get_iri_refresh_token(stored_tokens: dict) -> str | None: + try: + return get_iri_token(stored_tokens).get("refresh_token") + except RuntimeError: + return None + + +def replace_iri_token(token_response_data: dict, iri_token_data: dict) -> dict: + merged = dict(token_response_data) + other_tokens = list(merged.get("other_tokens", [])) + for index, token_data in enumerate(other_tokens): + if IRI_SCOPE in parse_scope_string(token_data.get("scope", "")): + other_tokens[index] = iri_token_data + break + else: + other_tokens.append(iri_token_data) + merged["other_tokens"] = other_tokens + return merged + + +def validate_auth_data(auth_data: dict) -> dict: + if auth_data.get("resource_server") != RESOURCE_SERVER: + raise RuntimeError( + f"Missing token for required resource server: {RESOURCE_SERVER}" + ) + + granted = parse_scope_string(auth_data.get("scope", "")) + missing = REQUIRED_SCOPES - granted + if missing: + raise RuntimeError(f"Missing required scopes: {sorted(missing)}") + + return get_iri_token(auth_data) + + +def validate_iri_token(iri_token_data: dict, validate_url: str) -> dict | list: + request = urllib.request.Request( + validate_url, + headers={ + "accept": "application/json", + "Authorization": f"Bearer {iri_token_data['access_token']}", + }, + method="GET", + ) + try: + with urllib.request.urlopen(request) as response: + body = response.read().decode("utf-8") + data = json.loads(body) if body else {} + except urllib.error.HTTPError as exc: + body = exc.read().decode("utf-8") + details = body.strip() or exc.reason + raise RuntimeError( + f"IRI validation failed with HTTP {exc.code} from {validate_url}: {details}" + ) from exc + except urllib.error.URLError as exc: + raise RuntimeError( + f"IRI validation request failed for {validate_url}: {exc.reason}" + ) from exc + except json.JSONDecodeError as exc: + raise RuntimeError( + f"IRI validation returned non-JSON data from {validate_url}" + ) from exc + + if isinstance(data, dict): + session_info = data.get("session_info") + if isinstance(session_info, dict): + authentications = session_info.get("authentications") + if isinstance(authentications, dict) and not authentications: + raise RuntimeError( + "IRI validation succeeded but session_info.authentications is empty. " + "Re-run with --force-login --prompt-login and use a Chrome incognito window." + ) + + return data + + +def interactive_login( + client: globus_sdk.NativeAppAuthClient, + *, + prompt_login: bool = False, +) -> dict: + client.oauth2_start_flow( + requested_scopes=" ".join(sorted(REQUESTED_SCOPES)), + refresh_tokens=True, + ) + print("Open this URL, login, and consent:") + prompt = "login" if prompt_login else globus_sdk.MISSING + print(client.oauth2_get_authorize_url(prompt=prompt)) + code = input("\nEnter authorization code: ").strip() + if not code: + raise RuntimeError( + "No authorization code entered. Re-run the script and paste the code " + "shown by Globus after login." + ) + try: + token_response = client.oauth2_exchange_code_for_tokens(code) + except GlobusAPIError as exc: + if exc.http_status == 400: + raise RuntimeError( + "Authorization code exchange failed. The code was empty, invalid, " + "expired, or already used. Re-run the script and complete the " + "Globus login flow again." + ) from exc + raise RuntimeError( + f"Authorization code exchange failed with HTTP {exc.http_status}. " + "Re-run the script and try again." + ) from exc + return token_response.data + + +def refresh_tokens( + client: globus_sdk.NativeAppAuthClient, refresh_token: str +) -> dict | None: + try: + token_response = client.oauth2_refresh_token(refresh_token) + return token_response.data + except GlobusAPIError as exc: + print( + f"Refresh failed ({exc.http_status}); switching to interactive login." + ) + return None + + +def refresh_stored_tokens( + client: globus_sdk.NativeAppAuthClient, stored_tokens: dict +) -> tuple[dict | None, bool]: + iri_refresh_token = get_iri_refresh_token(stored_tokens) + if iri_refresh_token: + iri_token_data = refresh_tokens(client, iri_refresh_token) + if iri_token_data is not None: + return replace_iri_token(stored_tokens, iri_token_data), True + + auth_refresh_token = get_refresh_token(stored_tokens) + if auth_refresh_token: + auth_data = refresh_tokens(client, auth_refresh_token) + if auth_data is not None: + return auth_data, True + + return None, False + + +def get_iri_access_token( + token_file: Path = DEFAULT_TOKEN_FILE, + force_login: bool = False, + prompt_login: bool = False, +) -> str: + """ + Get a valid IRI access token, refreshing or prompting for login as needed. + Tokens are saved to the specified token_file path (default: ~/.globus/auth_tokens.json). + By default, the function will attempt to refresh saved tokens before falling back + to interactive login. Use force_login=True to skip refresh and require interactive login. + Use prompt_login=True to add prompt=login to the authorization URL, which forces + re-authentication even if the user has an active Globus session in their browser. + + Args: + token_file: Path to save and load token data (default: ~/.globus/auth_tokens.json) + force_login: If True, skip token refresh and require interactive login + prompt_login: If True, add prompt=login to the authorization URL to force re-authentication + + Returns: + A valid IRI access token string with the required scopes. + + Raises: + RuntimeError: If token refresh fails and interactive login is not allowed or fails, + or if the resulting tokens do not include a valid IRI access token. + """ + client = globus_sdk.NativeAppAuthClient(CLIENT_ID) + + # Fast path: if token exists and is not expired, return it directly without refreshing or saving + if not force_login: + stored = load_tokens(token_file) + if stored: + try: + iri_token = get_iri_token(stored) + expires_at = iri_token.get("expires_at_seconds", 0) + if expires_at and time.time() < expires_at - 60: # 60s buffer + return iri_token["access_token"] + except RuntimeError: + pass # fall through to refresh + + auth_data = None + used_refresh = False + if not force_login: + stored = load_tokens(token_file) + if stored: + auth_data, used_refresh = refresh_stored_tokens(client, stored) + if auth_data is None: + auth_data = interactive_login(client, prompt_login=prompt_login) + try: + iri_token_data = validate_auth_data(auth_data) + except RuntimeError as exc: + if used_refresh and "Missing token for required IRI scope" in str(exc): + auth_data = interactive_login(client, prompt_login=prompt_login) + iri_token_data = validate_auth_data(auth_data) + else: + raise + save_tokens(token_file, auth_data) + return iri_token_data["access_token"] + + +def main() -> None: + args = parse_args() + if args.force_login and args.refresh_only: + raise RuntimeError("Choose only one of --force-login or --refresh-only") + + client = globus_sdk.NativeAppAuthClient(CLIENT_ID) + + auth_data = None + used_refresh = False + if not args.force_login: + stored = load_tokens(args.token_file) + if stored: + auth_data, used_refresh = refresh_stored_tokens(client, stored) + + if auth_data is None: + if args.refresh_only: + raise RuntimeError( + "Refresh-only mode failed. No usable saved refresh token was found " + "or token refresh did not return the required IRI token." + ) + auth_data = interactive_login(client, prompt_login=args.prompt_login) + + try: + iri_token_data = validate_auth_data(auth_data) + except RuntimeError as exc: + if used_refresh and "Missing token for required IRI scope" in str(exc): + print( + "Refreshed tokens did not include the IRI token; " + "switching to interactive login." + ) + auth_data = interactive_login(client, prompt_login=args.prompt_login) + iri_token_data = validate_auth_data(auth_data) + else: + raise + + save_tokens(args.token_file, auth_data) + + if args.validate_iri: + validation_data = validate_iri_token(iri_token_data, args.iri_validate_url) + print(f"IRI validation succeeded against {args.iri_validate_url}") + if isinstance(validation_data, dict): + session_info = validation_data.get("session_info") + if isinstance(session_info, dict): + session_id = session_info.get("session_id") + if session_id: + print(f"IRI session_id: {session_id}") + elif isinstance(validation_data, list): + print(f"IRI validation response items: {len(validation_data)}") + + expires_at = iri_token_data.get("expires_at_seconds") + if expires_at: + ttl = int(expires_at - time.time()) + print(f"\nIRI access token valid for ~{max(ttl, 0)} seconds.") + + print(f"Saved token data to {args.token_file}") + print(f"Granted Globus Auth scopes: {auth_data.get('scope', '')}") + print(f"IRI token resource server: {iri_token_data.get('resource_server')}") + print(f"IRI token scopes: {iri_token_data.get('scope', '')}") + + if args.print_token: + print("\nIRI access token:") + print(iri_token_data["access_token"]) + else: + print( + "IRI access token not printed " + "(use --print-token to display it for the NERSC IRI API)." + ) + + +if __name__ == "__main__": + main() diff --git a/orchestration/mlflow.py b/orchestration/mlflow.py index cff8487c..c337a2ba 100644 --- a/orchestration/mlflow.py +++ b/orchestration/mlflow.py @@ -1,16 +1,23 @@ import logging from dataclasses import dataclass, field +from dotenv import load_dotenv import json +import os import requests from typing import Any import mlflow from mlflow.tracking import MlflowClient +import mlflow.utils.rest_utils as rest_utils + from orchestration.config import BeamlineConfig logger = logging.getLogger(__name__) +_AMSC_PATCH_FLAG: str = "_amsc_x_api_key_patched" +load_dotenv() + @dataclass class ModelCheckpointInfo: @@ -48,13 +55,63 @@ def _is_mlflow_reachable(tracking_uri: str, timeout: float = 2.0) -> bool: Returns: True if the server responds with HTTP 200, False otherwise. """ + headers = {} + api_key = os.environ.get("AMSC_API_KEY") + if api_key: + headers["X-Api-Key"] = api_key try: - response = requests.get(f"{tracking_uri}/health", timeout=timeout) + response = requests.get( + f"{tracking_uri}/health", headers=headers, timeout=timeout + ) return response.status_code == 200 except Exception: return False +def _enable_amsc_x_api_key() -> bool: + """Patch mlflow.utils.rest_utils.http_request to inject X-Api-Key. + + Required by the American Science Cloud MLflow server, which enforces + API-key auth on all REST calls. Standard MLflow does not send custom + headers, so we wrap ``http_request`` at import time. + + Idempotent: repeat calls are no-ops thanks to a sentinel attribute on + the wrapper. Silently skips patching if ``AMSC_API_KEY`` is unset, + which lets the same codebase target non-AMSC MLflow servers. + + Returns: + True if the patch is (or was already) active, False if the API + key env var is unset. + """ + + api_key = os.environ.get("AMSC_API_KEY") + if not api_key: + return False + + if getattr(rest_utils.http_request, _AMSC_PATCH_FLAG, False): + return True + + original = rest_utils.http_request + + def patched(host_creds, endpoint, method, *args, **kwargs): + # MLflow internals call http_request with either `headers` or + # `extra_headers` depending on the code path — handle both. + if "headers" in kwargs and kwargs["headers"] is not None: + h = dict(kwargs["headers"]) + h["X-Api-Key"] = api_key + kwargs["headers"] = h + else: + h = dict(kwargs.get("extra_headers") or {}) + h["X-Api-Key"] = api_key + kwargs["extra_headers"] = h + return original(host_creds, endpoint, method, *args, **kwargs) + + setattr(patched, _AMSC_PATCH_FLAG, True) + rest_utils.http_request = patched + logger.info("AMSC X-Api-Key injection enabled for MLflow REST calls.") + return True + + def get_mlflow_client(config: BeamlineConfig) -> MlflowClient: """Construct an MlflowClient pointed at the configured tracking server. @@ -65,6 +122,7 @@ def get_mlflow_client(config: BeamlineConfig) -> MlflowClient: An authenticated MlflowClient instance. """ tracking_uri = config.mlflow["tracking_uri"] + _enable_amsc_x_api_key() # Idempotent patch for AMSC API key injection mlflow.set_tracking_uri(tracking_uri) return MlflowClient(tracking_uri=tracking_uri) @@ -183,7 +241,19 @@ def register_checkpoint( client.create_registered_model(model_name) mlflow.set_tracking_uri(config.mlflow["tracking_uri"]) + + # Use a dedicated experiment so the creator (this user) gets MANAGE + # permission automatically — avoids 403 on the default experiment. + experiment_name = config.mlflow.get("experiment_name", "als-model-registration") + experiment = mlflow.get_experiment_by_name(experiment_name) + if experiment is None: + experiment_id = mlflow.create_experiment(experiment_name) + logger.info(f"Created MLflow experiment '{experiment_name}' (id={experiment_id}).") + else: + experiment_id = experiment.experiment_id + with mlflow.start_run( + experiment_id=experiment_id, run_name=f"register_{model_name}", tags={"mlflow.note.content": description}, ) as run: @@ -247,7 +317,21 @@ def log_segmentation_metrics( run_tags: dict[str, str] = {"model": model_name, "slurm_job_id": job_id} + tracking_uri = config.mlflow["tracking_uri"] + mlflow.set_tracking_uri(tracking_uri) + _enable_amsc_x_api_key() # ensure AMSC auth patch is active for this entrypoint too + + experiment_name = config.mlflow.get("experiment_name", "als-model-registration") + experiment = mlflow.get_experiment_by_name(experiment_name) + if experiment is None: + experiment_id = mlflow.create_experiment(experiment_name) + else: + experiment_id = experiment.experiment_id + + run_tags: dict[str, str] = {"model": model_name, "slurm_job_id": job_id} + with mlflow.start_run( + experiment_id=experiment_id, run_name=run_name, nested=parent_run_id is not None, parent_run_id=parent_run_id, diff --git a/scripts/login_to_globus_and_prefect.sh b/scripts/login_to_globus_and_prefect.sh index a38b629f..b8f60bde 100755 --- a/scripts/login_to_globus_and_prefect.sh +++ b/scripts/login_to_globus_and_prefect.sh @@ -18,4 +18,7 @@ export GLOBUS_CLI_CLIENT_SECRET="$GLOBUS_CLIENT_SECRET" export GLOBUS_COMPUTE_CLIENT_ID="$GLOBUS_CLIENT_ID" export GLOBUS_COMPUTE_CLIENT_SECRET="$GLOBUS_CLIENT_SECRET" export PREFECT_API_KEY="$PREFECT_API_KEY" -export PREFECT_API_URL="$PREFECT_API_URL" \ No newline at end of file +export PREFECT_API_URL="$PREFECT_API_URL" +export NERSC_USERNAME="$NERSC_USERNAME" +export PATH_NERSC_CLIENT_ID="$PATH_NERSC_CLIENT_ID" +export PATH_NERSC_PRI_KEY="$PATH_NERSC_PRI_KEY" \ No newline at end of file