Skip to content

Commit

Permalink
add managed environment POC (#3021)
Browse files Browse the repository at this point in the history
Signed-off-by: Grantham Taylor <granthamtaylor@icloud.com>
  • Loading branch information
granthamtaylor authored Jan 6, 2025
1 parent f1e5339 commit f634d53
Show file tree
Hide file tree
Showing 5 changed files with 229 additions and 0 deletions.
1 change: 1 addition & 0 deletions flytekit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@
from flytekit.core.container_task import ContainerTask
from flytekit.core.context_manager import ExecutionParameters, FlyteContext, FlyteContextManager
from flytekit.core.dynamic_workflow_task import dynamic
from flytekit.core.environment import Environment
from flytekit.core.gate import approve, sleep, wait_for_input
from flytekit.core.hash import HashMethod
from flytekit.core.launch_plan import LaunchPlan, reference_launch_plan
Expand Down
115 changes: 115 additions & 0 deletions flytekit/core/environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import copy
from functools import partial, wraps
from typing import Any, Callable, TypeVar, Union

from rich.console import Console
from rich.panel import Panel
from rich.pretty import Pretty
from typing_extensions import Concatenate, ParamSpec

from flytekit.core.dynamic_workflow_task import dynamic
from flytekit.core.task import task

P = ParamSpec("P")
T = TypeVar("T")


# basically, I want the docstring for `flyte.task` to be available for users to see
# this is "copying" the docstring from `flyte.task` to functions wrapped by `forge`
# more details here: https://github.com/python/typing/issues/270
def forge(source: Callable[Concatenate[Any, P], T]) -> Callable[[Callable], Callable[Concatenate[Any, P], T]]:
def wrapper(target: Callable) -> Callable[Concatenate[Any, P], T]:
@wraps(source)
def wrapped(self, *args: P.args, **kwargs: P.kwargs) -> T:
return target(self, *args, **kwargs)

return wrapped

return wrapper


def inherit(old: dict[str, Any], new: dict[str, Any]) -> dict[str, Any]:
out = copy.deepcopy(old)

for key, value in new.items():
if key in out:
if isinstance(value, dict):
out[key] = inherit(out[key], value)
else:
out[key] = value
else:
out[key] = value

return out


class Environment:
@forge(task)
def __init__(self, **overrides: Any) -> None:
_overrides: dict[str, Any] = {}
for key, value in overrides.items():
if key == "_task_function":
raise KeyError("Cannot override task function")

_overrides[key] = value

self.overrides = _overrides

@forge(task)
def update(self, **overrides: Any) -> None:
self.overrides = inherit(self.overrides, overrides)

@forge(task)
def extend(self, **overrides: Any) -> "Environment":
return self.__class__(**inherit(self.overrides, overrides))

@forge(task)
def __call__(
self, _task_function: Union[Callable, None] = None, /, **overrides
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
# no additional overrides are passed
if _task_function is not None:
if callable(_task_function):
return partial(task, **self.overrides)(_task_function)

else:
raise ValueError("The first positional argument must be a callable")

# additional overrides are passed
else:

def inner(_task_function: Callable) -> Callable:
inherited = inherit(self.overrides, overrides)

return partial(task, **inherited)(_task_function)

return inner

def show(self) -> None:
console = Console()

console.print(Panel.fit(Pretty(self.overrides)))

task = __call__

@forge(dynamic)
def dynamic(
self, _task_function: Union[Callable, None] = None, /, **overrides
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
# no additional overrides are passed
if _task_function is not None:
if callable(_task_function):
return partial(dynamic, **self.overrides)(_task_function)

else:
raise ValueError("The first positional argument must be a callable")

# additional overrides are passed
else:

def inner(_task_function: Callable) -> Callable:
inherited = inherit(self.overrides, overrides)

return partial(dynamic, **inherited)(_task_function)

return inner
23 changes: 23 additions & 0 deletions plugins/flytekit-spark/tests/test_environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import flytekit
from flytekitplugins.spark import Spark
from flytekit.core.environment import Environment


def test_spark_task():

env = Environment(
task_config=Spark(
spark_conf={"spark": "1"},
executor_path="/usr/bin/python3",
applications_path="local:///usr/local/bin/entrypoint.py",
)
)

@env.task
def my_spark(a: str) -> int:
session = flytekit.current_context().spark_session
assert session.sparkContext.appName == "FlyteSpark: ex:local:local:local"
return 10

assert my_spark.task_config is not None
assert my_spark.task_config.spark_conf == {"spark": "1"}
90 changes: 90 additions & 0 deletions tests/flytekit/unit/core/test_environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from flytekit.core.environment import Environment, inherit


def test_basic_environment():

env = Environment(retries=2)

@env.task
def foo():
pass

@env
def bar():
pass

assert foo._metadata.retries == 2
assert bar._metadata.retries == 2


def test_dynamic_from_environment():

env = Environment(retries=2)

@env.task
def foo():
pass

@env.dynamic
def bar():
foo()

assert foo._metadata.retries == 2
assert bar._metadata.retries == 2


def test_extended_environment():

env = Environment(retries=2)

other = env.extend(retries=0)

@other.task
def foo():
pass

@other
def bar():
pass


assert foo._metadata.retries == 0
assert bar._metadata.retries == 0


def test_updated_environment():

env = Environment(retries=2)

env.update(retries=0)

@env.task
def foo():
pass

@env
def bar():
pass


assert foo._metadata.retries == 0
assert bar._metadata.retries == 0


def test_show_environment():

env = Environment(retries=2)

env.show()


def test_inherit():

old_config = {"cache": False, "timeout": 10}

new_config = {"cache": True}

combined = inherit(old_config, new_config)

assert combined["cache"] == True
assert combined["timeout"] == 10
Empty file.

0 comments on commit f634d53

Please sign in to comment.