Skip to content

Commit

Permalink
Clean up auth manager model
Browse files Browse the repository at this point in the history
  • Loading branch information
vincbeck committed Oct 9, 2024
1 parent 00d6ae7 commit 9d15d92
Show file tree
Hide file tree
Showing 22 changed files with 252 additions and 308 deletions.
9 changes: 5 additions & 4 deletions airflow/api_connexion/endpoints/asset_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
)
from airflow.assets import Asset
from airflow.assets.manager import asset_manager
from airflow.auth.managers.base_auth_manager import ResourceSetAccess
from airflow.models.asset import AssetDagRunQueue, AssetEvent, AssetModel
from airflow.utils import timezone
from airflow.utils.db import get_query_count
Expand Down Expand Up @@ -158,7 +159,7 @@ def _generate_queued_event_where_clause(
dag_id: str | None = None,
uri: str | None = None,
before: str | None = None,
permitted_dag_ids: set[str] | None = None,
permitted_dag_ids: set[str] | ResourceSetAccess | None = None,
) -> list:
"""Get AssetDagRunQueue where clause."""
where_clause = []
Expand All @@ -172,7 +173,7 @@ def _generate_queued_event_where_clause(
)
if before is not None:
where_clause.append(AssetDagRunQueue.created_at < format_datetime(before))
if permitted_dag_ids is not None:
if permitted_dag_ids is not None and permitted_dag_ids != ResourceSetAccess.ALL:
where_clause.append(AssetDagRunQueue.target_dag_id.in_(permitted_dag_ids))
return where_clause

Expand Down Expand Up @@ -272,7 +273,7 @@ def get_asset_queued_events(
*, uri: str, before: str | None = None, session: Session = NEW_SESSION
) -> APIResponse:
"""Get queued asset events for an asset."""
permitted_dag_ids = get_auth_manager().get_permitted_dag_ids(methods=["GET"])
permitted_dag_ids = get_auth_manager().get_accessible_dag_ids(method="GET")
where_clause = _generate_queued_event_where_clause(
uri=uri, before=before, permitted_dag_ids=permitted_dag_ids
)
Expand Down Expand Up @@ -304,7 +305,7 @@ def delete_asset_queued_events(
*, uri: str, before: str | None = None, session: Session = NEW_SESSION
) -> APIResponse:
"""Delete queued asset events for an asset."""
permitted_dag_ids = get_auth_manager().get_permitted_dag_ids(methods=["GET"])
permitted_dag_ids = get_auth_manager().get_accessible_dag_ids(method="GET")
where_clause = _generate_queued_event_where_clause(
uri=uri, before=before, permitted_dag_ids=permitted_dag_ids
)
Expand Down
12 changes: 7 additions & 5 deletions airflow/api_connexion/endpoints/dag_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
dag_schema,
dags_collection_schema,
)
from airflow.auth.managers.base_auth_manager import ResourceSetAccess
from airflow.exceptions import AirflowException, DagNotFound
from airflow.models.dag import DagModel, DagTag
from airflow.utils.airflow_flask_app import get_airflow_app
Expand Down Expand Up @@ -120,9 +121,9 @@ def get_dags(
if dag_id_pattern:
dags_query = dags_query.where(DagModel.dag_id.ilike(f"%{dag_id_pattern}%"))

readable_dags = get_auth_manager().get_permitted_dag_ids(user=g.user)

dags_query = dags_query.where(DagModel.dag_id.in_(readable_dags))
readable_dags = get_auth_manager().get_accessible_dag_ids(method="GET", user=g.user)
if readable_dags != ResourceSetAccess.ALL:
dags_query = dags_query.where(DagModel.dag_id.in_(readable_dags))
if tags:
cond = [DagModel.tags.any(DagTag.name == tag) for tag in tags]
dags_query = dags_query.where(or_(*cond))
Expand Down Expand Up @@ -191,9 +192,10 @@ def patch_dags(limit, session, offset=0, only_active=True, tags=None, dag_id_pat
if dag_id_pattern == "~":
dag_id_pattern = "%"
dags_query = dags_query.where(DagModel.dag_id.ilike(f"%{dag_id_pattern}%"))
editable_dags = get_auth_manager().get_permitted_dag_ids(methods=["PUT"], user=g.user)

dags_query = dags_query.where(DagModel.dag_id.in_(editable_dags))
editable_dags = get_auth_manager().get_accessible_dag_ids(method="PUT", user=g.user)
if editable_dags != ResourceSetAccess.ALL:
dags_query = dags_query.where(DagModel.dag_id.in_(editable_dags))
if tags:
cond = [DagModel.tags.any(DagTag.name == tag) for tag in tags]
dags_query = dags_query.where(or_(*cond))
Expand Down
13 changes: 7 additions & 6 deletions airflow/api_connexion/endpoints/dag_run_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
TaskInstanceReferenceCollection,
task_instance_reference_collection_schema,
)
from airflow.auth.managers.base_auth_manager import ResourceSetAccess
from airflow.auth.managers.models.resource_details import DagAccessEntity
from airflow.exceptions import ParamValidationError
from airflow.models import DagModel, DagRun
Expand Down Expand Up @@ -225,9 +226,9 @@ def get_dag_runs(

# This endpoint allows specifying ~ as the dag_id to retrieve DAG Runs for all DAGs.
if dag_id == "~":
query = query.where(
DagRun.dag_id.in_(get_auth_manager().get_permitted_dag_ids(methods=["GET"], user=g.user))
)
accessible_dags = get_auth_manager().get_accessible_dag_ids(method="GET", user=g.user)
if accessible_dags != ResourceSetAccess.ALL:
query = query.where(DagRun.dag_id.in_(accessible_dags))
else:
query = query.where(DagRun.dag_id == dag_id)

Expand Down Expand Up @@ -268,12 +269,12 @@ def get_dag_runs_batch(*, session: Session = NEW_SESSION) -> APIResponse:
except ValidationError as err:
raise BadRequest(detail=str(err.messages))

readable_dag_ids = get_auth_manager().get_permitted_dag_ids(methods=["GET"], user=g.user)
readable_dag_ids = get_auth_manager().get_accessible_dag_ids(method="GET", user=g.user)
query = select(DagRun)
if data.get("dag_ids"):
if data.get("dag_ids") and readable_dag_ids != ResourceSetAccess.ALL:
dag_ids = set(data["dag_ids"]) & set(readable_dag_ids)
query = query.where(DagRun.dag_id.in_(dag_ids))
else:
elif readable_dag_ids != ResourceSetAccess.ALL:
query = query.where(DagRun.dag_id.in_(readable_dag_ids))

states = data.get("states")
Expand Down
11 changes: 7 additions & 4 deletions airflow/api_connexion/endpoints/dag_stats_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from airflow.api_connexion.schemas.dag_stats_schema import (
dag_stats_collection_schema,
)
from airflow.auth.managers.base_auth_manager import ResourceSetAccess
from airflow.auth.managers.models.resource_details import DagAccessEntity
from airflow.models.dag import DagRun
from airflow.utils.session import NEW_SESSION, provide_session
Expand All @@ -41,21 +42,23 @@
@provide_session
def get_dag_stats(*, dag_ids: str, session: Session = NEW_SESSION) -> APIResponse:
"""Get Dag statistics."""
allowed_dag_ids = get_auth_manager().get_permitted_dag_ids(methods=["GET"], user=g.user)
allowed_dag_ids = get_auth_manager().get_accessible_dag_ids(method="GET", user=g.user)
dags_list = set(dag_ids.split(","))
filter_dag_ids = dags_list.intersection(allowed_dag_ids)

if allowed_dag_ids != ResourceSetAccess.ALL:
dags_list = dags_list.intersection(allowed_dag_ids)

query = (
select(DagRun.dag_id, DagRun.state, func.count(DagRun.state))
.group_by(DagRun.dag_id, DagRun.state)
.where(DagRun.dag_id.in_(filter_dag_ids))
.where(DagRun.dag_id.in_(dags_list))
)
dag_state_stats = session.execute(query)

dag_state_data = {(dag_id, state): count for dag_id, state, count in dag_state_stats}
dag_stats = {
dag_id: [{"state": state, "count": dag_state_data.get((dag_id, state), 0)} for state in DagRunState]
for dag_id in filter_dag_ids
for dag_id in dags_list
}

dags = [{"dag_id": stat, "stats": dag_stats[stat]} for stat in dag_stats]
Expand Down
8 changes: 5 additions & 3 deletions airflow/api_connexion/endpoints/dag_warning_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@
DagWarningCollection,
dag_warning_collection_schema,
)
from airflow.api_connexion.security import get_readable_dags
from airflow.auth.managers.base_auth_manager import ResourceSetAccess
from airflow.auth.managers.models.resource_details import DagAccessEntity
from airflow.models.dagwarning import DagWarning as DagWarningModel
from airflow.utils.db import get_query_count
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.www.extensions.init_auth_manager import get_auth_manager

if TYPE_CHECKING:
from sqlalchemy.orm import Session
Expand Down Expand Up @@ -61,8 +62,9 @@ def get_dag_warnings(
if dag_id:
query = query.where(DagWarningModel.dag_id == dag_id)
else:
readable_dags = get_readable_dags()
query = query.where(DagWarningModel.dag_id.in_(readable_dags))
readable_dags = get_auth_manager().get_accessible_dag_ids(method="GET")
if readable_dags != ResourceSetAccess.ALL:
query = query.where(DagWarningModel.dag_id.in_(readable_dags))
if warning_type:
query = query.where(DagWarningModel.warning_type == warning_type)
total_entries = get_query_count(query, session=session)
Expand Down
14 changes: 6 additions & 8 deletions airflow/api_connexion/endpoints/import_error_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import_error_collection_schema,
import_error_schema,
)
from airflow.auth.managers.base_auth_manager import ResourceSetAccess
from airflow.auth.managers.models.resource_details import AccessView, DagDetails
from airflow.models.dag import DagModel
from airflow.models.errors import ParseImportError
Expand All @@ -53,9 +54,8 @@ def get_import_error(*, import_error_id: int, session: Session = NEW_SESSION) ->
)
session.expunge(error)

can_read_all_dags = get_auth_manager().is_authorized_dag(method="GET")
if not can_read_all_dags:
readable_dag_ids = security.get_readable_dags()
readable_dag_ids = get_auth_manager().get_accessible_dag_ids(method="GET")
if readable_dag_ids != ResourceSetAccess.ALL:
file_dag_ids = {
dag_id[0]
for dag_id in session.query(DagModel.dag_id).filter(DagModel.fileloc == error.filename).all()
Expand Down Expand Up @@ -89,19 +89,17 @@ def get_import_errors(
query = select(ParseImportError)
query = apply_sorting(query, order_by, to_replace, allowed_sort_attrs)

can_read_all_dags = get_auth_manager().is_authorized_dag(method="GET")

if not can_read_all_dags:
readable_dag_ids = get_auth_manager().get_accessible_dag_ids(method="GET")
if readable_dag_ids != ResourceSetAccess.ALL:
# if the user doesn't have access to all DAGs, only display errors from visible DAGs
readable_dag_ids = security.get_readable_dags()
dagfiles_stmt = select(DagModel.fileloc).distinct().where(DagModel.dag_id.in_(readable_dag_ids))
query = query.where(ParseImportError.filename.in_(dagfiles_stmt))
count_query = count_query.where(ParseImportError.filename.in_(dagfiles_stmt))

total_entries = session.scalars(count_query).one()
import_errors = session.scalars(query.offset(offset).limit(limit)).all()

if not can_read_all_dags:
if readable_dag_ids != ResourceSetAccess.ALL:
for import_error in import_errors:
# Check if user has read access to all the DAGs defined in the file
file_dag_ids = (
Expand Down
13 changes: 8 additions & 5 deletions airflow/api_connexion/endpoints/task_instance_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
task_instance_reference_schema,
task_instance_schema,
)
from airflow.api_connexion.security import get_readable_dags
from airflow.auth.managers.base_auth_manager import ResourceSetAccess
from airflow.auth.managers.models.resource_details import DagAccessEntity, DagDetails
from airflow.exceptions import TaskNotFound
from airflow.models import SlaMiss
Expand Down Expand Up @@ -359,8 +359,10 @@ def get_task_instances(

if dag_id != "~":
base_query = base_query.where(TI.dag_id == dag_id)
else:
base_query = base_query.where(TI.dag_id.in_(get_readable_dags()))
elif (
accessible_dags := get_auth_manager().get_accessible_dag_ids(method="GET")
) != ResourceSetAccess.ALL:
base_query = base_query.where(TI.dag_id.in_(accessible_dags))
if dag_run_id != "~":
base_query = base_query.where(TI.run_id == dag_run_id)
base_query = _apply_range_filter(
Expand Down Expand Up @@ -432,12 +434,13 @@ def get_task_instances_batch(session: Session = NEW_SESSION) -> APIResponse:
if not get_auth_manager().batch_is_authorized_dag(requests):
raise PermissionDenied(detail=f"User not allowed to access some of these DAGs: {list(dag_ids)}")
else:
dag_ids = get_auth_manager().get_permitted_dag_ids(user=g.user)
dag_ids = get_auth_manager().get_accessible_dag_ids(method="GET", user=g.user)

states = _convert_ti_states(data["state"])
base_query = select(TI).join(TI.dag_run)

base_query = _apply_array_filter(base_query, key=TI.dag_id, values=dag_ids)
if dag_ids != ResourceSetAccess.ALL:
base_query = _apply_array_filter(base_query, key=TI.dag_id, values=dag_ids)
base_query = _apply_array_filter(base_query, key=TI.run_id, values=data["dag_run_ids"])
base_query = _apply_array_filter(base_query, key=TI.task_id, values=data["task_ids"])
base_query = _apply_range_filter(
Expand Down
6 changes: 4 additions & 2 deletions airflow/api_connexion/endpoints/xcom_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
xcom_schema_native,
xcom_schema_string,
)
from airflow.auth.managers.base_auth_manager import ResourceSetAccess
from airflow.auth.managers.models.resource_details import DagAccessEntity
from airflow.models import DagRun as DR, XCom
from airflow.settings import conf
Expand Down Expand Up @@ -61,8 +62,9 @@ def get_xcom_entries(
"""Get all XCom values."""
query = select(XCom)
if dag_id == "~":
readable_dag_ids = get_auth_manager().get_permitted_dag_ids(methods=["GET"], user=g.user)
query = query.where(XCom.dag_id.in_(readable_dag_ids))
readable_dag_ids = get_auth_manager().get_accessible_dag_ids(method="GET", user=g.user)
if readable_dag_ids != ResourceSetAccess.ALL:
query = query.where(XCom.dag_id.in_(readable_dag_ids))
query = query.join(DR, and_(XCom.dag_id == DR.dag_id, XCom.run_id == DR.run_id))
else:
query = query.where(XCom.dag_id == dag_id)
Expand Down
36 changes: 6 additions & 30 deletions airflow/api_connexion/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from functools import wraps
from typing import TYPE_CHECKING, Callable, TypeVar, cast

from flask import Response, g
from flask import Response

from airflow.api_connexion.exceptions import PermissionDenied, Unauthenticated
from airflow.auth.managers.models.resource_details import (
Expand Down Expand Up @@ -114,31 +114,11 @@ def requires_access_dag(
) -> Callable[[T], T]:
def _is_authorized_callback(dag_id: str):
def callback() -> bool | DagAccessEntity:
if dag_id:
# a DAG id is provided; is the user authorized to access this DAG?
return get_auth_manager().is_authorized_dag(
method=method,
access_entity=access_entity,
details=DagDetails(id=dag_id),
)
else:
# here we know dag_id is not provided.
# check is the user authorized to access all DAGs?
if get_auth_manager().is_authorized_dag(
method=method,
access_entity=access_entity,
):
return True
elif access_entity:
# no dag_id provided, and user does not have access to all dags
return False

# dag_id is not provided, and the user is not authorized to access *all* DAGs
# so we check that the user can access at least *one* dag
# but we leave it to the endpoint function to properly restrict access beyond that
if method not in ("GET", "PUT"):
return False
return any(get_auth_manager().get_permitted_dag_ids(methods=[method]))
return get_auth_manager().is_authorized_dag(
method=method,
access_entity=access_entity,
details=DagDetails(id=dag_id),
)

return callback

Expand Down Expand Up @@ -250,7 +230,3 @@ def decorated(*args, **kwargs):
return cast(T, decorated)

return requires_access_decorator


def get_readable_dags() -> set[str]:
return get_auth_manager().get_permitted_dag_ids(user=g.user)
2 changes: 1 addition & 1 deletion airflow/api_fastapi/db/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def paginated_select(

# TODO: Re-enable when permissions are handled. Readable / writable entities,
# for instance:
# readable_dags = get_auth_manager().get_permitted_dag_ids(user=g.user)
# readable_dags = get_auth_manager().get_accessible_dag_ids(user=g.user)
# dags_select = dags_select.where(DagModel.dag_id.in_(readable_dags))

base_select = apply_filters_to_select(base_select, [order_by, offset, limit])
Expand Down
Loading

0 comments on commit 9d15d92

Please sign in to comment.