forked from Isomaniac/Python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathhelpers.py
More file actions
138 lines (107 loc) · 4.55 KB
/
helpers.py
File metadata and controls
138 lines (107 loc) · 4.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""
Athena query utilities for findingsvalidator
"""
import os
import time
import boto3
from functools import lru_cache
from botocore.exceptions import ClientError
from logger import logger
from config import CONFIG
# ---------------------------------------------------------------------------
# SSM Parameter Store
# ---------------------------------------------------------------------------
@lru_cache(maxsize=None)
def get_ssm_parameter(param_name: str) -> str:
"""Get parameter from AWS Systems Manager Parameter Store."""
try:
ssm = boto3.client("ssm", region_name="us-east-1")
response = ssm.get_parameter(Name=param_name, WithDecryption=True)
return response["Parameter"]["Value"]
except ClientError as e:
logger.error(f"Fetching SSM parameter {param_name}: {e}")
raise
# ---------------------------------------------------------------------------
# Environment / Region / S3 bucket
# ---------------------------------------------------------------------------
# For local testing, hardcode ENV. For production, uncomment the line below:
ENV = CONFIG["ENV"] # Use local config for testing
# ENV = get_ssm_parameter("/eqa/udcrm-control-automation/config/env") # Uncomment for production
if CONFIG["ENV"] == "dev":
REGION = os.environ.get("AWS_REGION", "us-east-1")
else:
REGION = "us-east-1"
if REGION == "us-east-1":
S3_BUCKET = f"s3://eqa-udcrm-datalake-bucket-{ENV}"
else:
S3_BUCKET = f"s3://eqa-udcrm-datalake-bucket-{ENV}-{REGION}"
# ---------------------------------------------------------------------------
# Athena utilities
# ---------------------------------------------------------------------------
athena_client = boto3.client("athena", region_name=REGION)
def execute_athena_query(database: str, query: str, workgroup: str = "primary") -> str:
"""
Start an Athena query and return the QueryExecutionId (non-blocking).
Args:
database: Athena database name
query: SQL query string
workgroup: Athena workgroup (default: "primary")
Returns:
QueryExecutionId
"""
params = {
"QueryString": query,
"QueryExecutionContext": {"Database": database},
"ResultConfiguration": {"OutputLocation": S3_BUCKET},
"WorkGroup": workgroup,
}
response = athena_client.start_query_execution(**params)
return response["QueryExecutionId"]
def wait_for_query(query_execution_id: str, max_wait: int = 300, poll_interval: int = 5) -> None:
"""
Block until the Athena query succeeds, or raise on failure / timeout.
Args:
query_execution_id: Athena query ID
max_wait: Maximum wait time in seconds
poll_interval: Polling interval in seconds
Raises:
RuntimeError: If query fails or is cancelled
TimeoutError: If query doesn't complete in time
"""
elapsed = 0
while elapsed < max_wait:
response = athena_client.get_query_execution(QueryExecutionId=query_execution_id)
state = response["QueryExecution"]["Status"]["State"]
if state == "SUCCEEDED":
return
if state in ("FAILED", "CANCELLED"):
reason = response["QueryExecution"]["Status"].get("StateChangeReason", "Unknown")
raise RuntimeError(f"Athena query {query_execution_id} {state}: {reason}")
time.sleep(poll_interval)
elapsed += poll_interval
raise TimeoutError(f"Athena query {query_execution_id} did not complete in {max_wait}s")
def get_query_results(query_execution_id: str) -> list[dict]:
"""
Fetch all result rows from a completed Athena query as a list of dicts.
Args:
query_execution_id: Athena query ID
Returns:
List of result rows as dictionaries
"""
results = []
next_token = None
while True:
kwargs = {"QueryExecutionId": query_execution_id}
if next_token:
kwargs["NextToken"] = next_token
response = athena_client.get_query_results(**kwargs)
columns = [col["Name"] for col in response["ResultSet"]["ResultSetMetadata"]["ColumnInfo"]]
# Skip the header row on the first page only
rows = response["ResultSet"]["Rows"][1:] if not results else response["ResultSet"]["Rows"]
for row in rows:
values = [cell.get("VarCharValue", "") for cell in row["Data"]]
results.append(dict(zip(columns, values)))
next_token = response.get("NextToken")
if not next_token:
break
return results