From b8c297d82f656d862f5e0a25fbc5381aea826ffe Mon Sep 17 00:00:00 2001 From: Jarek Potiuk Date: Thu, 31 Oct 2024 19:23:01 +0100 Subject: [PATCH] Test standard provider with Airflow 2.8 and 2.9 The standard provider has now min version of Airflow = 2.8 since #43553, but we have not tested it for Airflow 2.8 and 2.9. --- .../src/airflow_breeze/global_constants.py | 4 +- docs/apache-airflow/howto/operator/python.rst | 2 +- .../airflow/providers/standard/__init__.py | 9 + .../providers/standard/operators/python.py | 84 +++++--- .../providers/standard/sensors/date_time.py | 23 ++- .../providers/standard/sensors/time.py | 23 ++- .../providers/standard/sensors/time_delta.py | 10 +- .../tests/common/sql/operators/test_sql.py | 1 - .../tests/openlineage/plugins/test_utils.py | 12 +- .../tests/openlineage/utils/test_utils.py | 5 +- .../tests/standard/operators/test_python.py | 182 ++++++++++++------ .../standard/utils/test_python_virtualenv.py | 6 +- 12 files changed, 253 insertions(+), 108 deletions(-) diff --git a/dev/breeze/src/airflow_breeze/global_constants.py b/dev/breeze/src/airflow_breeze/global_constants.py index 471666273fcb..d7bc8bf4b973 100644 --- a/dev/breeze/src/airflow_breeze/global_constants.py +++ b/dev/breeze/src/airflow_breeze/global_constants.py @@ -574,13 +574,13 @@ def get_airflow_extras(): { "python-version": "3.9", "airflow-version": "2.8.4", - "remove-providers": "cloudant fab edge standard", + "remove-providers": "cloudant fab edge", "run-tests": "true", }, { "python-version": "3.9", "airflow-version": "2.9.3", - "remove-providers": "cloudant edge standard", + "remove-providers": "cloudant edge", "run-tests": "true", }, { diff --git a/docs/apache-airflow/howto/operator/python.rst b/docs/apache-airflow/howto/operator/python.rst index 5d06bfa3d3e1..d24f886d01cd 100644 --- a/docs/apache-airflow/howto/operator/python.rst +++ b/docs/apache-airflow/howto/operator/python.rst @@ -253,7 +253,7 @@ With some limitations, you can also use ``Context`` in virtual environments. You can also use ``get_current_context()`` in the same way as before, but with some limitations. - * Requires ``pydantic>=2``. + * Requires ``apache-airflow>=3.0.0``. * Set ``use_airflow_context`` to ``True`` to call ``get_current_context()`` in the virtual environment. diff --git a/providers/src/airflow/providers/standard/__init__.py b/providers/src/airflow/providers/standard/__init__.py index 217e5db96078..47fc7a1e8009 100644 --- a/providers/src/airflow/providers/standard/__init__.py +++ b/providers/src/airflow/providers/standard/__init__.py @@ -15,3 +15,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + +from packaging.version import Version + +from airflow import __version__ as airflow_version + +AIRFLOW_VERSION = Version(airflow_version) +AIRFLOW_V_2_10_PLUS = Version(AIRFLOW_VERSION.base_version) >= Version("2.10.0") +AIRFLOW_V_3_0_PLUS = Version(AIRFLOW_VERSION.base_version) >= Version("3.0.0") diff --git a/providers/src/airflow/providers/standard/operators/python.py b/providers/src/airflow/providers/standard/operators/python.py index fb4babaf4aa2..8c1f440c73a2 100644 --- a/providers/src/airflow/providers/standard/operators/python.py +++ b/providers/src/airflow/providers/standard/operators/python.py @@ -35,9 +35,7 @@ from typing import TYPE_CHECKING, Any, Callable, Collection, Iterable, Mapping, NamedTuple, Sequence, cast import lazy_object_proxy -from packaging.version import Version -from airflow import __version__ as airflow_version from airflow.exceptions import ( AirflowConfigException, AirflowException, @@ -50,21 +48,19 @@ from airflow.models.taskinstance import _CURRENT_CONTEXT from airflow.models.variable import Variable from airflow.operators.branch import BranchMixIn +from airflow.providers.standard import AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS from airflow.providers.standard.utils.python_virtualenv import prepare_virtualenv, write_python_script from airflow.settings import _ENABLE_AIP_44 from airflow.typing_compat import Literal from airflow.utils import hashlib_wrapper -from airflow.utils.context import context_copy_partial, context_get_outlet_events, context_merge +from airflow.utils.context import context_copy_partial, context_merge from airflow.utils.file import get_unique_dag_module_name -from airflow.utils.operator_helpers import ExecutionCallableRunner, KeywordParameters -from airflow.utils.process_utils import execute_in_subprocess +from airflow.utils.operator_helpers import KeywordParameters +from airflow.utils.process_utils import execute_in_subprocess, execute_in_subprocess_with_kwargs from airflow.utils.session import create_session log = logging.getLogger(__name__) -AIRFLOW_VERSION = Version(airflow_version) -AIRFLOW_V_3_0_PLUS = Version(AIRFLOW_VERSION.base_version) >= Version("3.0.0") - if TYPE_CHECKING: from pendulum.datetime import DateTime @@ -187,7 +183,15 @@ def __init__( def execute(self, context: Context) -> Any: context_merge(context, self.op_kwargs, templates_dict=self.templates_dict) self.op_kwargs = self.determine_kwargs(context) - self._asset_events = context_get_outlet_events(context) + + if AIRFLOW_V_3_0_PLUS: + from airflow.utils.context import context_get_outlet_events + + self._asset_events = context_get_outlet_events(context) + elif AIRFLOW_V_2_10_PLUS: + from airflow.utils.context import context_get_outlet_events + + self._dataset_events = context_get_outlet_events(context) return_value = self.execute_callable() if self.show_return_value_in_logs: @@ -206,7 +210,15 @@ def execute_callable(self) -> Any: :return: the return value of the call. """ - runner = ExecutionCallableRunner(self.python_callable, self._asset_events, logger=self.log) + try: + from airflow.utils.operator_helpers import ExecutionCallableRunner + + asset_events = self._asset_events if AIRFLOW_V_3_0_PLUS else self._dataset_events + + runner = ExecutionCallableRunner(self.python_callable, asset_events, logger=self.log) + except ImportError: + # Handle Pre Airflow 3.10 case where ExecutionCallableRunner was not available + return self.python_callable(*self.op_args, **self.op_kwargs) return runner.run(*self.op_args, **self.op_kwargs) @@ -348,7 +360,6 @@ class _BasePythonVirtualenvOperator(PythonOperator, metaclass=ABCMeta): "ds_nodash", "expanded_ti_count", "inlets", - "map_index_template", "next_ds", "next_ds_nodash", "outlets", @@ -551,18 +562,25 @@ def _execute_python_callable_in_subprocess(self, python_path: Path): env_vars.update(self.env_vars) try: - execute_in_subprocess( - cmd=[ - os.fspath(python_path), - os.fspath(script_path), - os.fspath(input_path), - os.fspath(output_path), - os.fspath(string_args_path), - os.fspath(termination_log_path), - os.fspath(airflow_context_path), - ], - env=env_vars, - ) + cmd: list[str] = [ + os.fspath(python_path), + os.fspath(script_path), + os.fspath(input_path), + os.fspath(output_path), + os.fspath(string_args_path), + os.fspath(termination_log_path), + os.fspath(airflow_context_path), + ] + if AIRFLOW_V_2_10_PLUS: + execute_in_subprocess( + cmd=cmd, + env=env_vars, + ) + else: + execute_in_subprocess_with_kwargs( + cmd=cmd, + env=env_vars, + ) except subprocess.CalledProcessError as e: if e.returncode in self.skip_on_exit_code: raise AirflowSkipException(f"Process exited with code {e.returncode}. Skipping.") @@ -697,10 +715,15 @@ def __init__( raise AirflowException( "Passing non-string types (e.g. int or float) as python_version not supported" ) - + if use_airflow_context and not AIRFLOW_V_3_0_PLUS: + raise AirflowException( + "The `use_airflow_context=True` is only supported in Airflow 3.0.0 and later." + ) if use_airflow_context and (not expect_airflow and not system_site_packages): - error_msg = "use_airflow_context is set to True, but expect_airflow and system_site_packages are set to False." - raise AirflowException(error_msg) + raise AirflowException( + "The `use_airflow_context` parameter is set to True, but " + "expect_airflow and system_site_packages are set to False." + ) if not requirements: self.requirements: list[str] = [] elif isinstance(requirements, str): @@ -976,9 +999,14 @@ def __init__( ): if not python: raise ValueError("Python Path must be defined in ExternalPythonOperator") + if use_airflow_context and not AIRFLOW_V_3_0_PLUS: + raise AirflowException( + "The `use_airflow_context=True` is only supported in Airflow 3.0.0 and later." + ) if use_airflow_context and not expect_airflow: - error_msg = "use_airflow_context is set to True, but expect_airflow is set to False." - raise AirflowException(error_msg) + raise AirflowException( + "The `use_airflow_context` parameter is set to True, but expect_airflow is set to False." + ) self.python = python self.expect_pendulum = expect_pendulum super().__init__( diff --git a/providers/src/airflow/providers/standard/sensors/date_time.py b/providers/src/airflow/providers/standard/sensors/date_time.py index 20a6a484e05a..35e88df07ba7 100644 --- a/providers/src/airflow/providers/standard/sensors/date_time.py +++ b/providers/src/airflow/providers/standard/sensors/date_time.py @@ -18,10 +18,27 @@ from __future__ import annotations import datetime +from dataclasses import dataclass from typing import TYPE_CHECKING, Any, NoReturn, Sequence +from airflow.providers.standard import AIRFLOW_V_3_0_PLUS from airflow.sensors.base import BaseSensorOperator -from airflow.triggers.base import StartTriggerArgs + +try: + from airflow.triggers.base import StartTriggerArgs +except ImportError: + # TODO: Remove this when min airflow version is 2.10.0 for standard provider + @dataclass + class StartTriggerArgs: # type: ignore[no-redef] + """Arguments required for start task execution from triggerer.""" + + trigger_cls: str + next_method: str + trigger_kwargs: dict[str, Any] | None = None + next_kwargs: dict[str, Any] | None = None + timeout: datetime.timedelta | None = None + + from airflow.triggers.temporal import DateTimeTrigger from airflow.utils import timezone @@ -125,7 +142,9 @@ def execute(self, context: Context) -> NoReturn: trigger=DateTimeTrigger( moment=timezone.parse(self.target_time), end_from_trigger=self.end_from_trigger, - ), + ) + if AIRFLOW_V_3_0_PLUS + else DateTimeTrigger(moment=timezone.parse(self.target_time)), ) def execute_complete(self, context: Context, event: Any = None) -> None: diff --git a/providers/src/airflow/providers/standard/sensors/time.py b/providers/src/airflow/providers/standard/sensors/time.py index 6dba2628fce3..5c1629495297 100644 --- a/providers/src/airflow/providers/standard/sensors/time.py +++ b/providers/src/airflow/providers/standard/sensors/time.py @@ -18,10 +18,27 @@ from __future__ import annotations import datetime +from dataclasses import dataclass from typing import TYPE_CHECKING, Any, NoReturn +from airflow.providers.standard import AIRFLOW_V_2_10_PLUS from airflow.sensors.base import BaseSensorOperator -from airflow.triggers.base import StartTriggerArgs + +try: + from airflow.triggers.base import StartTriggerArgs +except ImportError: + # TODO: Remove this when min airflow version is 2.10.0 for standard provider + @dataclass + class StartTriggerArgs: # type: ignore[no-redef] + """Arguments required for start task execution from triggerer.""" + + trigger_cls: str + next_method: str + trigger_kwargs: dict[str, Any] | None = None + next_kwargs: dict[str, Any] | None = None + timeout: datetime.timedelta | None = None + + from airflow.triggers.temporal import DateTimeTrigger from airflow.utils import timezone @@ -102,7 +119,9 @@ def __init__( def execute(self, context: Context) -> NoReturn: self.defer( - trigger=DateTimeTrigger(moment=self.target_datetime, end_from_trigger=self.end_from_trigger), + trigger=DateTimeTrigger(moment=self.target_datetime, end_from_trigger=self.end_from_trigger) + if AIRFLOW_V_2_10_PLUS + else DateTimeTrigger(moment=self.target_datetime), method_name="execute_complete", ) diff --git a/providers/src/airflow/providers/standard/sensors/time_delta.py b/providers/src/airflow/providers/standard/sensors/time_delta.py index dc78a0e33bc4..eb8bac1c57ea 100644 --- a/providers/src/airflow/providers/standard/sensors/time_delta.py +++ b/providers/src/airflow/providers/standard/sensors/time_delta.py @@ -23,6 +23,7 @@ from airflow.configuration import conf from airflow.exceptions import AirflowSkipException +from airflow.providers.standard import AIRFLOW_V_3_0_PLUS from airflow.sensors.base import BaseSensorOperator from airflow.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger from airflow.utils import timezone @@ -81,7 +82,10 @@ def execute(self, context: Context) -> bool | NoReturn: # If the target datetime is in the past, return immediately return True try: - trigger = DateTimeTrigger(moment=target_dttm, end_from_trigger=self.end_from_trigger) + if AIRFLOW_V_3_0_PLUS: + trigger = DateTimeTrigger(moment=target_dttm, end_from_trigger=self.end_from_trigger) + else: + trigger = DateTimeTrigger(moment=target_dttm) except (TypeError, ValueError) as e: if self.soft_fail: raise AirflowSkipException("Skipping due to soft_fail is set to True.") from e @@ -121,7 +125,9 @@ def __init__( def execute(self, context: Context) -> None: if self.deferrable: self.defer( - trigger=TimeDeltaTrigger(self.time_to_wait, end_from_trigger=True), + trigger=TimeDeltaTrigger(self.time_to_wait, end_from_trigger=True) + if AIRFLOW_V_3_0_PLUS + else TimeDeltaTrigger(self.time_to_wait), method_name="execute_complete", ) else: diff --git a/providers/tests/common/sql/operators/test_sql.py b/providers/tests/common/sql/operators/test_sql.py index 6274fa7ef747..00149a6389a1 100644 --- a/providers/tests/common/sql/operators/test_sql.py +++ b/providers/tests/common/sql/operators/test_sql.py @@ -53,7 +53,6 @@ pytestmark = [ pytest.mark.db_test, - pytest.mark.skipif(reason="Tests for Airflow 2.8.0+ only"), pytest.mark.skip_if_database_isolation_mode, ] diff --git a/providers/tests/openlineage/plugins/test_utils.py b/providers/tests/openlineage/plugins/test_utils.py index 8c553a8d8953..cb294169bfd7 100644 --- a/providers/tests/openlineage/plugins/test_utils.py +++ b/providers/tests/openlineage/plugins/test_utils.py @@ -57,12 +57,6 @@ if AIRFLOW_V_3_0_PLUS: from airflow.utils.types import DagRunTriggeredByType -BASH_OPERATOR_PATH = "airflow.providers.standard.operators.bash" -PYTHON_OPERATOR_PATH = "airflow.providers.standard.operators.python" -if not AIRFLOW_V_2_10_PLUS: - BASH_OPERATOR_PATH = "airflow.operators.bash" - PYTHON_OPERATOR_PATH = "airflow.operators.python" - class SafeStrDict(dict): def __str__(self): @@ -276,7 +270,7 @@ def test_get_fully_qualified_class_name(): from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter result = get_fully_qualified_class_name(BashOperator(task_id="test", bash_command="exit 0;")) - assert result == f"{BASH_OPERATOR_PATH}.BashOperator" + assert result == "airflow.providers.standard.operators.bash.BashOperator" result = get_fully_qualified_class_name(OpenLineageAdapter()) assert result == "airflow.providers.openlineage.plugins.adapter.OpenLineageAdapter" @@ -292,8 +286,8 @@ def test_is_operator_disabled(mock_disabled_operators): assert is_operator_disabled(op) is False mock_disabled_operators.return_value = { - f"{BASH_OPERATOR_PATH}.BashOperator", - f"{PYTHON_OPERATOR_PATH}.PythonOperator", + "airflow.providers.standard.operators.bash.BashOperator", + "airflow.providers.standard.operators.python.PythonOperator", } assert is_operator_disabled(op) is True diff --git a/providers/tests/openlineage/utils/test_utils.py b/providers/tests/openlineage/utils/test_utils.py index f4e286331a51..28be0b630675 100644 --- a/providers/tests/openlineage/utils/test_utils.py +++ b/providers/tests/openlineage/utils/test_utils.py @@ -43,14 +43,11 @@ from airflow.utils.task_group import TaskGroup from airflow.utils.types import DagRunType -from tests_common.test_utils.compat import AIRFLOW_V_2_10_PLUS, BashOperator, PythonOperator +from tests_common.test_utils.compat import BashOperator, PythonOperator from tests_common.test_utils.mock_operators import MockOperator BASH_OPERATOR_PATH = "airflow.providers.standard.operators.bash" PYTHON_OPERATOR_PATH = "airflow.providers.standard.operators.python" -if not AIRFLOW_V_2_10_PLUS: - BASH_OPERATOR_PATH = "airflow.operators.bash" - PYTHON_OPERATOR_PATH = "airflow.operators.python" class CustomOperatorForTest(BashOperator): diff --git a/providers/tests/standard/operators/test_python.py b/providers/tests/standard/operators/test_python.py index b8a8ef5bc122..143c057003d6 100644 --- a/providers/tests/standard/operators/test_python.py +++ b/providers/tests/standard/operators/test_python.py @@ -72,7 +72,7 @@ from airflow.utils.types import NOTSET, DagRunType from tests_common.test_utils import AIRFLOW_MAIN_FOLDER -from tests_common.test_utils.compat import AIRFLOW_V_3_0_PLUS +from tests_common.test_utils.compat import AIRFLOW_V_2_9_PLUS, AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS from tests_common.test_utils.db import clear_db_runs if AIRFLOW_V_3_0_PLUS: @@ -97,6 +97,10 @@ USE_AIRFLOW_CONTEXT_MARKER = pytest.mark.skipif(not _ENABLE_AIP_44, reason="AIP-44 is not enabled") +AIRFLOW_CONTEXT_BEFORE_V3_0_MESSAGE = ( + r"The `use_airflow_context=True` is only supported in Airflow 3.0.0 and later." +) + class BasePythonTest: """Base test class for TestPythonOperator and TestPythonSensor classes""" @@ -509,7 +513,7 @@ def f(): ti = self.create_ti(f) with pytest.raises( AirflowException, - match="'branch_task_ids' expected all task IDs are strings.", + match=r"'branch_task_ids'.*task.*", ): ti.run() @@ -518,7 +522,9 @@ def f(): return "some_task_id" ti = self.create_ti(f) - with pytest.raises(AirflowException, match="Invalid tasks found: {'some_task_id'}"): + with pytest.raises( + AirflowException, match=r"Invalid tasks found: {\(False, 'bool'\)}.|'branch_task_ids'.*task.*" + ): ti.run() @pytest.mark.skip_if_database_isolation_mode # tests pure logic with run() method, can not run in isolation mode @@ -903,9 +909,16 @@ def test_virtualenv_serializable_context_fields(self, create_task_instance): "ti", "var", # Accessor for Variable; var->json and var->value. "conn", # Accessor for Connection. - "inlet_events", # Accessor for inlet AssetEvent. - "outlet_events", # Accessor for outlet AssetEvent. ] + if AIRFLOW_V_2_9_PLUS: + intentionally_excluded_context_keys.extend( + ["map_index_template"], + ) + if AIRFLOW_V_2_10_PLUS: + intentionally_excluded_context_keys.extend( + # Accessors for inlet_events and outlet_events + ["inlet_events", "outlet_events"], + ) ti = create_task_instance(dag_id=self.dag_id, task_id=self.task_id, schedule=None) context = ti.get_template_context() @@ -1035,13 +1048,17 @@ def f(): context = get_current_context() if not isinstance(context, Context): # type: ignore[misc] - error_msg = f"Expected Context, got {type(context)}" + error_msg = f"Expected Context, got {type(context)}:{context!r}" raise TypeError(error_msg) return [] - ti = self.run_as_task(f, return_ti=True, multiple_outputs=False, use_airflow_context=True) - assert ti.state == TaskInstanceState.SUCCESS + if AIRFLOW_V_3_0_PLUS: + ti = self.run_as_task(f, return_ti=True, multiple_outputs=False, use_airflow_context=True) + assert ti.state == TaskInstanceState.SUCCESS + else: + with pytest.raises(AirflowException, match=AIRFLOW_CONTEXT_BEFORE_V3_0_MESSAGE): + self.run_as_task(f, return_ti=True, use_airflow_context=True) @USE_AIRFLOW_CONTEXT_MARKER def test_current_context_not_found_error(self): @@ -1051,21 +1068,32 @@ def f(): get_current_context() return [] - with pytest.raises( - AirflowException, - match="Current context was requested but no context was found! " - "Are you running within an airflow task?", - ): - self.run_as_task(f, return_ti=True, multiple_outputs=False, use_airflow_context=False) + if AIRFLOW_V_2_9_PLUS: + with pytest.raises( + AirflowException, + match="Current context was requested but no context was found! " + "Are you running within an airflow task?", + ): + self.run_as_task(f, return_ti=True, multiple_outputs=False, use_airflow_context=False) + else: + with pytest.raises( + AirflowException, + match="Current context was requested but no context was found! " + "Are you running within an airflow task?", + ): + self.run_as_task(f, return_ti=True, use_airflow_context=False) @USE_AIRFLOW_CONTEXT_MARKER def test_current_context_airflow_not_found_error(self): airflow_flag: dict[str, bool] = {"expect_airflow": False} - error_msg = "use_airflow_context is set to True, but expect_airflow is set to False." + error_msg = r"The `use_airflow_context` parameter is set to True, but expect_airflow is set to False." if not issubclass(self.opcls, ExternalPythonOperator): airflow_flag["system_site_packages"] = False - error_msg = "use_airflow_context is set to True, but expect_airflow and system_site_packages are set to False." + error_msg = ( + r"The `use_airflow_context` parameter is set to True, but " + r"expect_airflow and system_site_packages are set to False." + ) def f(): from airflow.providers.standard.operators.python import get_current_context @@ -1073,10 +1101,14 @@ def f(): get_current_context() return [] - with pytest.raises(AirflowException, match=error_msg): - self.run_as_task( - f, return_ti=True, multiple_outputs=False, use_airflow_context=True, **airflow_flag - ) + if AIRFLOW_V_3_0_PLUS: + with pytest.raises(AirflowException, match=error_msg): + self.run_as_task( + f, return_ti=True, multiple_outputs=False, use_airflow_context=True, **airflow_flag + ) + else: + with pytest.raises(AirflowException, match=AIRFLOW_CONTEXT_BEFORE_V3_0_MESSAGE): + self.run_as_task(f, return_ti=True, use_airflow_context=True, **airflow_flag) @USE_AIRFLOW_CONTEXT_MARKER def test_use_airflow_context_touch_other_variables(self): @@ -1086,13 +1118,17 @@ def f(): context = get_current_context() if not isinstance(context, Context): # type: ignore[misc] - error_msg = f"Expected Context, got {type(context)}" + error_msg = f"Expected Context, got {type(context)}:{context!r}" raise TypeError(error_msg) return [] - ti = self.run_as_task(f, return_ti=True, multiple_outputs=False, use_airflow_context=True) - assert ti.state == TaskInstanceState.SUCCESS + if AIRFLOW_V_3_0_PLUS: + ti = self.run_as_task(f, return_ti=True, multiple_outputs=False, use_airflow_context=True) + assert ti.state == TaskInstanceState.SUCCESS + else: + with pytest.raises(AirflowException, match=AIRFLOW_CONTEXT_BEFORE_V3_0_MESSAGE): + self.run_as_task(f, return_ti=True, use_airflow_context=True) @pytest.mark.skipif(_ENABLE_AIP_44, reason="AIP-44 is enabled") def test_use_airflow_context_without_aip_44_error(self): @@ -1103,8 +1139,12 @@ def f(): return [] error_msg = "`get_current_context()` needs to be used with AIP-44 enabled." - with pytest.raises(AirflowException, match=re.escape(error_msg)): - self.run_as_task(f, return_ti=True, multiple_outputs=False, use_airflow_context=True) + if AIRFLOW_V_3_0_PLUS: + with pytest.raises(AirflowException, match=re.escape(error_msg)): + self.run_as_task(f, return_ti=True, multiple_outputs=False, use_airflow_context=True) + else: + with pytest.raises(AirflowException, match=re.escape(AIRFLOW_CONTEXT_BEFORE_V3_0_MESSAGE)): + self.run_as_task(f, return_ti=True, use_airflow_context=True) venv_cache_path = tempfile.mkdtemp(prefix="venv_cache_path") @@ -1520,21 +1560,32 @@ def f(): context = get_current_context() if not isinstance(context, Context): # type: ignore[misc] - error_msg = f"Expected Context, got {type(context)}" + error_msg = f"Expected Context, got {type(context)}:{context!r}" raise TypeError(error_msg) return [] - ti = self.run_as_task( - f, - return_ti=True, - multiple_outputs=False, - use_airflow_context=True, - session=session, - expect_airflow=False, - system_site_packages=True, - ) - assert ti.state == TaskInstanceState.SUCCESS + if AIRFLOW_V_3_0_PLUS: + ti = self.run_as_task( + f, + return_ti=True, + multiple_outputs=False, + use_airflow_context=True, + session=session, + expect_airflow=False, + system_site_packages=True, + ) + assert ti.state == TaskInstanceState.SUCCESS + else: + with pytest.raises(AirflowException, match=AIRFLOW_CONTEXT_BEFORE_V3_0_MESSAGE): + self.run_as_task( + f, + return_ti=True, + use_airflow_context=True, + session=session, + expect_airflow=False, + system_site_packages=True, + ) # when venv tests are run in parallel to other test they create new processes and this might take @@ -1627,21 +1678,25 @@ def f(a, b, c=False, d=False): else: raise RuntimeError - with pytest.raises(AirflowException, match=r"Invalid tasks found: {\((True|False), 'bool'\)}"): + with pytest.raises( + AirflowException, match=r"Invalid tasks found: {\(False, 'bool'\)}.|'branch_task_ids'.*task.*" + ): self.run_as_task(f, op_args=[0, 1], op_kwargs={"c": True}) def test_return_false(self): def f(): return False - with pytest.raises(AirflowException, match=r"Invalid tasks found: {\(False, 'bool'\)}."): + with pytest.raises( + AirflowException, match=r"Invalid tasks found: {\(False, 'bool'\)}.|'branch_task_ids'.*task.*" + ): self.run_as_task(f) def test_context(self): def f(templates_dict): return templates_dict["ds"] - with pytest.raises(AirflowException, match="Invalid tasks found:"): + with pytest.raises(AirflowException, match="Invalid tasks found:|'branch_task_ids'.*task.*"): self.run_as_task(f, templates_dict={"ds": "{{ ds }}"}) def test_environment_variables(self): @@ -1652,7 +1707,7 @@ def f(): with pytest.raises( AirflowException, - match=r"'branch_task_ids' must contain only valid task_ids. Invalid tasks found: {'ABCDE'}", + match=r"'branch_task_ids'.*task.*", ): self.run_as_task(f, env_vars={"MY_ENV_VAR": "ABCDE"}) @@ -1666,7 +1721,7 @@ def f(): with pytest.raises( AirflowException, - match=r"'branch_task_ids' must contain only valid task_ids. Invalid tasks found: {'QWERT'}", + match=r"'branch_task_ids'.*task.*", ): self.run_as_task(f, inherit_env=True) @@ -1691,7 +1746,7 @@ def f(): with pytest.raises( AirflowException, - match=r"'branch_task_ids' must contain only valid task_ids. Invalid tasks found: {'EFGHI'}", + match=r"'branch_task_ids'.*task.*", ): self.run_as_task(f, env_vars={"MY_ENV_VAR": "EFGHI"}, inherit_env=True) @@ -1706,7 +1761,9 @@ def test_with_no_caching(self): def f(): return False - with pytest.raises(AirflowException, match=r"Invalid tasks found: {\(False, 'bool'\)}."): + with pytest.raises( + AirflowException, match=r"Invalid tasks found: {\(False, 'bool'\)}.|'branch_task_ids'.*task.*" + ): self.run_as_task(f, do_not_use_caching=True) def test_with_dag_run(self): @@ -1827,7 +1884,7 @@ def f(): ti = self.create_ti(f) with pytest.raises( AirflowException, - match="'branch_task_ids' expected all task IDs are strings.", + match=r"'branch_task_ids'.*task.*", ): ti.run() @@ -1836,7 +1893,9 @@ def f(): return "some_task_id" ti = self.create_ti(f) - with pytest.raises(AirflowException, match="Invalid tasks found: {'some_task_id'}"): + with pytest.raises( + AirflowException, match=r"Invalid tasks found: {\(False, 'bool'\)}.|'branch_task_ids'.*task.*" + ): ti.run() @@ -1866,21 +1925,32 @@ def f(): context = get_current_context() if not isinstance(context, Context): # type: ignore[misc] - error_msg = f"Expected Context, got {type(context)}" + error_msg = f"Expected Context, got {type(context)}:{context!r}" raise TypeError(error_msg) return [] - ti = self.run_as_task( - f, - return_ti=True, - multiple_outputs=False, - use_airflow_context=True, - session=session, - expect_airflow=False, - system_site_packages=True, - ) - assert ti.state == TaskInstanceState.SUCCESS + if AIRFLOW_V_3_0_PLUS: + ti = self.run_as_task( + f, + return_ti=True, + multiple_outputs=False, + use_airflow_context=True, + session=session, + expect_airflow=False, + system_site_packages=True, + ) + assert ti.state == TaskInstanceState.SUCCESS + else: + with pytest.raises(AirflowException, match=AIRFLOW_CONTEXT_BEFORE_V3_0_MESSAGE): + self.run_as_task( + f, + return_ti=True, + use_airflow_context=True, + session=session, + expect_airflow=False, + system_site_packages=True, + ) # when venv tests are run in parallel to other test they create new processes and this might take diff --git a/providers/tests/standard/utils/test_python_virtualenv.py b/providers/tests/standard/utils/test_python_virtualenv.py index 0e10dcf5305c..b5d31679aa5b 100644 --- a/providers/tests/standard/utils/test_python_virtualenv.py +++ b/providers/tests/standard/utils/test_python_virtualenv.py @@ -25,6 +25,7 @@ from airflow.providers.standard.utils.python_virtualenv import _generate_pip_conf, _use_uv, prepare_virtualenv from airflow.utils.decorators import remove_task_decorator +from tests_common.test_utils.compat import AIRFLOW_V_2_9_PLUS from tests_common.test_utils.config import conf_vars @@ -204,7 +205,10 @@ def test_remove_decorator_no_parens(self): def test_remove_decorator_including_comment(self): py_source = "@task.virtualenv\ndef f():\n# @task.virtualenv\nimport funcsigs" res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv") - assert res == "def f():\n# @task.virtualenv\nimport funcsigs" + if AIRFLOW_V_2_9_PLUS: + assert res == "def f():\n# @task.virtualenv\nimport funcsigs" + else: + assert res == "def f():\n# " def test_remove_decorator_nested(self): py_source = "@foo\n@task.virtualenv\n@bar\ndef f():\nimport funcsigs"