-
Notifications
You must be signed in to change notification settings - Fork 302
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Grantham Taylor <granthamtaylor@icloud.com>
- Loading branch information
1 parent
f1e5339
commit f634d53
Showing
5 changed files
with
229 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
import copy | ||
from functools import partial, wraps | ||
from typing import Any, Callable, TypeVar, Union | ||
|
||
from rich.console import Console | ||
from rich.panel import Panel | ||
from rich.pretty import Pretty | ||
from typing_extensions import Concatenate, ParamSpec | ||
|
||
from flytekit.core.dynamic_workflow_task import dynamic | ||
from flytekit.core.task import task | ||
|
||
P = ParamSpec("P") | ||
T = TypeVar("T") | ||
|
||
|
||
# basically, I want the docstring for `flyte.task` to be available for users to see | ||
# this is "copying" the docstring from `flyte.task` to functions wrapped by `forge` | ||
# more details here: https://github.com/python/typing/issues/270 | ||
def forge(source: Callable[Concatenate[Any, P], T]) -> Callable[[Callable], Callable[Concatenate[Any, P], T]]: | ||
def wrapper(target: Callable) -> Callable[Concatenate[Any, P], T]: | ||
@wraps(source) | ||
def wrapped(self, *args: P.args, **kwargs: P.kwargs) -> T: | ||
return target(self, *args, **kwargs) | ||
|
||
return wrapped | ||
|
||
return wrapper | ||
|
||
|
||
def inherit(old: dict[str, Any], new: dict[str, Any]) -> dict[str, Any]: | ||
out = copy.deepcopy(old) | ||
|
||
for key, value in new.items(): | ||
if key in out: | ||
if isinstance(value, dict): | ||
out[key] = inherit(out[key], value) | ||
else: | ||
out[key] = value | ||
else: | ||
out[key] = value | ||
|
||
return out | ||
|
||
|
||
class Environment: | ||
@forge(task) | ||
def __init__(self, **overrides: Any) -> None: | ||
_overrides: dict[str, Any] = {} | ||
for key, value in overrides.items(): | ||
if key == "_task_function": | ||
raise KeyError("Cannot override task function") | ||
|
||
_overrides[key] = value | ||
|
||
self.overrides = _overrides | ||
|
||
@forge(task) | ||
def update(self, **overrides: Any) -> None: | ||
self.overrides = inherit(self.overrides, overrides) | ||
|
||
@forge(task) | ||
def extend(self, **overrides: Any) -> "Environment": | ||
return self.__class__(**inherit(self.overrides, overrides)) | ||
|
||
@forge(task) | ||
def __call__( | ||
self, _task_function: Union[Callable, None] = None, /, **overrides | ||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]: | ||
# no additional overrides are passed | ||
if _task_function is not None: | ||
if callable(_task_function): | ||
return partial(task, **self.overrides)(_task_function) | ||
|
||
else: | ||
raise ValueError("The first positional argument must be a callable") | ||
|
||
# additional overrides are passed | ||
else: | ||
|
||
def inner(_task_function: Callable) -> Callable: | ||
inherited = inherit(self.overrides, overrides) | ||
|
||
return partial(task, **inherited)(_task_function) | ||
|
||
return inner | ||
|
||
def show(self) -> None: | ||
console = Console() | ||
|
||
console.print(Panel.fit(Pretty(self.overrides))) | ||
|
||
task = __call__ | ||
|
||
@forge(dynamic) | ||
def dynamic( | ||
self, _task_function: Union[Callable, None] = None, /, **overrides | ||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]: | ||
# no additional overrides are passed | ||
if _task_function is not None: | ||
if callable(_task_function): | ||
return partial(dynamic, **self.overrides)(_task_function) | ||
|
||
else: | ||
raise ValueError("The first positional argument must be a callable") | ||
|
||
# additional overrides are passed | ||
else: | ||
|
||
def inner(_task_function: Callable) -> Callable: | ||
inherited = inherit(self.overrides, overrides) | ||
|
||
return partial(dynamic, **inherited)(_task_function) | ||
|
||
return inner |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import flytekit | ||
from flytekitplugins.spark import Spark | ||
from flytekit.core.environment import Environment | ||
|
||
|
||
def test_spark_task(): | ||
|
||
env = Environment( | ||
task_config=Spark( | ||
spark_conf={"spark": "1"}, | ||
executor_path="/usr/bin/python3", | ||
applications_path="local:///usr/local/bin/entrypoint.py", | ||
) | ||
) | ||
|
||
@env.task | ||
def my_spark(a: str) -> int: | ||
session = flytekit.current_context().spark_session | ||
assert session.sparkContext.appName == "FlyteSpark: ex:local:local:local" | ||
return 10 | ||
|
||
assert my_spark.task_config is not None | ||
assert my_spark.task_config.spark_conf == {"spark": "1"} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
from flytekit.core.environment import Environment, inherit | ||
|
||
|
||
def test_basic_environment(): | ||
|
||
env = Environment(retries=2) | ||
|
||
@env.task | ||
def foo(): | ||
pass | ||
|
||
@env | ||
def bar(): | ||
pass | ||
|
||
assert foo._metadata.retries == 2 | ||
assert bar._metadata.retries == 2 | ||
|
||
|
||
def test_dynamic_from_environment(): | ||
|
||
env = Environment(retries=2) | ||
|
||
@env.task | ||
def foo(): | ||
pass | ||
|
||
@env.dynamic | ||
def bar(): | ||
foo() | ||
|
||
assert foo._metadata.retries == 2 | ||
assert bar._metadata.retries == 2 | ||
|
||
|
||
def test_extended_environment(): | ||
|
||
env = Environment(retries=2) | ||
|
||
other = env.extend(retries=0) | ||
|
||
@other.task | ||
def foo(): | ||
pass | ||
|
||
@other | ||
def bar(): | ||
pass | ||
|
||
|
||
assert foo._metadata.retries == 0 | ||
assert bar._metadata.retries == 0 | ||
|
||
|
||
def test_updated_environment(): | ||
|
||
env = Environment(retries=2) | ||
|
||
env.update(retries=0) | ||
|
||
@env.task | ||
def foo(): | ||
pass | ||
|
||
@env | ||
def bar(): | ||
pass | ||
|
||
|
||
assert foo._metadata.retries == 0 | ||
assert bar._metadata.retries == 0 | ||
|
||
|
||
def test_show_environment(): | ||
|
||
env = Environment(retries=2) | ||
|
||
env.show() | ||
|
||
|
||
def test_inherit(): | ||
|
||
old_config = {"cache": False, "timeout": 10} | ||
|
||
new_config = {"cache": True} | ||
|
||
combined = inherit(old_config, new_config) | ||
|
||
assert combined["cache"] == True | ||
assert combined["timeout"] == 10 |
Empty file.