Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIP-82 Save references between assets and triggers #43826

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions airflow/assets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@

from sqlalchemy.orm.session import Session

from airflow.triggers.base import BaseTrigger


__all__ = ["Asset", "AssetAll", "AssetAny", "Dataset"]


Expand Down Expand Up @@ -266,20 +269,43 @@ class Asset(os.PathLike, BaseAsset):
uri: str
group: str
extra: dict[str, Any]
watchers: list[BaseTrigger] = []

asset_type: ClassVar[str] = ""
__version__: ClassVar[int] = 1

@overload
def __init__(self, name: str, uri: str, *, group: str = "", extra: dict | None = None) -> None:
def __init__(
self,
name: str,
uri: str,
*,
group: str = "",
extra: dict | None = None,
watchers: list[BaseTrigger] | None = None,
) -> None:
"""Canonical; both name and uri are provided."""

@overload
def __init__(self, name: str, *, group: str = "", extra: dict | None = None) -> None:
def __init__(
self,
name: str,
*,
group: str = "",
extra: dict | None = None,
watchers: list[BaseTrigger] | None = None,
) -> None:
"""It's possible to only provide the name, either by keyword or as the only positional argument."""

@overload
def __init__(self, *, uri: str, group: str = "", extra: dict | None = None) -> None:
def __init__(
self,
*,
uri: str,
group: str = "",
extra: dict | None = None,
watchers: list[BaseTrigger] | None = None,
) -> None:
"""It's possible to only provide the URI as a keyword argument."""

def __init__(
Expand All @@ -289,6 +315,7 @@ def __init__(
*,
group: str = "",
extra: dict | None = None,
watchers: list[BaseTrigger] | None = None,
) -> None:
if name is None and uri is None:
raise TypeError("Asset() requires either 'name' or 'uri'")
Expand All @@ -301,10 +328,14 @@ def __init__(
self.uri = _sanitize_uri(_validate_non_empty_identifier(self, fields["uri"], uri))
self.group = _validate_identifier(self, fields["group"], group) if group else self.asset_type
self.extra = _set_extra_default(extra)
self.watchers = watchers or []

def __fspath__(self) -> str:
return self.uri

def __hash__(self) -> int:
return hash(self.uri)

@property
def normalized_uri(self) -> str | None:
"""
Expand Down
50 changes: 50 additions & 0 deletions airflow/dag_processing/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

from airflow.assets import Asset, AssetAlias
from airflow.assets.manager import asset_manager
from airflow.models import Trigger
from airflow.models.asset import (
AssetAliasModel,
AssetModel,
Expand All @@ -55,6 +56,7 @@
from sqlalchemy.orm import Session
from sqlalchemy.sql import Select

from airflow.triggers.base import BaseTrigger
from airflow.typing_compat import Self

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -425,3 +427,51 @@ def add_task_asset_references(
for task_id, asset_id in referenced_outlets
if (task_id, asset_id) not in orm_refs
)

def add_asset_trigger_references(
self, assets: dict[tuple[str, str], AssetModel], *, session: Session
) -> None:
# Update references from assets being used
for name_uri, asset in self.assets.items():
asset_model = assets[name_uri]
trigger_class_path_to_asset_dict: dict[str, BaseTrigger] = {
trigger.serialize()[0]: trigger for trigger in asset.watchers
}

trigger_class_paths_from_asset: set[str] = set(trigger_class_path_to_asset_dict.keys())
trigger_class_paths_from_asset_model: set[str] = {
trigger.classpath for trigger in asset_model.triggers
}

# Optimization: no diff between the DB and DAG definitions, no update needed
if trigger_class_paths_from_asset == trigger_class_paths_from_asset_model:
continue

refs_to_add = trigger_class_paths_from_asset - trigger_class_paths_from_asset_model
refs_to_remove = trigger_class_paths_from_asset_model - trigger_class_paths_from_asset

# Remove old references
asset_model.triggers = [
trigger for trigger in asset_model.triggers if trigger.classpath not in refs_to_remove
]

# Add new references
for trigger_class_path in refs_to_add:
trigger_model = session.scalar(
select(Trigger).where(Trigger.classpath == trigger_class_path).limit(1)
)

# Create the trigger in the DB if it does not exist
if not trigger_model:
trigger_model = Trigger.from_object(trigger_class_path_to_asset_dict[trigger_class_path])
session.add(trigger_model)

asset_model.triggers.append(trigger_model)

# Remove references from assets no longer used
all_assets = session.scalars(select(AssetModel))
# orphan_assets = set()
for asset_model in all_assets:
if (asset_model.name, asset_model.uri) not in self.assets:
asset_model.triggers = []
# orphan_assets.add(asset_model.id)
1 change: 1 addition & 0 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1844,6 +1844,7 @@ def bulk_write_to_db(
asset_op.add_dag_asset_references(orm_dags, orm_assets, session=session)
asset_op.add_dag_asset_alias_references(orm_dags, orm_asset_aliases, session=session)
asset_op.add_task_asset_references(orm_dags, orm_assets, session=session)
asset_op.add_asset_trigger_references(orm_assets, session=session)
session.flush()

@provide_session
Expand Down
6 changes: 3 additions & 3 deletions task_sdk/src/airflow/sdk/definitions/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,9 +557,9 @@ def __lt__(self, other):
def __hash__(self):
hash_components: list[Any] = [type(self)]
for c in _DAG_HASH_ATTRS:
# task_ids returns a list and lists can't be hashed
if c == "task_ids":
val = tuple(self.task_dict)
# If it is a list, convert to tuple because lists can't be hashed
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ashb I made this more generic since watchers is also a list and I thought we could just say "if it is a list, then let's convert it to a tuple". Let me know if you have any concerns

if isinstance(getattr(self, c, None), list):
val = tuple(getattr(self, c))
else:
val = getattr(self, c, None)
try:
Expand Down