diff --git a/flytekit/__init__.py b/flytekit/__init__.py index 6cd2b85564..5afb9d4bfb 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -228,6 +228,7 @@ from flytekit.core.container_task import ContainerTask from flytekit.core.context_manager import ExecutionParameters, FlyteContext, FlyteContextManager from flytekit.core.dynamic_workflow_task import dynamic +from flytekit.core.environment import Environment from flytekit.core.gate import approve, sleep, wait_for_input from flytekit.core.hash import HashMethod from flytekit.core.launch_plan import LaunchPlan, reference_launch_plan diff --git a/flytekit/core/environment.py b/flytekit/core/environment.py new file mode 100644 index 0000000000..8058b84b8b --- /dev/null +++ b/flytekit/core/environment.py @@ -0,0 +1,115 @@ +import copy +from functools import partial, wraps +from typing import Any, Callable, TypeVar, Union + +from rich.console import Console +from rich.panel import Panel +from rich.pretty import Pretty +from typing_extensions import Concatenate, ParamSpec + +from flytekit.core.dynamic_workflow_task import dynamic +from flytekit.core.task import task + +P = ParamSpec("P") +T = TypeVar("T") + + +# basically, I want the docstring for `flyte.task` to be available for users to see +# this is "copying" the docstring from `flyte.task` to functions wrapped by `forge` +# more details here: https://github.com/python/typing/issues/270 +def forge(source: Callable[Concatenate[Any, P], T]) -> Callable[[Callable], Callable[Concatenate[Any, P], T]]: + def wrapper(target: Callable) -> Callable[Concatenate[Any, P], T]: + @wraps(source) + def wrapped(self, *args: P.args, **kwargs: P.kwargs) -> T: + return target(self, *args, **kwargs) + + return wrapped + + return wrapper + + +def inherit(old: dict[str, Any], new: dict[str, Any]) -> dict[str, Any]: + out = copy.deepcopy(old) + + for key, value in new.items(): + if key in out: + if isinstance(value, dict): + out[key] = inherit(out[key], value) + else: + out[key] = value + else: + out[key] = value + + return out + + +class Environment: + @forge(task) + def __init__(self, **overrides: Any) -> None: + _overrides: dict[str, Any] = {} + for key, value in overrides.items(): + if key == "_task_function": + raise KeyError("Cannot override task function") + + _overrides[key] = value + + self.overrides = _overrides + + @forge(task) + def update(self, **overrides: Any) -> None: + self.overrides = inherit(self.overrides, overrides) + + @forge(task) + def extend(self, **overrides: Any) -> "Environment": + return self.__class__(**inherit(self.overrides, overrides)) + + @forge(task) + def __call__( + self, _task_function: Union[Callable, None] = None, /, **overrides + ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + # no additional overrides are passed + if _task_function is not None: + if callable(_task_function): + return partial(task, **self.overrides)(_task_function) + + else: + raise ValueError("The first positional argument must be a callable") + + # additional overrides are passed + else: + + def inner(_task_function: Callable) -> Callable: + inherited = inherit(self.overrides, overrides) + + return partial(task, **inherited)(_task_function) + + return inner + + def show(self) -> None: + console = Console() + + console.print(Panel.fit(Pretty(self.overrides))) + + task = __call__ + + @forge(dynamic) + def dynamic( + self, _task_function: Union[Callable, None] = None, /, **overrides + ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + # no additional overrides are passed + if _task_function is not None: + if callable(_task_function): + return partial(dynamic, **self.overrides)(_task_function) + + else: + raise ValueError("The first positional argument must be a callable") + + # additional overrides are passed + else: + + def inner(_task_function: Callable) -> Callable: + inherited = inherit(self.overrides, overrides) + + return partial(dynamic, **inherited)(_task_function) + + return inner diff --git a/plugins/flytekit-spark/tests/test_environment.py b/plugins/flytekit-spark/tests/test_environment.py new file mode 100644 index 0000000000..e9d695c843 --- /dev/null +++ b/plugins/flytekit-spark/tests/test_environment.py @@ -0,0 +1,23 @@ +import flytekit +from flytekitplugins.spark import Spark +from flytekit.core.environment import Environment + + +def test_spark_task(): + + env = Environment( + task_config=Spark( + spark_conf={"spark": "1"}, + executor_path="/usr/bin/python3", + applications_path="local:///usr/local/bin/entrypoint.py", + ) + ) + + @env.task + def my_spark(a: str) -> int: + session = flytekit.current_context().spark_session + assert session.sparkContext.appName == "FlyteSpark: ex:local:local:local" + return 10 + + assert my_spark.task_config is not None + assert my_spark.task_config.spark_conf == {"spark": "1"} diff --git a/tests/flytekit/unit/core/test_environment.py b/tests/flytekit/unit/core/test_environment.py new file mode 100644 index 0000000000..1756b64618 --- /dev/null +++ b/tests/flytekit/unit/core/test_environment.py @@ -0,0 +1,90 @@ +from flytekit.core.environment import Environment, inherit + + +def test_basic_environment(): + + env = Environment(retries=2) + + @env.task + def foo(): + pass + + @env + def bar(): + pass + + assert foo._metadata.retries == 2 + assert bar._metadata.retries == 2 + + +def test_dynamic_from_environment(): + + env = Environment(retries=2) + + @env.task + def foo(): + pass + + @env.dynamic + def bar(): + foo() + + assert foo._metadata.retries == 2 + assert bar._metadata.retries == 2 + + +def test_extended_environment(): + + env = Environment(retries=2) + + other = env.extend(retries=0) + + @other.task + def foo(): + pass + + @other + def bar(): + pass + + + assert foo._metadata.retries == 0 + assert bar._metadata.retries == 0 + + +def test_updated_environment(): + + env = Environment(retries=2) + + env.update(retries=0) + + @env.task + def foo(): + pass + + @env + def bar(): + pass + + + assert foo._metadata.retries == 0 + assert bar._metadata.retries == 0 + + +def test_show_environment(): + + env = Environment(retries=2) + + env.show() + + +def test_inherit(): + + old_config = {"cache": False, "timeout": 10} + + new_config = {"cache": True} + + combined = inherit(old_config, new_config) + + assert combined["cache"] == True + assert combined["timeout"] == 10 diff --git a/tests/flytekit/unit/experimental/__init__.py b/tests/flytekit/unit/experimental/__init__.py deleted file mode 100644 index e69de29bb2..0000000000