Skip to content

Commit

Permalink
refactor(providers/common/compat): extract airflow version to __init__
Browse files Browse the repository at this point in the history
  • Loading branch information
Lee-W committed Nov 8, 2024
1 parent e70bc8f commit befbbbf
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 14 deletions.
14 changes: 9 additions & 5 deletions providers/src/airflow/providers/common/compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,21 @@
#
from __future__ import annotations

import packaging.version
from packaging.version import Version

from airflow import __version__ as airflow_version
from airflow import __version__ as AIRFLOW_VERSION

__all__ = ["__version__"]

__version__ = "1.2.1"

if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
"2.8.0"
):

AIRFLOW_V_3_0_PLUS = Version(Version(AIRFLOW_VERSION).base_version) >= Version("3.0.0")
AIRFLOW_V_2_10_PLUS = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.10.0")
AIRFLOW_V_2_9_PLUS = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.9.0")
AIRFLOW_V_2_8_PLUS = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.8.0")

if Version(Version(AIRFLOW_VERSION).base_version) < Version("2.8.0"):
raise RuntimeError(
f"The package `apache-airflow-providers-common-compat:{__version__}` needs Apache Airflow 2.8.0+"
)
14 changes: 6 additions & 8 deletions providers/src/airflow/providers/common/compat/assets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@

from typing import TYPE_CHECKING

from airflow import __version__ as AIRFLOW_VERSION
from airflow.providers.common.compat import (
AIRFLOW_V_2_8_PLUS,
AIRFLOW_V_2_9_PLUS,
AIRFLOW_V_2_10_PLUS,
AIRFLOW_V_3_0_PLUS,
)

if TYPE_CHECKING:
from airflow.auth.managers.models.resource_details import AssetDetails
Expand All @@ -32,13 +37,6 @@
expand_alias_to_assets,
)
else:
from packaging.version import Version

AIRFLOW_V_3_0_PLUS = Version(Version(AIRFLOW_VERSION).base_version) >= Version("3.0.0")
AIRFLOW_V_2_10_PLUS = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.10.0")
AIRFLOW_V_2_9_PLUS = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.9.0")
AIRFLOW_V_2_8_PLUS = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.8.0")

if AIRFLOW_V_3_0_PLUS:
from airflow.auth.managers.models.resource_details import AssetDetails
from airflow.sdk.definitions.asset import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations

from airflow.providers.common.compat.assets import AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS
from airflow.providers.common.compat import AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS


def _get_asset_compat_hook_lineage_collector():
Expand Down
171 changes: 171 additions & 0 deletions tests/decorators/test_assets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from unittest import mock
from unittest.mock import ANY

import pytest

from airflow.sdk.definitions.asset import Asset
from airflow.decorators.assets import AssetRef, _AssetMainOperator, asset
from airflow.models.asset import AssetActive, AssetModel

pytestmark = pytest.mark.db_test


@pytest.fixture
def example_asset_func(request):
name = "example_asset_func"
if getattr(request, "param", None) is not None:
name = request.param

def _example_asset_func():
return "This is example_asset"

_example_asset_func.__name__ = name
_example_asset_func.__qualname__ = name
return _example_asset_func


@pytest.fixture
def example_asset_definition(example_asset_func):
return asset(schedule=None, uri="s3://bucket/object", group="MLModel", extra={"k": "v"})(
example_asset_func
)


@pytest.fixture
def example_asset_func_with_valid_arg_as_inlet_asset():
def _example_asset_func(self, context, inlet_asset_1, inlet_asset_2):
return "This is example_asset"

_example_asset_func.__name__ = "example_asset_func"
_example_asset_func.__qualname__ = "example_asset_func"
return _example_asset_func


class TestAssetDecorator:
def test_without_uri(self, example_asset_func):
asset_definition = asset(schedule=None)(example_asset_func)

assert asset_definition.name == "example_asset_func"
assert asset_definition.uri == "example_asset_func"
assert asset_definition.group == ""
assert asset_definition.extra == {}
assert asset_definition.function == example_asset_func
assert asset_definition.schedule is None

def test_with_uri(self, example_asset_func):
asset_definition = asset(schedule=None, uri="s3://bucket/object")(example_asset_func)

assert asset_definition.name == "example_asset_func"
assert asset_definition.uri == "s3://bucket/object"
assert asset_definition.group == ""
assert asset_definition.extra == {}
assert asset_definition.function == example_asset_func
assert asset_definition.schedule is None

def test_with_group_and_extra(self, example_asset_func):
asset_definition = asset(schedule=None, uri="s3://bucket/object", group="MLModel", extra={"k": "v"})(
example_asset_func
)
assert asset_definition.name == "example_asset_func"
assert asset_definition.uri == "s3://bucket/object"
assert asset_definition.group == "MLModel"
assert asset_definition.extra == {"k": "v"}
assert asset_definition.function == example_asset_func
assert asset_definition.schedule is None

def test_nested_function(self):
def root_func():
@asset(schedule=None)
def asset_func():
pass

with pytest.raises(ValueError) as err:
root_func()

assert err.value.args[0] == "nested function not supported"

@pytest.mark.parametrize("example_asset_func", ("self", "context"), indirect=True)
def test_with_invalid_asset_name(self, example_asset_func):
with pytest.raises(ValueError) as err:
asset(schedule=None)(example_asset_func)

assert err.value.args[0].startswith("prohibited name for asset: ")


class TestAssetDefinition:
def test_serialzie(self, example_asset_definition):
assert example_asset_definition.serialize() == {
"extra": {"k": "v"},
"group": "MLModel",
"name": "example_asset_func",
"uri": "s3://bucket/object",
}

@mock.patch("airflow.decorators.assets._AssetMainOperator")
@mock.patch("airflow.decorators.assets.DAG")
def test__attrs_post_init__(
self, DAG, _AssetMainOperator, example_asset_func_with_valid_arg_as_inlet_asset
):
asset_definition = asset(schedule=None, uri="s3://bucket/object", group="MLModel", extra={"k": "v"})(
example_asset_func_with_valid_arg_as_inlet_asset
)

DAG.assert_called_once_with(dag_id="example_asset_func", schedule=None, auto_register=True)
_AssetMainOperator.assert_called_once_with(
task_id="__main__",
inlets=[
AssetRef(name="inlet_asset_1"),
AssetRef(name="inlet_asset_2"),
],
outlets=[asset_definition.to_asset()],
python_callable=ANY,
definition_name="example_asset_func",
uri="s3://bucket/object",
)

python_callable = _AssetMainOperator.call_args.kwargs["python_callable"]
assert python_callable.__wrapped__ == example_asset_func_with_valid_arg_as_inlet_asset


class Test_AssetMainOperator:
def test_determine_kwargs(self, example_asset_func_with_valid_arg_as_inlet_asset, session):
example_asset = AssetModel(uri="s3://bucket/object1", name="inlet_asset_1")
session.add(example_asset)
session.add(AssetActive.for_asset(example_asset))
session.commit()

asset_definition = asset(schedule=None, uri="s3://bucket/object", group="MLModel", extra={"k": "v"})(
example_asset_func_with_valid_arg_as_inlet_asset
)

op = _AssetMainOperator(
task_id="__main__",
inlets=[AssetRef(name="inlet_asset_1"), AssetRef(name="inlet_asset_2")],
outlets=[asset_definition],
python_callable=example_asset_func_with_valid_arg_as_inlet_asset,
definition_name="example_asset_func",
)
assert op.determine_kwargs(context={"k": "v"}) == {
"self": Asset(name="example_asset_func"),
"context": {"k": "v"},
"inlet_asset_1": Asset(name="inlet_asset_1", uri="s3://bucket/object1"),
"inlet_asset_2": Asset(name="inlet_asset_2"),
}

0 comments on commit befbbbf

Please sign in to comment.