diff --git a/providers/src/airflow/providers/common/compat/__init__.py b/providers/src/airflow/providers/common/compat/__init__.py index ef51cb422e513..dfd3e347626e7 100644 --- a/providers/src/airflow/providers/common/compat/__init__.py +++ b/providers/src/airflow/providers/common/compat/__init__.py @@ -23,17 +23,21 @@ # from __future__ import annotations -import packaging.version +from packaging.version import Version -from airflow import __version__ as airflow_version +from airflow import __version__ as AIRFLOW_VERSION __all__ = ["__version__"] __version__ = "1.2.1" -if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse( - "2.8.0" -): + +AIRFLOW_V_3_0_PLUS = Version(Version(AIRFLOW_VERSION).base_version) >= Version("3.0.0") +AIRFLOW_V_2_10_PLUS = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.10.0") +AIRFLOW_V_2_9_PLUS = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.9.0") +AIRFLOW_V_2_8_PLUS = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.8.0") + +if Version(Version(AIRFLOW_VERSION).base_version) < Version("2.8.0"): raise RuntimeError( f"The package `apache-airflow-providers-common-compat:{__version__}` needs Apache Airflow 2.8.0+" ) diff --git a/providers/src/airflow/providers/common/compat/assets/__init__.py b/providers/src/airflow/providers/common/compat/assets/__init__.py index ea073840fe006..66178cf0c68db 100644 --- a/providers/src/airflow/providers/common/compat/assets/__init__.py +++ b/providers/src/airflow/providers/common/compat/assets/__init__.py @@ -19,7 +19,12 @@ from typing import TYPE_CHECKING -from airflow import __version__ as AIRFLOW_VERSION +from airflow.providers.common.compat import ( + AIRFLOW_V_2_8_PLUS, + AIRFLOW_V_2_9_PLUS, + AIRFLOW_V_2_10_PLUS, + AIRFLOW_V_3_0_PLUS, +) if TYPE_CHECKING: from airflow.auth.managers.models.resource_details import AssetDetails @@ -32,13 +37,6 @@ expand_alias_to_assets, ) else: - from packaging.version import Version - - AIRFLOW_V_3_0_PLUS = Version(Version(AIRFLOW_VERSION).base_version) >= Version("3.0.0") - AIRFLOW_V_2_10_PLUS = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.10.0") - AIRFLOW_V_2_9_PLUS = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.9.0") - AIRFLOW_V_2_8_PLUS = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.8.0") - if AIRFLOW_V_3_0_PLUS: from airflow.auth.managers.models.resource_details import AssetDetails from airflow.sdk.definitions.asset import ( diff --git a/providers/src/airflow/providers/common/compat/lineage/hook.py b/providers/src/airflow/providers/common/compat/lineage/hook.py index bf080de37ffb2..63214a9051c11 100644 --- a/providers/src/airflow/providers/common/compat/lineage/hook.py +++ b/providers/src/airflow/providers/common/compat/lineage/hook.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from airflow.providers.common.compat.assets import AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS +from airflow.providers.common.compat import AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS def _get_asset_compat_hook_lineage_collector(): diff --git a/tests/decorators/test_assets.py b/tests/decorators/test_assets.py new file mode 100644 index 0000000000000..9410efc0a34a8 --- /dev/null +++ b/tests/decorators/test_assets.py @@ -0,0 +1,171 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock +from unittest.mock import ANY + +import pytest + +from airflow.sdk.definitions.asset import Asset +from airflow.decorators.assets import AssetRef, _AssetMainOperator, asset +from airflow.models.asset import AssetActive, AssetModel + +pytestmark = pytest.mark.db_test + + +@pytest.fixture +def example_asset_func(request): + name = "example_asset_func" + if getattr(request, "param", None) is not None: + name = request.param + + def _example_asset_func(): + return "This is example_asset" + + _example_asset_func.__name__ = name + _example_asset_func.__qualname__ = name + return _example_asset_func + + +@pytest.fixture +def example_asset_definition(example_asset_func): + return asset(schedule=None, uri="s3://bucket/object", group="MLModel", extra={"k": "v"})( + example_asset_func + ) + + +@pytest.fixture +def example_asset_func_with_valid_arg_as_inlet_asset(): + def _example_asset_func(self, context, inlet_asset_1, inlet_asset_2): + return "This is example_asset" + + _example_asset_func.__name__ = "example_asset_func" + _example_asset_func.__qualname__ = "example_asset_func" + return _example_asset_func + + +class TestAssetDecorator: + def test_without_uri(self, example_asset_func): + asset_definition = asset(schedule=None)(example_asset_func) + + assert asset_definition.name == "example_asset_func" + assert asset_definition.uri == "example_asset_func" + assert asset_definition.group == "" + assert asset_definition.extra == {} + assert asset_definition.function == example_asset_func + assert asset_definition.schedule is None + + def test_with_uri(self, example_asset_func): + asset_definition = asset(schedule=None, uri="s3://bucket/object")(example_asset_func) + + assert asset_definition.name == "example_asset_func" + assert asset_definition.uri == "s3://bucket/object" + assert asset_definition.group == "" + assert asset_definition.extra == {} + assert asset_definition.function == example_asset_func + assert asset_definition.schedule is None + + def test_with_group_and_extra(self, example_asset_func): + asset_definition = asset(schedule=None, uri="s3://bucket/object", group="MLModel", extra={"k": "v"})( + example_asset_func + ) + assert asset_definition.name == "example_asset_func" + assert asset_definition.uri == "s3://bucket/object" + assert asset_definition.group == "MLModel" + assert asset_definition.extra == {"k": "v"} + assert asset_definition.function == example_asset_func + assert asset_definition.schedule is None + + def test_nested_function(self): + def root_func(): + @asset(schedule=None) + def asset_func(): + pass + + with pytest.raises(ValueError) as err: + root_func() + + assert err.value.args[0] == "nested function not supported" + + @pytest.mark.parametrize("example_asset_func", ("self", "context"), indirect=True) + def test_with_invalid_asset_name(self, example_asset_func): + with pytest.raises(ValueError) as err: + asset(schedule=None)(example_asset_func) + + assert err.value.args[0].startswith("prohibited name for asset: ") + + +class TestAssetDefinition: + def test_serialzie(self, example_asset_definition): + assert example_asset_definition.serialize() == { + "extra": {"k": "v"}, + "group": "MLModel", + "name": "example_asset_func", + "uri": "s3://bucket/object", + } + + @mock.patch("airflow.decorators.assets._AssetMainOperator") + @mock.patch("airflow.decorators.assets.DAG") + def test__attrs_post_init__( + self, DAG, _AssetMainOperator, example_asset_func_with_valid_arg_as_inlet_asset + ): + asset_definition = asset(schedule=None, uri="s3://bucket/object", group="MLModel", extra={"k": "v"})( + example_asset_func_with_valid_arg_as_inlet_asset + ) + + DAG.assert_called_once_with(dag_id="example_asset_func", schedule=None, auto_register=True) + _AssetMainOperator.assert_called_once_with( + task_id="__main__", + inlets=[ + AssetRef(name="inlet_asset_1"), + AssetRef(name="inlet_asset_2"), + ], + outlets=[asset_definition.to_asset()], + python_callable=ANY, + definition_name="example_asset_func", + uri="s3://bucket/object", + ) + + python_callable = _AssetMainOperator.call_args.kwargs["python_callable"] + assert python_callable.__wrapped__ == example_asset_func_with_valid_arg_as_inlet_asset + + +class Test_AssetMainOperator: + def test_determine_kwargs(self, example_asset_func_with_valid_arg_as_inlet_asset, session): + example_asset = AssetModel(uri="s3://bucket/object1", name="inlet_asset_1") + session.add(example_asset) + session.add(AssetActive.for_asset(example_asset)) + session.commit() + + asset_definition = asset(schedule=None, uri="s3://bucket/object", group="MLModel", extra={"k": "v"})( + example_asset_func_with_valid_arg_as_inlet_asset + ) + + op = _AssetMainOperator( + task_id="__main__", + inlets=[AssetRef(name="inlet_asset_1"), AssetRef(name="inlet_asset_2")], + outlets=[asset_definition], + python_callable=example_asset_func_with_valid_arg_as_inlet_asset, + definition_name="example_asset_func", + ) + assert op.determine_kwargs(context={"k": "v"}) == { + "self": Asset(name="example_asset_func"), + "context": {"k": "v"}, + "inlet_asset_1": Asset(name="inlet_asset_1", uri="s3://bucket/object1"), + "inlet_asset_2": Asset(name="inlet_asset_2"), + }