diff --git a/airflow/api_fastapi/common/parameters.py b/airflow/api_fastapi/common/parameters.py index 4784a03396cb..9b265c7583a7 100644 --- a/airflow/api_fastapi/common/parameters.py +++ b/airflow/api_fastapi/common/parameters.py @@ -19,7 +19,7 @@ from abc import ABC, abstractmethod from datetime import datetime -from typing import TYPE_CHECKING, Any, Callable, Generic, List, Literal, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Generic, List, TypeVar from fastapi import Depends, HTTPException, Query from pendulum.parsing.exceptions import ParserError @@ -128,26 +128,6 @@ def transform_aliases(self, value: str | None) -> str | None: return value -class _OrderByParam(BaseParam[str]): - """Order result by specified attribute ascending or descending.""" - - def __init__(self, attribute: ColumnElement, skip_none: bool = True) -> None: - super().__init__(skip_none) - self.attribute: ColumnElement = attribute - self.value: Literal["asc", "desc"] | None = None - - def to_orm(self, select: Select) -> Select: - if self.value is None and self.skip_none: - return select - asc_stmt = select.order_by(self.attribute.asc()) - if self.value is None: - return asc_stmt - return asc_stmt if self.value == "asc" else select.order_by(self.attribute.desc()) - - def depends(self, order_by: str = "asc") -> _OrderByParam: - return self.set_value(order_by) - - class _DagIdPatternSearch(_SearchParam): """Search on dag_id.""" @@ -331,5 +311,4 @@ def _safe_parse_datetime(date_to_check: str) -> datetime: # DagRun QueryLastDagRunStateFilter = Annotated[_LastDagRunStateFilter, Depends(_LastDagRunStateFilter().depends)] # DAGTags -QueryDagTagOrderBy = Annotated[_OrderByParam, Depends(_OrderByParam(DagTag.name, skip_none=False).depends)] QueryDagTagPatternSearch = Annotated[_DagTagNamePatternSearch, Depends(_DagTagNamePatternSearch().depends)] diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml index 165a0ae123b7..3ca0192a3ca1 100644 --- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml @@ -318,7 +318,7 @@ paths: required: false schema: type: string - default: asc + default: name title: Order By - name: tag_name_pattern in: query diff --git a/airflow/api_fastapi/core_api/routes/public/dags.py b/airflow/api_fastapi/core_api/routes/public/dags.py index a514fd84f386..c7b753b5cdbd 100644 --- a/airflow/api_fastapi/core_api/routes/public/dags.py +++ b/airflow/api_fastapi/core_api/routes/public/dags.py @@ -18,7 +18,7 @@ from __future__ import annotations from fastapi import Depends, HTTPException, Query, Request, Response -from sqlalchemy import distinct, select, update +from sqlalchemy import select, update from sqlalchemy.orm import Session from typing_extensions import Annotated @@ -32,7 +32,6 @@ QueryDagDisplayNamePatternSearch, QueryDagIdPatternSearch, QueryDagIdPatternSearchWithNone, - QueryDagTagOrderBy, QueryDagTagPatternSearch, QueryLastDagRunStateFilter, QueryLimit, @@ -105,12 +104,20 @@ async def get_dags( async def get_dag_tags( limit: QueryLimit, offset: QueryOffset, - order_by: QueryDagTagOrderBy, + order_by: Annotated[ + SortParam, + Depends( + SortParam( + ["name"], + DagTag, + ).dynamic_depends() + ), + ], tag_name_pattern: QueryDagTagPatternSearch, session: Annotated[Session, Depends(get_session)], ) -> DAGTagCollectionResponse: """Get all DAG tags.""" - base_select = select(distinct(DagTag.name)) + base_select = select(DagTag.name).group_by(DagTag.name) dag_tags_select, total_entries = paginated_select( base_select=base_select, filters=[tag_name_pattern], diff --git a/tests/api_fastapi/core_api/routes/public/test_dags.py b/tests/api_fastapi/core_api/routes/public/test_dags.py index ff9993705eb3..a48040482023 100644 --- a/tests/api_fastapi/core_api/routes/public/test_dags.py +++ b/tests/api_fastapi/core_api/routes/public/test_dags.py @@ -457,14 +457,14 @@ class TestGetDagTags(TestDagEndpoint): ), # test order_by ( - {"order_by": "desc"}, + {"order_by": "-name"}, 200, ["tag_2", "tag_1", "example"], 3, ), # test all query params ( - {"tag_name_pattern": "t%", "order_by": "desc", "offset": 1, "limit": 1}, + {"tag_name_pattern": "t%", "order_by": "-name", "offset": 1, "limit": 1}, 200, ["tag_1"], 2, @@ -475,6 +475,19 @@ class TestGetDagTags(TestDagEndpoint): ["tag_1", "tag_2"], 3, ), + # test invalid query params + ( + {"order_by": "dag_id"}, + 400, + None, + None, + ), + ( + {"order_by": "-dag_id"}, + 400, + None, + None, + ), ], ) def test_get_dag_tags(