Skip to content

Commit

Permalink
AIP-72: Add "Get Variable" endpoint for Execution API
Browse files Browse the repository at this point in the history
This commit introduces a new endpoint, `/execution/variable/{variable_key}`, in the Execution API to retrieve Variables details.

Same as the Connections PR, it uses a placeholder `check_connection_access` function to validate task permissions for each request.
  • Loading branch information
kaxil committed Nov 8, 2024
1 parent 6c30fc5 commit 4d56ec4
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 0 deletions.
9 changes: 9 additions & 0 deletions airflow/api_fastapi/execution_api/datamodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,15 @@ class ConnectionResponse(BaseModel):
extra: str | None


class VariableResponse(BaseModel):
"""Variable schema for responses with fields that are needed for Runtime."""

model_config = ConfigDict(from_attributes=True)

key: str
val: str | None = Field(alias="value")


# TODO: This is a placeholder for Task Identity Token schema.
class TIToken(BaseModel):
"""Task Identity Token."""
Expand Down
2 changes: 2 additions & 0 deletions airflow/api_fastapi/execution_api/routes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
from airflow.api_fastapi.execution_api.routes.connections import connection_router
from airflow.api_fastapi.execution_api.routes.health import health_router
from airflow.api_fastapi.execution_api.routes.task_instance import ti_router
from airflow.api_fastapi.execution_api.routes.variables import variable_router

execution_api_router = AirflowRouter()
execution_api_router.include_router(connection_router)
execution_api_router.include_router(health_router)
execution_api_router.include_router(ti_router)
execution_api_router.include_router(variable_router)
1 change: 1 addition & 0 deletions airflow/api_fastapi/execution_api/routes/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def get_task_token() -> datamodels.TIToken:

@connection_router.get(
"/{connection_id}",
status_code=status.HTTP_200_OK,
responses={
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
status.HTTP_403_FORBIDDEN: {"description": "Task does not have access to the connection"},
Expand Down
90 changes: 90 additions & 0 deletions airflow/api_fastapi/execution_api/routes/variables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import logging

from fastapi import Depends, HTTPException, status
from typing_extensions import Annotated

from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.execution_api import datamodels
from airflow.models.variable import Variable

# TODO: Add dependency on JWT token
variable_router = AirflowRouter(
prefix="/variable",
tags=["Variable"],
responses={status.HTTP_404_NOT_FOUND: {"description": "Variable not found"}},
)

log = logging.getLogger(__name__)


def get_task_token() -> datamodels.TIToken:
"""TODO: Placeholder for task identity authentication. This should be replaced with actual JWT decoding and validation."""
return datamodels.TIToken(ti_key="test_key")


@variable_router.get(
"/{variable_key}",
status_code=status.HTTP_200_OK,
responses={
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
status.HTTP_403_FORBIDDEN: {"description": "Task does not have access to the variable"},
},
)
def get_variable(
variable_key: str,
token: Annotated[datamodels.TIToken, Depends(get_task_token)],
) -> datamodels.VariableResponse:
"""Get an Airflow Variable."""
if not has_variable_access(variable_key, token):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={
"reason": "access_denied",
"message": f"Task does not have access to variable {variable_key}",
},
)

try:
variable_value = Variable.get(variable_key)
except KeyError:
raise HTTPException(
status.HTTP_404_NOT_FOUND,
detail={
"reason": "not_found",
"message": f"Variable with key '{variable_key}' not found",
},
)

return datamodels.VariableResponse(key=variable_key, value=variable_value)


def has_variable_access(variable_key: str, token: datamodels.TIToken) -> bool:
"""Check if the task has access to the variable."""
# TODO: Placeholder for actual implementation

ti_key = token.ti_key
log.debug(
"Checking access for task instance with key '%s' to variable '%s'",
ti_key,
variable_key,
)
return True
76 changes: 76 additions & 0 deletions tests/api_fastapi/execution_api/routes/test_variable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from unittest import mock

import pytest

from airflow.models.variable import Variable


class TestGetVariable:
@pytest.mark.db_test
def test_variable_get_from_db(self, client, session):
Variable.set(key="var1", value="value", session=session)
session.commit()

response = client.get("/execution/variable/var1")

assert response.status_code == 200
assert response.json() == {"key": "var1", "value": "value"}

# Remove connection
Variable.delete(key="var1", session=session)
session.commit()

@mock.patch.dict(
"os.environ",
{"AIRFLOW_VAR_KEY1": "VALUE"},
)
def test_variable_get_from_env_var(self, client, session):
response = client.get("/execution/variable/key1")

assert response.status_code == 200
assert response.json() == {"key": "key1", "value": "VALUE"}

def test_variable_get_not_found(self, client):
response = client.get("/execution/variable/non_existent_var")

assert response.status_code == 404
assert response.json() == {
"detail": {
"message": "Variable with key 'non_existent_var' not found",
"reason": "not_found",
}
}

def test_variable_get_access_denied(self, client):
with mock.patch(
"airflow.api_fastapi.execution_api.routes.variables.has_variable_access", return_value=False
):
response = client.get("/execution/variable/key1")

# Assert response status code and detail for access denied
assert response.status_code == 403
assert response.json() == {
"detail": {
"reason": "access_denied",
"message": "Task does not have access to variable key1",
}
}

0 comments on commit 4d56ec4

Please sign in to comment.