From f924ecb6e3d6e85a83b261cc7adc3a3c05c2b1ee Mon Sep 17 00:00:00 2001 From: Bugra Ozturk Date: Tue, 12 Nov 2024 18:00:16 +0100 Subject: [PATCH] AIP-84 Migrate test a connection to FastAPI API (#43766) * Migrate test a connection to fastapi * Remove NotFound error from docs * Remove Validation error which is handled by fastapi * Remove async for complying convention for methods * Fix pre-commit ruff --- .../endpoints/connection_endpoint.py | 1 + .../core_api/datamodels/connections.py | 7 ++ .../core_api/openapi/v1-generated.yaml | 61 +++++++++++++++ .../core_api/routes/public/connections.py | 46 ++++++++++++ airflow/ui/openapi-gen/queries/common.ts | 3 + airflow/ui/openapi-gen/queries/queries.ts | 43 +++++++++++ .../ui/openapi-gen/requests/schemas.gen.ts | 17 +++++ .../ui/openapi-gen/requests/services.gen.ts | 30 ++++++++ airflow/ui/openapi-gen/requests/types.gen.ts | 37 +++++++++ .../routes/public/test_connections.py | 75 +++++++++++++++---- 10 files changed, 307 insertions(+), 13 deletions(-) diff --git a/airflow/api_connexion/endpoints/connection_endpoint.py b/airflow/api_connexion/endpoints/connection_endpoint.py index a07e13b35255..ed7e896fa7ae 100644 --- a/airflow/api_connexion/endpoints/connection_endpoint.py +++ b/airflow/api_connexion/endpoints/connection_endpoint.py @@ -182,6 +182,7 @@ def post_connection(*, session: Session = NEW_SESSION) -> APIResponse: raise AlreadyExists(detail=f"Connection already exist. ID: {conn_id}") +@mark_fastapi_migration_done @security.requires_access_connection("POST") def test_connection() -> APIResponse: """ diff --git a/airflow/api_fastapi/core_api/datamodels/connections.py b/airflow/api_fastapi/core_api/datamodels/connections.py index c5956b6ec517..7b23682cc8ef 100644 --- a/airflow/api_fastapi/core_api/datamodels/connections.py +++ b/airflow/api_fastapi/core_api/datamodels/connections.py @@ -67,6 +67,13 @@ class ConnectionCollectionResponse(BaseModel): total_entries: int +class ConnectionTestResponse(BaseModel): + """Connection Test serializer for responses.""" + + status: bool + message: str + + # Request Models class ConnectionBody(BaseModel): """Connection Serializer for requests body.""" diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml index cc73ad9a820c..b9b33b35e243 100644 --- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml @@ -1213,6 +1213,53 @@ paths: application/json: schema: $ref: '#/components/schemas/HTTPValidationError' + /public/connections/test: + post: + tags: + - Connection + summary: Test Connection + description: 'Test an API connection. + + + This method first creates an in-memory transient conn_id & exports that to + an env var, + + as some hook classes tries to find out the `conn` from their __init__ method + & errors out if not found. + + It also deletes the conn id env variable after the test.' + operationId: test_connection + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/ConnectionBody' + required: true + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/ConnectionTestResponse' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + '403': + description: Forbidden + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' /public/dags/{dag_id}/dagRuns/{dag_run_id}: get: tags: @@ -3568,6 +3615,20 @@ components: - extra title: ConnectionResponse description: Connection serializer for responses. + ConnectionTestResponse: + properties: + status: + type: boolean + title: Status + message: + type: string + title: Message + type: object + required: + - status + - message + title: ConnectionTestResponse + description: Connection Test serializer for responses. DAGCollectionResponse: properties: dags: diff --git a/airflow/api_fastapi/core_api/routes/public/connections.py b/airflow/api_fastapi/core_api/routes/public/connections.py index dbd2091b25ef..4797d17487f5 100644 --- a/airflow/api_fastapi/core_api/routes/public/connections.py +++ b/airflow/api_fastapi/core_api/routes/public/connections.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import os from typing import Annotated from fastapi import Depends, HTTPException, Query, status @@ -29,10 +30,14 @@ ConnectionBody, ConnectionCollectionResponse, ConnectionResponse, + ConnectionTestResponse, ) from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc +from airflow.configuration import conf from airflow.models import Connection +from airflow.secrets.environment_variables import CONN_ENV_PREFIX from airflow.utils import helpers +from airflow.utils.strings import get_random_string connections_router = AirflowRouter(tags=["Connection"], prefix="/connections") @@ -181,3 +186,44 @@ def patch_connection( for key, val in data.items(): setattr(connection, key, val) return ConnectionResponse.model_validate(connection, from_attributes=True) + + +@connections_router.post( + "/test", + responses=create_openapi_http_exception_doc( + [ + status.HTTP_401_UNAUTHORIZED, + status.HTTP_403_FORBIDDEN, + ] + ), +) +def test_connection( + test_body: ConnectionBody, +) -> ConnectionTestResponse: + """ + Test an API connection. + + This method first creates an in-memory transient conn_id & exports that to an env var, + as some hook classes tries to find out the `conn` from their __init__ method & errors out if not found. + It also deletes the conn id env variable after the test. + """ + if conf.get("core", "test_connection", fallback="Disabled").lower().strip() != "enabled": + raise HTTPException( + 403, + "Testing connections is disabled in Airflow configuration. " + "Contact your deployment admin to enable it.", + ) + + transient_conn_id = get_random_string() + conn_env_var = f"{CONN_ENV_PREFIX}{transient_conn_id.upper()}" + try: + data = test_body.model_dump(by_alias=True) + data["conn_id"] = transient_conn_id + conn = Connection(**data) + os.environ[conn_env_var] = conn.get_uri() + test_status, test_message = conn.test_connection() + return ConnectionTestResponse.model_validate( + {"status": test_status, "message": test_message}, from_attributes=True + ) + finally: + os.environ.pop(conn_env_var, None) diff --git a/airflow/ui/openapi-gen/queries/common.ts b/airflow/ui/openapi-gen/queries/common.ts index 164410068443..fd9fefda7dfb 100644 --- a/airflow/ui/openapi-gen/queries/common.ts +++ b/airflow/ui/openapi-gen/queries/common.ts @@ -972,6 +972,9 @@ export type BackfillServiceCreateBackfillMutationResult = Awaited< export type ConnectionServicePostConnectionMutationResult = Awaited< ReturnType >; +export type ConnectionServiceTestConnectionMutationResult = Awaited< + ReturnType +>; export type PoolServicePostPoolMutationResult = Awaited< ReturnType >; diff --git a/airflow/ui/openapi-gen/queries/queries.ts b/airflow/ui/openapi-gen/queries/queries.ts index 421288e2c55b..968419632ae7 100644 --- a/airflow/ui/openapi-gen/queries/queries.ts +++ b/airflow/ui/openapi-gen/queries/queries.ts @@ -1646,6 +1646,49 @@ export const useConnectionServicePostConnection = < }) as unknown as Promise, ...options, }); +/** + * Test Connection + * Test an API connection. + * + * This method first creates an in-memory transient conn_id & exports that to an env var, + * as some hook classes tries to find out the `conn` from their __init__ method & errors out if not found. + * It also deletes the conn id env variable after the test. + * @param data The data for the request. + * @param data.requestBody + * @returns ConnectionTestResponse Successful Response + * @throws ApiError + */ +export const useConnectionServiceTestConnection = < + TData = Common.ConnectionServiceTestConnectionMutationResult, + TError = unknown, + TContext = unknown, +>( + options?: Omit< + UseMutationOptions< + TData, + TError, + { + requestBody: ConnectionBody; + }, + TContext + >, + "mutationFn" + >, +) => + useMutation< + TData, + TError, + { + requestBody: ConnectionBody; + }, + TContext + >({ + mutationFn: ({ requestBody }) => + ConnectionService.testConnection({ + requestBody, + }) as unknown as Promise, + ...options, + }); /** * Post Pool * Create a Pool. diff --git a/airflow/ui/openapi-gen/requests/schemas.gen.ts b/airflow/ui/openapi-gen/requests/schemas.gen.ts index 5700d3899240..106991531db6 100644 --- a/airflow/ui/openapi-gen/requests/schemas.gen.ts +++ b/airflow/ui/openapi-gen/requests/schemas.gen.ts @@ -474,6 +474,23 @@ export const $ConnectionResponse = { description: "Connection serializer for responses.", } as const; +export const $ConnectionTestResponse = { + properties: { + status: { + type: "boolean", + title: "Status", + }, + message: { + type: "string", + title: "Message", + }, + }, + type: "object", + required: ["status", "message"], + title: "ConnectionTestResponse", + description: "Connection Test serializer for responses.", +} as const; + export const $DAGCollectionResponse = { properties: { dags: { diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts b/airflow/ui/openapi-gen/requests/services.gen.ts index 56b2a3bf3d6d..12897f9d4934 100644 --- a/airflow/ui/openapi-gen/requests/services.gen.ts +++ b/airflow/ui/openapi-gen/requests/services.gen.ts @@ -47,6 +47,8 @@ import type { GetConnectionsResponse, PostConnectionData, PostConnectionResponse, + TestConnectionData, + TestConnectionResponse, GetDagRunData, GetDagRunResponse, DeleteDagRunData, @@ -772,6 +774,34 @@ export class ConnectionService { }, }); } + + /** + * Test Connection + * Test an API connection. + * + * This method first creates an in-memory transient conn_id & exports that to an env var, + * as some hook classes tries to find out the `conn` from their __init__ method & errors out if not found. + * It also deletes the conn id env variable after the test. + * @param data The data for the request. + * @param data.requestBody + * @returns ConnectionTestResponse Successful Response + * @throws ApiError + */ + public static testConnection( + data: TestConnectionData, + ): CancelablePromise { + return __request(OpenAPI, { + method: "POST", + url: "/public/connections/test", + body: data.requestBody, + mediaType: "application/json", + errors: { + 401: "Unauthorized", + 403: "Forbidden", + 422: "Validation Error", + }, + }); + } } export class DagRunService { diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow/ui/openapi-gen/requests/types.gen.ts index e200e70037b8..efd8f602528e 100644 --- a/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow/ui/openapi-gen/requests/types.gen.ts @@ -113,6 +113,14 @@ export type ConnectionResponse = { extra: string | null; }; +/** + * Connection Test serializer for responses. + */ +export type ConnectionTestResponse = { + status: boolean; + message: string; +}; + /** * DAG Collection serializer for responses. */ @@ -1005,6 +1013,12 @@ export type PostConnectionData = { export type PostConnectionResponse = ConnectionResponse; +export type TestConnectionData = { + requestBody: ConnectionBody; +}; + +export type TestConnectionResponse = ConnectionTestResponse; + export type GetDagRunData = { dagId: string; dagRunId: string; @@ -1845,6 +1859,29 @@ export type $OpenApiTs = { }; }; }; + "/public/connections/test": { + post: { + req: TestConnectionData; + res: { + /** + * Successful Response + */ + 200: ConnectionTestResponse; + /** + * Unauthorized + */ + 401: HTTPExceptionResponse; + /** + * Forbidden + */ + 403: HTTPExceptionResponse; + /** + * Validation Error + */ + 422: HTTPValidationError; + }; + }; + }; "/public/dags/{dag_id}/dagRuns/{dag_run_id}": { get: { req: GetDagRunData; diff --git a/tests/api_fastapi/core_api/routes/public/test_connections.py b/tests/api_fastapi/core_api/routes/public/test_connections.py index 67b58007f531..599068d659ae 100644 --- a/tests/api_fastapi/core_api/routes/public/test_connections.py +++ b/tests/api_fastapi/core_api/routes/public/test_connections.py @@ -16,9 +16,13 @@ # under the License. from __future__ import annotations +import os +from unittest import mock + import pytest from airflow.models import Connection +from airflow.secrets.environment_variables import CONN_ENV_PREFIX from airflow.utils.session import provide_session from tests_common.test_utils.db import clear_db_connections @@ -296,7 +300,7 @@ def test_post_should_response_201_redacted_password(self, test_client, body, exp class TestPatchConnection(TestConnectionEndpoint): @pytest.mark.parametrize( - "payload", + "body", [ {"connection_id": TEST_CONN_ID, "conn_type": TEST_CONN_TYPE, "extra": '{"key": "var"}'}, {"connection_id": TEST_CONN_ID, "conn_type": TEST_CONN_TYPE, "host": "test_host_patch"}, @@ -317,14 +321,14 @@ class TestPatchConnection(TestConnectionEndpoint): ], ) @provide_session - def test_patch_should_respond_200(self, test_client, payload, session): + def test_patch_should_respond_200(self, test_client, body, session): self.create_connection() - response = test_client.patch(f"/public/connections/{TEST_CONN_ID}", json=payload) + response = test_client.patch(f"/public/connections/{TEST_CONN_ID}", json=body) assert response.status_code == 200 @pytest.mark.parametrize( - "payload, updated_connection, update_mask", + "body, updated_connection, update_mask", [ ( {"connection_id": TEST_CONN_ID, "conn_type": TEST_CONN_TYPE, "extra": '{"key": "var"}'}, @@ -414,17 +418,17 @@ def test_patch_should_respond_200(self, test_client, payload, session): ], ) def test_patch_should_respond_200_with_update_mask( - self, test_client, session, payload, updated_connection, update_mask + self, test_client, session, body, updated_connection, update_mask ): self.create_connection() - response = test_client.patch(f"/public/connections/{TEST_CONN_ID}", json=payload, params=update_mask) + response = test_client.patch(f"/public/connections/{TEST_CONN_ID}", json=body, params=update_mask) assert response.status_code == 200 connection = session.query(Connection).filter_by(conn_id=TEST_CONN_ID).first() assert connection.password is None assert response.json() == updated_connection @pytest.mark.parametrize( - "payload", + "body", [ { "connection_id": "i_am_not_a_connection", @@ -456,9 +460,9 @@ def test_patch_should_respond_200_with_update_mask( }, ], ) - def test_patch_should_respond_400(self, test_client, payload): + def test_patch_should_respond_400(self, test_client, body): self.create_connection() - response = test_client.patch(f"/public/connections/{TEST_CONN_ID}", json=payload) + response = test_client.patch(f"/public/connections/{TEST_CONN_ID}", json=body) assert response.status_code == 400 print(response.json()) assert { @@ -466,7 +470,7 @@ def test_patch_should_respond_400(self, test_client, payload): } == response.json() @pytest.mark.parametrize( - "payload", + "body", [ { "connection_id": "i_am_not_a_connection", @@ -498,11 +502,11 @@ def test_patch_should_respond_400(self, test_client, payload): }, ], ) - def test_patch_should_respond_404(self, test_client, payload): - response = test_client.patch(f"/public/connections/{payload['connection_id']}", json=payload) + def test_patch_should_respond_404(self, test_client, body): + response = test_client.patch(f"/public/connections/{body['connection_id']}", json=body) assert response.status_code == 404 assert { - "detail": f"The Connection with connection_id: `{payload['connection_id']}` was not found", + "detail": f"The Connection with connection_id: `{body['connection_id']}` was not found", } == response.json() @pytest.mark.enable_redact @@ -563,3 +567,48 @@ def test_patch_should_response_200_redacted_password(self, test_client, session, response = test_client.patch(f"/public/connections/{TEST_CONN_ID}", json=body) assert response.status_code == 200 assert response.json() == expected_response + + +class TestConnection(TestConnectionEndpoint): + @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) + @pytest.mark.parametrize( + "body", + [ + {"connection_id": TEST_CONN_ID, "conn_type": "sqlite"}, + {"connection_id": TEST_CONN_ID, "conn_type": "ftp"}, + ], + ) + def test_should_respond_200(self, test_client, body): + response = test_client.post("/public/connections/test", json=body) + assert response.status_code == 200 + assert response.json() == { + "status": True, + "message": "Connection successfully tested", + } + + @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) + @pytest.mark.parametrize( + "body", + [ + {"connection_id": TEST_CONN_ID, "conn_type": "sqlite"}, + {"connection_id": TEST_CONN_ID, "conn_type": "ftp"}, + ], + ) + def test_connection_env_is_cleaned_after_run(self, test_client, body): + test_client.post("/public/connections/test", json=body) + assert not any([key.startswith(CONN_ENV_PREFIX) for key in os.environ.keys()]) + + @pytest.mark.parametrize( + "body", + [ + {"connection_id": TEST_CONN_ID, "conn_type": "sqlite"}, + {"connection_id": TEST_CONN_ID, "conn_type": "ftp"}, + ], + ) + def test_should_respond_403_by_default(self, test_client, body): + response = test_client.post("/public/connections/test", json=body) + assert response.status_code == 403 + assert response.json() == { + "detail": "Testing connections is disabled in Airflow configuration. " + "Contact your deployment admin to enable it." + }