diff --git a/app/db/queries.py b/app/db/queries.py index dfd9f08..4cf3326 100644 --- a/app/db/queries.py +++ b/app/db/queries.py @@ -83,8 +83,38 @@ async def build_conditions( return where_clause, query_params -def get_query(columns_to_select: str, where_clause: str) -> str: - return f""" +SORTABLE_COLUMNS = { + "accession": "pua.accession", + "amino_acids": "pua.amino_acids", + "organism": "pua.organism", + "curation_status": "pua.curation_status", + "predicted_ec": "puace.max_clean_ec_confidence", +} + +DEFAULT_ORDER_BY = "puace.max_clean_ec_confidence DESC, pua.amino_acids ASC, pua.predictions_uniprot_annot_id ASC" + + +def parse_ordering(ordering: str | None) -> str: + """Parse an ordering string like '-accession' into a SQL ORDER BY clause. + + Returns the default ordering if ordering is None or invalid. + """ + if not ordering: + return DEFAULT_ORDER_BY + + descending = ordering.startswith("-") + field = ordering.lstrip("-") + + if field not in SORTABLE_COLUMNS: + return DEFAULT_ORDER_BY + + direction = "DESC" if descending else "ASC" + col = SORTABLE_COLUMNS[field] + return f"{col} {direction}, pua.predictions_uniprot_annot_id ASC" + + +def get_query(columns_to_select: str, where_clause: str, include_order_by: bool = True, ordering: str | None = None) -> str: + query = f""" SELECT {columns_to_select} FROM cleandb.predictions_uniprot_annot pua @@ -92,9 +122,11 @@ def get_query(columns_to_select: str, where_clause: str) -> str: ON puace.predictions_uniprot_annot_id = pua.predictions_uniprot_annot_id LEFT JOIN cleandb.predictions_uniprot_annot_ec_mv01 puae ON puae.predictions_uniprot_annot_id = pua.predictions_uniprot_annot_id - WHERE {where_clause} - ORDER BY puace.max_clean_ec_confidence DESC, pua.amino_acids ASC, pua.predictions_uniprot_annot_id ASC - """ + WHERE {where_clause}""" + if include_order_by: + query += f""" + ORDER BY {parse_ordering(ordering)}""" + return query async def get_filtered_data( db: Database, params: CLEANSearchQueryParams @@ -120,7 +152,7 @@ async def get_filtered_data( """ # Build the main query - query = get_query(columns_to_select, where_clause) + query = get_query(columns_to_select, where_clause, ordering=params.ordering) # Add pagination if params.limit is not None: @@ -142,7 +174,7 @@ async def get_total_count(db: Database, params: CLEANSearchQueryParams) -> int: """Get total count of records matching the filters.""" where_clause, query_params = await build_conditions(params) - query = get_query("COUNT(*)", where_clause) + query = get_query("COUNT(*)", where_clause, include_order_by=False) # Extract query parameters from the dictionary query_args = list(query_params.values()) diff --git a/app/models/query_params.py b/app/models/query_params.py index f9b33ea..3c00753 100644 --- a/app/models/query_params.py +++ b/app/models/query_params.py @@ -64,6 +64,11 @@ class CLEANSearchQueryParams(BaseModel): None, description="Maximum number of records to return" ) offset: Optional[int] = Field(0, description="Number of records to skip") + ordering: Optional[str] = Field( + None, + description="Column to sort by. Prefix with '-' for descending order. " + "Allowed values: accession, amino_acids, organism, curation_status, predicted_ec", + ) class CLEANTypeaheadQueryParams(BaseModel): """Query parameters for CLEAN typeahead suggestions.""" diff --git a/app/routers/search.py b/app/routers/search.py index 9e0089b..f1dfce3 100644 --- a/app/routers/search.py +++ b/app/routers/search.py @@ -60,6 +60,11 @@ def parse_query_params( None, description="Maximum number of records to return" ), offset: Optional[int] = Query(0, description="Number of records to skip"), + ordering: Optional[str] = Query( + None, + description="Column to sort by. Prefix with '-' for descending order. " + "Allowed values: accession, amino_acids, organism, curation_status, predicted_ec", + ), ) -> CLEANSearchQueryParams: """Parse and validate query parameters.""" try: @@ -80,6 +85,7 @@ def parse_query_params( format=format, limit=limit, offset=offset, + ordering=ordering, ) except Exception as e: logger.error(f"Error parsing query parameters: {e}") @@ -99,19 +105,14 @@ async def get_data( """ try: - # # Get total count for the query (without pagination) - # total_count = await get_total_count(db, params) - - # # Apply automatic pagination if results exceed threshold and no explicit limit provided - # if total_count > settings.AUTO_PAGINATION_THRESHOLD and params.limit is None: - # params.auto_paginated = True - # params.limit = settings.AUTO_PAGINATION_THRESHOLD - # logger.info( - # f"Auto-pagination applied. Results limited to {params.limit} records." - # ) - - params.limit = 500 - # Get data from database (now with potential auto-pagination applied) + # Apply default page size if no explicit limit provided + if params.limit is None: + params.limit = settings.AUTO_PAGINATION_THRESHOLD + + # Get total count for the query (without pagination) + total_count = await get_total_count(db, params) + + # Get data from database data = await get_filtered_data(db, params) # Handle response format @@ -143,12 +144,10 @@ async def get_data( headers={"Content-Disposition": "attachment; filename=CLEAN_data.csv"}, ) else: - # TODO don't we want total_count to be the value returned by get_total_count? - total_count = len(data) response = CLEANSearchResponse( total=total_count, offset=params.offset, - limit=total_count if total_count < params.limit else params.limit, + limit=params.limit, data=[CLEANDataBase( predictions_uniprot_annot_id=record["predictions_uniprot_annot_id"], uniprot=record["uniprot_id"], @@ -172,54 +171,48 @@ async def get_data( ) for record in data], ) - # Add pagination links if automatic pagination was applied - if params.auto_paginated: - # Add flag indicating automatic pagination was applied - response.auto_paginated = True - - if request: - base_url = str(request.url).split("?")[0] - - # Prepare query parameters for pagination links - # For Pydantic v2 compatibility - query_params = { - k: v - for k, v in params.model_dump().items() - if k not in ["auto_paginated", "offset", "limit"] - and v is not None + # Add pagination links + if request: + base_url = str(request.url).split("?")[0] + + # Prepare query parameters for pagination links + query_params = { + k: v + for k, v in params.model_dump().items() + if k not in ["auto_paginated", "offset", "limit"] + and v is not None + } + + # Set format explicitly if it was provided + if params.format != ResponseFormat.JSON: + query_params["format"] = params.format + + current_offset = params.offset or 0 + current_limit = params.limit + + # Next page link if there are more records + if current_offset + current_limit < total_count: + next_offset = current_offset + current_limit + next_params = { + **query_params, + "offset": next_offset, + "limit": current_limit, } - - # Set format explicitly if it was provided - if params.format != ResponseFormat.JSON: - query_params["format"] = params.format - - # Calculate next page link if there are more records - current_offset = params.offset or 0 - current_limit = params.limit or total_count - if current_offset + current_limit < total_count: - next_offset = current_offset + current_limit - next_params = { - **query_params, - "offset": next_offset, - "limit": current_limit, - } - response.next = ( - f"{base_url}?{urlencode(next_params, doseq=True)}" - ) - - # Calculate previous page link if not on first page - current_offset = params.offset or 0 - current_limit = params.limit or total_count - if current_offset > 0: - prev_offset = max(0, current_offset - current_limit) - prev_params = { - **query_params, - "offset": prev_offset, - "limit": current_limit, - } - response.previous = ( - f"{base_url}?{urlencode(prev_params, doseq=True)}" - ) + response.next = ( + f"{base_url}?{urlencode(next_params, doseq=True)}" + ) + + # Previous page link if not on first page + if current_offset > 0: + prev_offset = max(0, current_offset - current_limit) + prev_params = { + **query_params, + "offset": prev_offset, + "limit": current_limit, + } + response.previous = ( + f"{base_url}?{urlencode(prev_params, doseq=True)}" + ) return response