Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 39 additions & 7 deletions app/db/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,18 +83,50 @@ async def build_conditions(

return where_clause, query_params

def get_query(columns_to_select: str, where_clause: str) -> str:
return f"""
SORTABLE_COLUMNS = {
"accession": "pua.accession",
"amino_acids": "pua.amino_acids",
"organism": "pua.organism",
"curation_status": "pua.curation_status",
"predicted_ec": "puace.max_clean_ec_confidence",
}

DEFAULT_ORDER_BY = "puace.max_clean_ec_confidence DESC, pua.amino_acids ASC, pua.predictions_uniprot_annot_id ASC"


def parse_ordering(ordering: str | None) -> str:
"""Parse an ordering string like '-accession' into a SQL ORDER BY clause.

Returns the default ordering if ordering is None or invalid.
"""
if not ordering:
return DEFAULT_ORDER_BY

descending = ordering.startswith("-")
field = ordering.lstrip("-")

if field not in SORTABLE_COLUMNS:
return DEFAULT_ORDER_BY

direction = "DESC" if descending else "ASC"
col = SORTABLE_COLUMNS[field]
return f"{col} {direction}, pua.predictions_uniprot_annot_id ASC"


def get_query(columns_to_select: str, where_clause: str, include_order_by: bool = True, ordering: str | None = None) -> str:
query = f"""
SELECT
{columns_to_select}
FROM cleandb.predictions_uniprot_annot pua
INNER JOIN cleandb.predictions_uniprot_annot_clean_ec_mv01 puace
ON puace.predictions_uniprot_annot_id = pua.predictions_uniprot_annot_id
LEFT JOIN cleandb.predictions_uniprot_annot_ec_mv01 puae
ON puae.predictions_uniprot_annot_id = pua.predictions_uniprot_annot_id
WHERE {where_clause}
ORDER BY puace.max_clean_ec_confidence DESC, pua.amino_acids ASC, pua.predictions_uniprot_annot_id ASC
"""
WHERE {where_clause}"""
if include_order_by:
query += f"""
ORDER BY {parse_ordering(ordering)}"""
return query

async def get_filtered_data(
db: Database, params: CLEANSearchQueryParams
Expand All @@ -120,7 +152,7 @@ async def get_filtered_data(
"""

# Build the main query
query = get_query(columns_to_select, where_clause)
query = get_query(columns_to_select, where_clause, ordering=params.ordering)

# Add pagination
if params.limit is not None:
Expand All @@ -142,7 +174,7 @@ async def get_total_count(db: Database, params: CLEANSearchQueryParams) -> int:
"""Get total count of records matching the filters."""
where_clause, query_params = await build_conditions(params)

query = get_query("COUNT(*)", where_clause)
query = get_query("COUNT(*)", where_clause, include_order_by=False)

# Extract query parameters from the dictionary
query_args = list(query_params.values())
Expand Down
5 changes: 5 additions & 0 deletions app/models/query_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ class CLEANSearchQueryParams(BaseModel):
None, description="Maximum number of records to return"
)
offset: Optional[int] = Field(0, description="Number of records to skip")
ordering: Optional[str] = Field(
None,
description="Column to sort by. Prefix with '-' for descending order. "
"Allowed values: accession, amino_acids, organism, curation_status, predicted_ec",
)

class CLEANTypeaheadQueryParams(BaseModel):
"""Query parameters for CLEAN typeahead suggestions."""
Expand Down
119 changes: 56 additions & 63 deletions app/routers/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ def parse_query_params(
None, description="Maximum number of records to return"
),
offset: Optional[int] = Query(0, description="Number of records to skip"),
ordering: Optional[str] = Query(
None,
description="Column to sort by. Prefix with '-' for descending order. "
"Allowed values: accession, amino_acids, organism, curation_status, predicted_ec",
),
) -> CLEANSearchQueryParams:
"""Parse and validate query parameters."""
try:
Expand All @@ -80,6 +85,7 @@ def parse_query_params(
format=format,
limit=limit,
offset=offset,
ordering=ordering,
)
except Exception as e:
logger.error(f"Error parsing query parameters: {e}")
Expand All @@ -99,19 +105,14 @@ async def get_data(
"""

try:
# # Get total count for the query (without pagination)
# total_count = await get_total_count(db, params)

# # Apply automatic pagination if results exceed threshold and no explicit limit provided
# if total_count > settings.AUTO_PAGINATION_THRESHOLD and params.limit is None:
# params.auto_paginated = True
# params.limit = settings.AUTO_PAGINATION_THRESHOLD
# logger.info(
# f"Auto-pagination applied. Results limited to {params.limit} records."
# )

params.limit = 500
# Get data from database (now with potential auto-pagination applied)
# Apply default page size if no explicit limit provided
if params.limit is None:
params.limit = settings.AUTO_PAGINATION_THRESHOLD

# Get total count for the query (without pagination)
total_count = await get_total_count(db, params)

# Get data from database
data = await get_filtered_data(db, params)

# Handle response format
Expand Down Expand Up @@ -143,12 +144,10 @@ async def get_data(
headers={"Content-Disposition": "attachment; filename=CLEAN_data.csv"},
)
else:
# TODO don't we want total_count to be the value returned by get_total_count?
total_count = len(data)
response = CLEANSearchResponse(
total=total_count,
offset=params.offset,
limit=total_count if total_count < params.limit else params.limit,
limit=params.limit,
data=[CLEANDataBase(
predictions_uniprot_annot_id=record["predictions_uniprot_annot_id"],
uniprot=record["uniprot_id"],
Expand All @@ -172,54 +171,48 @@ async def get_data(
) for record in data],
)

# Add pagination links if automatic pagination was applied
if params.auto_paginated:
# Add flag indicating automatic pagination was applied
response.auto_paginated = True

if request:
base_url = str(request.url).split("?")[0]

# Prepare query parameters for pagination links
# For Pydantic v2 compatibility
query_params = {
k: v
for k, v in params.model_dump().items()
if k not in ["auto_paginated", "offset", "limit"]
and v is not None
# Add pagination links
if request:
base_url = str(request.url).split("?")[0]

# Prepare query parameters for pagination links
query_params = {
k: v
for k, v in params.model_dump().items()
if k not in ["auto_paginated", "offset", "limit"]
and v is not None
}

# Set format explicitly if it was provided
if params.format != ResponseFormat.JSON:
query_params["format"] = params.format

current_offset = params.offset or 0
current_limit = params.limit

# Next page link if there are more records
if current_offset + current_limit < total_count:
next_offset = current_offset + current_limit
next_params = {
**query_params,
"offset": next_offset,
"limit": current_limit,
}

# Set format explicitly if it was provided
if params.format != ResponseFormat.JSON:
query_params["format"] = params.format

# Calculate next page link if there are more records
current_offset = params.offset or 0
current_limit = params.limit or total_count
if current_offset + current_limit < total_count:
next_offset = current_offset + current_limit
next_params = {
**query_params,
"offset": next_offset,
"limit": current_limit,
}
response.next = (
f"{base_url}?{urlencode(next_params, doseq=True)}"
)

# Calculate previous page link if not on first page
current_offset = params.offset or 0
current_limit = params.limit or total_count
if current_offset > 0:
prev_offset = max(0, current_offset - current_limit)
prev_params = {
**query_params,
"offset": prev_offset,
"limit": current_limit,
}
response.previous = (
f"{base_url}?{urlencode(prev_params, doseq=True)}"
)
response.next = (
f"{base_url}?{urlencode(next_params, doseq=True)}"
)

# Previous page link if not on first page
if current_offset > 0:
prev_offset = max(0, current_offset - current_limit)
prev_params = {
**query_params,
"offset": prev_offset,
"limit": current_limit,
}
response.previous = (
f"{base_url}?{urlencode(prev_params, doseq=True)}"
)

return response

Expand Down