diff --git a/examples/job_scripts/pearl_fast_jobs.py b/examples/job_scripts/pearl_fast_jobs.py new file mode 100644 index 00000000..76b83335 --- /dev/null +++ b/examples/job_scripts/pearl_fast_jobs.py @@ -0,0 +1,291 @@ +import argparse +import logging +import os +import sys +import time +from pathlib import Path +from typing import Any + +import requests + +from fia_api.core.models import State + +PEARL_SCRIPT = """ +from mantid.simpleapi import * +import numpy as np +import json + +Cycles2Run=['15_2','25_3','25_4'] +Path2Data = r'/archive/NDXPEARL/Instrument/data' + +CycleDict = { + "start_15_2": 90482, + "end_15_2": 91528, + "start_25_3": 124935, + "end_25_3": 124946, + "start_25_4": 124987, + "end_25_4": 125000, +} + +output = "" + +for cycle in Cycles2Run: + reject=[] + peak_centres=[] + peak_centres_error=[] + peak_intensity=[] + peak_intensity_error=[] + uAmps=[] + RunNo=[] + index=0 + start=CycleDict['start_'+cycle] + end=CycleDict['end_'+cycle] + for i in range(start,end+1): + if i == 95382: + continue + Load(Filename=Path2Data+'/cycle_'+cycle+'/PEARL00'+ str(i)+'.nxs', OutputWorkspace=str(i)) + ws = mtd[str(i)] + run = ws.getRun() + pcharge = run.getProtonCharge() + if pcharge <1.0: + reject.append(str(i)) + DeleteWorkspace(str(i)) + continue + NormaliseByCurrent(InputWorkspace=str(i), OutputWorkspace=str(i)) + ExtractSingleSpectrum(InputWorkspace=str(i),WorkspaceIndex=index, + OutputWorkspace=str(i)+ '_' + str(index)) + CropWorkspace(InputWorkspace=str(i)+ '_' + str(index), Xmin=1100, + Xmax=19990, OutputWorkspace=str(i)+ '_' + str(index)) + DeleteWorkspace(str(i)) + + fit_output = Fit(Function='name=Gaussian,Height=19.2327,\\ + PeakCentre=4843.8,Sigma=1532.64,\\ + constraints=(4600 5200.0: + DeleteWorkspace(str(i)+'_0_fit_Parameters') + DeleteWorkspace(str(i)+'_0_fit_Workspace') + DeleteWorkspace(str(i)+'_0') + DeleteWorkspace(str(i)+'_0_fit_NormalisedCovarianceMatrix') + reject.append(str(i)) + continue + else: + uAmps.append(pcharge) + peak_centres.append(paramTable.column(1)[1]) + peak_centres_error.append(paramTable.column(2)[1]) + peak_intensity.append(paramTable.column(1)[0]) + peak_intensity_error.append(paramTable.column(2)[0]) + RunNo.append(str(i)) + DeleteWorkspace(str(i)+'_0') + DeleteWorkspace(str(i)+'_0_fit_Parameters') + DeleteWorkspace(str(i)+'_0_fit_Workspace') + DeleteWorkspace(str(i)+'_0_fit_NormalisedCovarianceMatrix') + + combined_data=np.column_stack( + (RunNo, uAmps, peak_intensity, peak_intensity_error, peak_centres, peak_centres_error) + ) + + output += f"peak_centres_{cycle}.csv, " + print(f"combined data for {cycle}: ") + print(combined_data) + np.savetxt('/output/peak_centres_'+cycle+'.csv', combined_data, delimiter=", ", fmt='% s',) + +print("Outputting files") +print(json.dumps({"status": "Successful", + "status_message":"Simple job run successfully.", + "output_files": output, "stacktrace": ""})) +""" + + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + + +class PearlFastStart: + def __init__( + self, + fia_url: str, + auth_url: str, + username: str | None, + password: str | None, + output_dir: str | Path, + token_refresh_interval: int = 3600, # Default to 1 hour + ) -> None: + self.fia_url = fia_url.rstrip("/") + self.auth_url = auth_url.rstrip("/") + self.username = username + self.password = password + self.output_dir = Path(output_dir) + self.token: str | None = None + self.token_refresh_interval = token_refresh_interval + self._token_acquired_at: float = 0.0 + + def authenticate(self) -> None: + logger.info(f"Authenticating user {self.username} at {self.auth_url}") + try: + response = requests.post( + f"{self.auth_url}/login", json={"username": self.username, "password": self.password}, timeout=30 + ) + response.raise_for_status() + body = response.json() + self.token = body if isinstance(body, str) else body.get("token") + if not self.token: + raise ValueError("No token found in login response") + self._token_acquired_at = time.monotonic() + logger.info("Authentication successful") + except Exception as e: + logger.error(f"Authentication failed: {e}") + raise + + def _is_token_expiring(self) -> bool: + """Check if the token is due for a proactive refresh.""" + return (time.monotonic() - self._token_acquired_at) >= self.token_refresh_interval + + def _refresh_token_if_needed(self) -> None: + """Re-authenticate proactively if the token is close to expiring.""" + if self._is_token_expiring(): + logger.info("Token nearing expiry, refreshing authentication") + self.authenticate() + + def get_headers(self) -> dict[str, str]: + return {"Authorization": f"Bearer {self.token}"} + + def submit_job(self, script: str) -> int: + logger.info(f"Submitting fast-start job script to {self.fia_url}") + payload = {"script": script} + # post /job/fast-start + response = requests.post(f"{self.fia_url}/execute", json=payload, headers=self.get_headers(), timeout=30) + response.raise_for_status() + job_id = int(response.json()) + logger.info(f"Job submitted successfully. Job ID: {job_id}") + return job_id + + def _poll_job_status(self, job_id: int) -> dict[str, Any]: + """Poll the job status endpoint, re-authenticating on auth-related HTTP errors.""" + self._refresh_token_if_needed() + response = requests.get(f"{self.fia_url}/job/{job_id}", headers=self.get_headers(), timeout=30) + + if response.status_code in (401, 403, 404): + logger.warning( + f"Received HTTP {response.status_code} while polling job {job_id}, re-authenticating and retrying" + ) + self.authenticate() + response = requests.get(f"{self.fia_url}/job/{job_id}", headers=self.get_headers(), timeout=30) + + response.raise_for_status() + return response.json() + + def monitor_job(self, job_id: int, poll_interval: int = 5) -> dict[str, Any]: + logger.info(f"Monitoring job {job_id}") + while True: + # this won't work until FIA-API supports job status for fast-start jobs, + # but we want to include it here for when that is implemented + job_data: dict[str, Any] = self._poll_job_status(job_id) + state = job_data.get("state") + + logger.info(f"Job {job_id} current state: {state}") + + if state == State.SUCCESSFUL.value: + logger.info(f"Job {job_id} completed successfully") + return job_data + if state in [State.ERROR.value, State.UNSUCCESSFUL.value]: + error_msg = job_data.get("status_message", "No error message provided") + logger.error(f"Job {job_id} failed with state {state}: {error_msg}") + raise RuntimeError(f"Job {job_id} failed: {error_msg}") + + time.sleep(poll_interval) + + def download_results(self, job_id: int, outputs: str | list[str] | None) -> None: + if not outputs: + logger.warning(f"No outputs found for job {job_id}") + return + + # Outputs is expected to be a string or list of filenames + filenames = outputs.split(",") if isinstance(outputs, str) else outputs + + self.output_dir.mkdir(parents=True, exist_ok=True) + + for file in filenames: + filename = file.strip() + if not filename: + continue + + logger.info(f"Downloading {filename} for job {job_id}") + response = requests.get( + f"{self.fia_url}/job/{job_id}/filename/{filename}", headers=self.get_headers(), timeout=30, stream=True + ) + response.raise_for_status() + + file_path = self.output_dir / filename + with Path.open(file_path, "wb") as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + logger.info(f"Downloaded {filename} to {file_path}") + + def run(self) -> None: + try: + self.authenticate() + job_id = self.submit_job(PEARL_SCRIPT) + job_data = self.monitor_job(job_id) + self.download_results(job_id, job_data.get("outputs")) + logger.info("PEARL automation completed successfully") + except Exception as e: + logger.error(f"PEARL automation failed: {e}") + sys.exit(1) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Automate PEARL Mantid jobs via FIA API") + parser.add_argument("--fia-url", default=os.environ.get("FIA_API_URL", "http://localhost:8000"), help="FIA API URL") + parser.add_argument( + "--auth-url", default=os.environ.get("AUTH_API_URL", "http://localhost:8001"), help="Auth API URL" + ) + parser.add_argument("--username", default=os.environ.get("PEARL_USERNAME"), help="Auth Username") + parser.add_argument("--password", default=os.environ.get("PEARL_PASSWORD"), help="Auth Password") + parser.add_argument( + "--output-dir", default=os.environ.get("OUTPUT_DIRECTORY", "./output"), help="Output directory for results" + ) + parser.add_argument( + "--runner", default=os.environ.get("MANTID_RUNNER_IMAGE"), help="Specific Mantid runner image to use" + ) + parser.add_argument( + "--token-refresh-interval", + type=int, + default=3600, + help="Interval (in seconds) to refresh the authentication token", + ) + + args = parser.parse_args() + + if not args.username or not args.password: + err_msg = ( + "Username and password must be provided via " + "arguments or environment variables (PEARL_USERNAME, PEARL_PASSWORD)" + ) + logger.error(err_msg) + sys.exit(1) + + automation = PearlFastStart( + args.fia_url, + args.auth_url, + args.username, + args.password, + args.output_dir, + args.token_refresh_interval, + ) + automation.run() + + +if __name__ == "__main__": + main() diff --git a/test/examples/test_pearl_fast_jobs.py b/test/examples/test_pearl_fast_jobs.py new file mode 100644 index 00000000..fe20806c --- /dev/null +++ b/test/examples/test_pearl_fast_jobs.py @@ -0,0 +1,255 @@ +import os +import subprocess +import sys +import unittest +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +# Add the project root to sys.path to import the script +sys.path.append(str(Path(__file__).parent.parent.parent)) + +from examples.job_scripts.pearl_fast_jobs import PearlFastStart, main +from fia_api.core.models import State + + +@pytest.fixture(scope="session") +def get_fast_start(): + fia_url = "http://fia-api" + auth_url = "http://auth-api" + username = "test_user" + password = "test_pass" # noqa S105 + output_dir = "./test_output" + return PearlFastStart(fia_url, auth_url, username, password, output_dir) + + +@patch("examples.job_scripts.pearl_fast_jobs.requests.post") +def test_authenticate_success(mock_post, get_fast_start): + automation = get_fast_start + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"token": "valid_token"} + mock_post.return_value = mock_response + + automation.authenticate() + assert automation.token == "valid_token" # noqa S105 + mock_post.assert_called_once_with( + f"{automation.auth_url}/login", + json={"username": automation.username, "password": automation.password}, + timeout=30, + ) + + +@patch("examples.job_scripts.pearl_fast_jobs.requests.post") +def test_authenticate_success_string_token(mock_post, get_fast_start): + """Auth APIs that return a bare string token (not a dict) are also handled.""" + automation = get_fast_start + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = "valid_token" + mock_post.return_value = mock_response + + automation.authenticate() + assert automation.token == "valid_token" # noqa: S105 + + +@patch("examples.job_scripts.pearl_fast_jobs.requests.post") +def test_authenticate_no_token_raises_error(mock_post, get_fast_start): + automation = get_fast_start + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {} # Missing token + mock_post.return_value = mock_response + + with pytest.raises(ValueError, match="No token found in login response"): + automation.authenticate() + + +@patch("examples.job_scripts.pearl_fast_jobs.requests.post") +def test_submit_job_success(mock_post, get_fast_start): + automation = get_fast_start + automation.token = "valid_token" # noqa S105 + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = 12345 + mock_post.return_value = mock_response + + job_id = automation.submit_job("print('hello')") + expected_job_id = 12345 + assert job_id == expected_job_id + mock_post.assert_called_once() + + +@patch("examples.job_scripts.pearl_fast_jobs.requests.get") +@patch("examples.job_scripts.pearl_fast_jobs.time.sleep", return_value=None) +def test_monitor_job_success(mock_sleep, mock_get, get_fast_start): + automation = get_fast_start + automation.token = "valid_token" # noqa S105 + + # Mock responses for polling: 1st NOT_STARTED, 2nd SUCCESSFUL + mock_response_1 = MagicMock() + mock_response_1.status_code = 200 + mock_response_1.json.return_value = {"state": State.NOT_STARTED.value} + + mock_response_2 = MagicMock() + mock_response_2.status_code = 200 + mock_response_2.json.return_value = {"state": State.SUCCESSFUL.value, "outputs": "file1.csv,file2.csv"} + + mock_get.side_effect = [mock_response_1, mock_response_2] + + job_data = automation.monitor_job(12345, poll_interval=0) + assert job_data["state"] == State.SUCCESSFUL.value + expected_call_count = 2 + assert mock_get.call_count == expected_call_count + + +@pytest.mark.parametrize("state", [State.ERROR.value, State.UNSUCCESSFUL.value]) +@patch("examples.job_scripts.pearl_fast_jobs.requests.get") +def test_monitor_job_failure_raises_error(mock_get, get_fast_start, state): + automation = get_fast_start + automation.token = "valid_token" # noqa S105 + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"state": state, "status_message": "Something went wrong"} + mock_get.return_value = mock_response + + with pytest.raises(RuntimeError, match="Something went wrong"): + automation.monitor_job(12345) + + +@patch("examples.job_scripts.pearl_fast_jobs.requests.get") +@patch("examples.job_scripts.pearl_fast_jobs.Path.open", new_callable=unittest.mock.mock_open) +def test_download_results(mock_open, mock_get, get_fast_start): + automation = get_fast_start + automation.token = "valid_token" # noqa S105 + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.iter_content.return_value = [b"data1", b"data2"] + mock_get.return_value = mock_response + + automation.download_results(12345, "file1.csv, file2.csv, ") # Added empty entry to test filter + expected_call_count = 2 + + assert mock_get.call_count == expected_call_count + assert mock_open.call_count == expected_call_count + + +def test_download_results_no_outputs(get_fast_start): + automation = get_fast_start + with patch("examples.job_scripts.pearl_fast_jobs.logger.warning") as mock_log: + automation.download_results(12345, None) + mock_log.assert_called_with("No outputs found for job 12345") + + +@patch("examples.job_scripts.pearl_fast_jobs.PearlFastStart.authenticate") +@patch("examples.job_scripts.pearl_fast_jobs.time.monotonic") +def test_refresh_token_if_needed_refreshes_when_expiring(mock_monotonic, mock_auth, get_fast_start): + automation = get_fast_start + automation.token_refresh_interval = 300 + automation._token_acquired_at = 100.0 + mock_monotonic.return_value = 500.0 # Token is expiring + automation._refresh_token_if_needed() + mock_auth.assert_called_once() + + +@patch("examples.job_scripts.pearl_fast_jobs.requests.post") +@patch("examples.job_scripts.pearl_fast_jobs.requests.get") +@patch("examples.job_scripts.pearl_fast_jobs.time.sleep", return_value=None) +def test_monitor_job_reauths_on_401(mock_sleep, mock_get, mock_post, get_fast_start): + """When a 401 is received, monitor_job re-authenticates and retries the request.""" + automation = get_fast_start + automation.token = "old_token" # noqa: S105 + automation._token_acquired_at = 9999999999.0 + + # First GET returns 401, then re-auth succeeds, retry GET returns success + mock_401 = MagicMock() + mock_401.status_code = 401 + + mock_success = MagicMock() + mock_success.status_code = 200 + mock_success.json.return_value = {"state": State.SUCCESSFUL.value, "outputs": "out.csv"} + + mock_get.side_effect = [mock_401, mock_success] + + # Mock re-authentication + mock_auth_response = MagicMock() + mock_auth_response.status_code = 200 + mock_auth_response.json.return_value = {"token": "new_token"} + mock_post.return_value = mock_auth_response + + job_data = automation.monitor_job(12345, poll_interval=0) + assert job_data["state"] == State.SUCCESSFUL.value + # One re-auth POST should have been made + mock_post.assert_called_once() + # Two GETs: the 401 and the retry + expected_get_count = 2 + assert mock_get.call_count == expected_get_count + + +@patch("examples.job_scripts.pearl_fast_jobs.PearlFastStart.authenticate") +@patch("examples.job_scripts.pearl_fast_jobs.PearlFastStart.submit_job") +@patch("examples.job_scripts.pearl_fast_jobs.PearlFastStart.monitor_job") +@patch("examples.job_scripts.pearl_fast_jobs.PearlFastStart.download_results") +def test_run_success(mock_dl, mock_mon, mock_sub, mock_auth, get_fast_start): + automation = get_fast_start + mock_sub.return_value = 1 + mock_mon.return_value = {"outputs": "out"} + + automation.run() + + mock_auth.assert_called_once() + mock_sub.assert_called_once() + mock_mon.assert_called_once_with(1) + mock_dl.assert_called_once_with(1, "out") + + +@patch("examples.job_scripts.pearl_fast_jobs.PearlFastStart.authenticate", side_effect=Exception("Auth fail")) +@patch("examples.job_scripts.pearl_fast_jobs.sys.exit") +def test_run_failure(mock_exit: MagicMock, mock_auth: MagicMock, get_fast_start: PearlFastStart) -> None: + automation = get_fast_start + automation.run() + mock_exit.assert_called_once_with(1) + + +@patch("examples.job_scripts.pearl_fast_jobs.sys.argv", ["pearl_fast_jobs.py", "--username", "u", "--password", "p"]) +@patch("examples.job_scripts.pearl_fast_jobs.PearlFastStart.run") +def test_main_success(mock_run: MagicMock) -> None: + main() + mock_run.assert_called_once() + + +@patch("examples.job_scripts.pearl_fast_jobs.sys.argv", ["pearl_fast_jobs.py", "--username", "", "--password", ""]) +@patch("examples.job_scripts.pearl_fast_jobs.sys.exit", side_effect=SystemExit) +def test_main_no_creds_exits(mock_exit: MagicMock) -> None: + with patch.dict(os.environ, {}, clear=True), pytest.raises(SystemExit): + main() + mock_exit.assert_called_once_with(1) + + +@patch("examples.job_scripts.pearl_fast_jobs.requests.get") +@patch("examples.job_scripts.pearl_fast_jobs.Path.open", new_callable=unittest.mock.mock_open) +def test_download_results_list_input(mock_open: MagicMock, mock_get: MagicMock, get_fast_start: PearlFastStart) -> None: + automation = get_fast_start + automation.token = "valid_token" # noqa: S105 + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.iter_content.return_value = [b"data"] + mock_get.return_value = mock_response + + automation.download_results(12345, ["file1.csv"]) + assert mock_get.call_count == 1 + assert mock_open.call_count == 1 + + +def test_main_entry_point() -> None: + # Run the script as a subprocess to cover the if __name__ == "__main__": block + # We provide invalid args so it exits quickly + result = subprocess.run( + [sys.executable, "-m", "examples.job_scripts.pearl_fast_jobs", "--username", ""], + capture_output=True, + text=True, + check=False, + ) + assert result.returncode == 1 + assert "Username and password must be provided" in result.stderr