Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix asset name uri handling #43774

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
10 changes: 6 additions & 4 deletions airflow/assets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def as_expression(self) -> Any:

:meta private:
"""
return self.uri
return {"asset": {"uri": self.uri, "name": self.name, "group": self.group}}

def iter_assets(self) -> Iterator[tuple[str, Asset]]:
yield self.uri, self
Expand Down Expand Up @@ -390,7 +390,8 @@ def __init__(self, *objects: BaseAsset) -> None:
raise TypeError("expect asset expressions in condition")

self.objects = [
_AssetAliasCondition(obj.name) if isinstance(obj, AssetAlias) else obj for obj in objects
_AssetAliasCondition(name=obj.name, group=obj.group) if isinstance(obj, AssetAlias) else obj
for obj in objects
]

def evaluate(self, statuses: dict[str, bool]) -> bool:
Expand Down Expand Up @@ -450,8 +451,9 @@ class _AssetAliasCondition(AssetAny):
:meta private:
"""

def __init__(self, name: str) -> None:
def __init__(self, name: str, group: str) -> None:
self.name = name
self.group = group
self.objects = expand_alias_to_assets(name)

def __repr__(self) -> str:
Expand All @@ -463,7 +465,7 @@ def as_expression(self) -> Any:

:meta private:
"""
return {"alias": self.name}
return {"alias": {"name": self.name, "group": self.group}}

def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]:
yield self.name, AssetAlias(self.name)
Expand Down
6 changes: 3 additions & 3 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def encode_asset_condition(var: BaseAsset) -> dict[str, Any]:
:meta private:
"""
if isinstance(var, Asset):
return {"__type": DAT.ASSET, "name": var.name, "uri": var.uri, "extra": var.extra}
return {"__type": DAT.ASSET, "name": var.name, "uri": var.uri, "group": var.group, "extra": var.extra}
if isinstance(var, AssetAlias):
return {"__type": DAT.ASSET_ALIAS, "name": var.name}
if isinstance(var, AssetAll):
Expand All @@ -273,7 +273,7 @@ def decode_asset_condition(var: dict[str, Any]) -> BaseAsset:
"""
dat = var["__type"]
if dat == DAT.ASSET:
return Asset(uri=var["uri"], name=var["name"], extra=var["extra"])
return Asset(name=var["name"], uri=var["uri"], group=var["group"], extra=var["extra"])
if dat == DAT.ASSET_ALL:
return AssetAll(*(decode_asset_condition(x) for x in var["objects"]))
if dat == DAT.ASSET_ANY:
Expand Down Expand Up @@ -1053,7 +1053,7 @@ def detect_task_dependencies(task: Operator) -> list[DagDependency]:
)
)
elif isinstance(obj, AssetAlias):
cond = _AssetAliasCondition(obj.name)
cond = _AssetAliasCondition(name=obj.name, group=obj.group)

deps.extend(cond.iter_dag_dependencies(source=task.dag_id, target=""))
return deps
Expand Down
4 changes: 3 additions & 1 deletion airflow/timetables/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ def __init__(self, assets: BaseAsset) -> None:
super().__init__()
self.asset_condition = assets
if isinstance(self.asset_condition, AssetAlias):
self.asset_condition = _AssetAliasCondition(self.asset_condition.name)
self.asset_condition = _AssetAliasCondition(
name=self.asset_condition.name, group=self.asset_condition.group
)

if not next(self.asset_condition.iter_assets(), False):
self._summary = AssetTriggeredTimetable.UNRESOLVED_ALIAS_SUMMARY
Expand Down
30 changes: 24 additions & 6 deletions providers/tests/openlineage/plugins/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,9 +333,9 @@ def test_serialize_timetable():
from airflow.timetables.simple import AssetTriggeredTimetable

asset = AssetAny(
Asset("2"),
AssetAlias("example-alias"),
Asset("3"),
Asset(name="2", uri="test://2", group="test-group"),
AssetAlias(name="example-alias", group="test-group"),
Asset(name="3", uri="test://3", group="test-group"),
AssetAll(AssetAlias("this-should-not-be-seen"), Asset("4")),
)
dag = MagicMock()
Expand All @@ -346,14 +346,32 @@ def test_serialize_timetable():
"asset_condition": {
"__type": DagAttributeTypes.ASSET_ANY,
"objects": [
{"__type": DagAttributeTypes.ASSET, "extra": {}, "uri": "2"},
{
"__type": DagAttributeTypes.ASSET,
"extra": {},
"uri": "test://2/",
"name": "2",
"group": "test-group",
},
{"__type": DagAttributeTypes.ASSET_ANY, "objects": []},
{"__type": DagAttributeTypes.ASSET, "extra": {}, "uri": "3"},
{
"__type": DagAttributeTypes.ASSET,
"extra": {},
"uri": "test://3/",
"name": "3",
"group": "test-group",
},
{
"__type": DagAttributeTypes.ASSET_ALL,
"objects": [
{"__type": DagAttributeTypes.ASSET_ANY, "objects": []},
{"__type": DagAttributeTypes.ASSET, "extra": {}, "uri": "4"},
{
"__type": DagAttributeTypes.ASSET,
"extra": {},
"uri": "4",
"name": "4",
"group": "",
},
],
},
],
Expand Down
16 changes: 13 additions & 3 deletions tests/api_fastapi/core_api/routes/ui/test_assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def cleanup():


def test_next_run_assets(test_client, dag_maker):
with dag_maker(dag_id="upstream", schedule=[Asset(uri="s3://bucket/key/1")], serialized=True):
with dag_maker(dag_id="upstream", schedule=[Asset(uri="s3://bucket/next-run-asset/1")], serialized=True):
EmptyOperator(task_id="task1")

dag_maker.create_dagrun()
Expand All @@ -46,6 +46,16 @@ def test_next_run_assets(test_client, dag_maker):

assert response.status_code == 200
assert response.json() == {
"asset_expression": {"all": ["s3://bucket/key/1"]},
"events": [{"id": 20, "uri": "s3://bucket/key/1", "lastUpdate": None}],
"asset_expression": {
"all": [
{
"asset": {
"uri": "s3://bucket/next-run-asset/1",
"name": "s3://bucket/next-run-asset/1",
"group": "asset",
}
}
]
},
"events": [{"id": 20, "uri": "s3://bucket/next-run-asset/1", "lastUpdate": None}],
}
12 changes: 6 additions & 6 deletions tests/assets/test_asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,22 +597,22 @@ def resolved_asset_alias_2(self, session, asset_1):
return asset_alias_2

def test_init(self, asset_alias_1, asset_1, resolved_asset_alias_2):
cond = _AssetAliasCondition(name=asset_alias_1.name)
cond = _AssetAliasCondition(name=asset_alias_1.name, group=asset_alias_1.group)
assert cond.objects == []

cond = _AssetAliasCondition(name=resolved_asset_alias_2.name)
cond = _AssetAliasCondition(name=resolved_asset_alias_2.name, group=resolved_asset_alias_2.group)
assert cond.objects == [Asset(uri=asset_1.uri)]

def test_as_expression(self, asset_alias_1, resolved_asset_alias_2):
for assset_alias in (asset_alias_1, resolved_asset_alias_2):
cond = _AssetAliasCondition(assset_alias.name)
assert cond.as_expression() == {"alias": assset_alias.name}
cond = _AssetAliasCondition(name=assset_alias.name, group=assset_alias.group)
assert cond.as_expression() == {"alias": {"name": assset_alias.name, "group": ""}}

def test_evalute(self, asset_alias_1, resolved_asset_alias_2, asset_1):
cond = _AssetAliasCondition(asset_alias_1.name)
cond = _AssetAliasCondition(name=asset_alias_1.name, group=asset_alias_1.group)
assert cond.evaluate({asset_1.uri: True}) is False

cond = _AssetAliasCondition(resolved_asset_alias_2.name)
cond = _AssetAliasCondition(name=resolved_asset_alias_2.name, group=resolved_asset_alias_2.group)
assert cond.evaluate({asset_1.uri: True}) is True


Expand Down
3 changes: 2 additions & 1 deletion tests/decorators/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,12 +989,13 @@ def test_task_decorator_asset(dag_maker, session):

result = None
uri = "s3://bucket/name"
asset_name = "test_asset"

with dag_maker(session=session) as dag:

@dag.task()
def up1() -> Asset:
return Asset(uri)
return Asset(uri=uri, name=asset_name)

@dag.task()
def up2(src: Asset) -> str:
Expand Down
29 changes: 14 additions & 15 deletions tests/jobs/test_scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -3961,8 +3961,8 @@ def test_create_dag_runs_assets(self, session, dag_maker):
- That dag_model has next_dagrun
"""

asset1 = Asset(uri="ds1")
asset2 = Asset(uri="ds2")
asset1 = Asset(uri="test://asset1", name="test_asset", group="test_group")
asset2 = Asset(uri="test://asset2", name="test_asset_2", group="test_group")

with dag_maker(dag_id="assets-1", start_date=timezone.utcnow(), session=session):
BashOperator(task_id="task", bash_command="echo 1", outlets=[asset1])
Expand Down Expand Up @@ -4057,15 +4057,14 @@ def dict_from_obj(obj):
],
)
def test_no_create_dag_runs_when_dag_disabled(self, session, dag_maker, disable, enable):
ds = Asset("ds")
with dag_maker(dag_id="consumer", schedule=[ds], session=session):
asset = Asset(uri="test://asset_1", name="test_asset_1", group="test_group")
with dag_maker(dag_id="consumer", schedule=[asset], session=session):
pass
with dag_maker(dag_id="producer", schedule="@daily", session=session):
BashOperator(task_id="task", bash_command="echo 1", outlets=ds)
BashOperator(task_id="task", bash_command="echo 1", outlets=asset)
asset_manger = AssetManager()

asset_id = session.scalars(select(AssetModel.id).filter_by(uri=ds.uri)).one()

asset_id = session.scalars(select(AssetModel.id).filter_by(uri=asset.uri, name=asset.name)).one()
ase_q = select(AssetEvent).where(AssetEvent.asset_id == asset_id).order_by(AssetEvent.timestamp)
adrq_q = select(AssetDagRunQueue).where(
AssetDagRunQueue.asset_id == asset_id, AssetDagRunQueue.target_dag_id == "consumer"
Expand All @@ -4078,7 +4077,7 @@ def test_no_create_dag_runs_when_dag_disabled(self, session, dag_maker, disable,
dr1: DagRun = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
asset_manger.register_asset_change(
task_instance=dr1.get_task_instance("task", session=session),
asset=ds,
asset=asset,
session=session,
)
session.flush()
Expand All @@ -4092,7 +4091,7 @@ def test_no_create_dag_runs_when_dag_disabled(self, session, dag_maker, disable,
dr2: DagRun = dag_maker.create_dagrun_after(dr1, run_type=DagRunType.SCHEDULED)
asset_manger.register_asset_change(
task_instance=dr2.get_task_instance("task", session=session),
asset=ds,
asset=asset,
session=session,
)
session.flush()
Expand Down Expand Up @@ -6180,11 +6179,11 @@ def _find_assets_activation(session) -> tuple[list[AssetModel], list[AssetModel]
def test_asset_orphaning(self, dag_maker, session):
self.job_runner = SchedulerJobRunner(job=Job(), subdir=os.devnull)

asset1 = Asset(uri="ds1")
asset2 = Asset(uri="ds2")
asset3 = Asset(uri="ds3")
asset4 = Asset(uri="ds4")
asset5 = Asset(uri="ds5")
asset1 = Asset(uri="test://asset_1", name="test_asset_1", group="test_group")
asset2 = Asset(uri="test://asset_2", name="test_asset_2", group="test_group")
asset3 = Asset(uri="test://asset_3", name="test_asset_3", group="test_group")
asset4 = Asset(uri="test://asset_4", name="test_asset_4", group="test_group")
asset5 = Asset(uri="test://asset_5", name="test_asset_5", group="test_group")

with dag_maker(dag_id="assets-1", schedule=[asset1, asset2], session=session):
BashOperator(task_id="task", bash_command="echo 1", outlets=[asset3, asset4])
Expand Down Expand Up @@ -6223,7 +6222,7 @@ def test_asset_orphaning(self, dag_maker, session):
def test_asset_orphaning_ignore_orphaned_assets(self, dag_maker, session):
self.job_runner = SchedulerJobRunner(job=Job(), subdir=os.devnull)

asset1 = Asset(uri="ds1")
asset1 = Asset(uri="test://asset_1", name="test_asset_1", group="test_group")

with dag_maker(dag_id="assets-1", schedule=[asset1], session=session):
BashOperator(task_id="task", bash_command="echo 1")
Expand Down
10 changes: 7 additions & 3 deletions tests/listeners/test_asset_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,11 @@ def clean_listener_manager():
@pytest.mark.db_test
@provide_session
def test_asset_listener_on_asset_changed_gets_calls(create_task_instance_of_operator, session):
asset_uri = "test_asset_uri"
asset = Asset(uri=asset_uri)
asset_model = AssetModel(uri=asset_uri)
asset_uri = "test://asset/"
asset_name = "test_asset_uri"
asset_group = "test-group"
asset = Asset(uri=asset_uri, name=asset_name, group=asset_group)
asset_model = AssetModel(uri=asset_uri, name=asset_name)
session.add(asset_model)

session.flush()
Expand All @@ -60,3 +62,5 @@ def test_asset_listener_on_asset_changed_gets_calls(create_task_instance_of_oper

assert len(asset_listener.changed) == 1
assert asset_listener.changed[0].uri == asset_uri
assert asset_listener.changed[0].name == asset_name
assert asset_listener.changed[0].group == asset_group
Loading
Loading