Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
j-tr committed Nov 13, 2023
1 parent a90a195 commit 013c18e
Showing 1 changed file with 21 additions and 22 deletions.
43 changes: 21 additions & 22 deletions prefect_ray/task_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,29 +173,8 @@ async def submit(
else:
ray_decorator = ray.remote

def _run_prefect_task(func, *upstream_ray_obj_refs, **kwargs):
"""Resolves Ray futures before calling the actual Prefect task function.
Passing upstream_ray_obj_refs directly as args enables Ray to wait for
upstream tasks before running this remote function.
This variable is otherwise unused as the ray object refs are also
contained in kwargs.
"""

def resolve_ray_future(expr):
"""Resolves Ray future."""
if isinstance(expr, ray.ObjectRef):
return ray.get(expr)
return expr

kwargs = visit_collection(
kwargs, visit_fn=resolve_ray_future, return_data=True
)

return func(**kwargs)

self._ray_refs[key] = (
ray_decorator(_run_prefect_task)
ray_decorator(self._run_prefect_task)
.options(call.keywords["task_run"].name)
.remote(sync_compatible(call.func), *upstream_ray_obj_refs, **call_kwargs)
)
Expand All @@ -222,6 +201,26 @@ def exchange_prefect_for_ray_future(expr):

return kwargs_ray_futures, upstream_ray_obj_refs

@staticmethod
def _run_prefect_task(func, *upstream_ray_obj_refs, **kwargs):
"""Resolves Ray futures before calling the actual Prefect task function.
Passing upstream_ray_obj_refs directly as args enables Ray to wait for
upstream tasks before running this remote function.
This variable is otherwise unused as the ray object refs are also
contained in kwargs.
"""

def resolve_ray_future(expr):
"""Resolves Ray future."""
if isinstance(expr, ray.ObjectRef):
return ray.get(expr)
return expr

kwargs = visit_collection(kwargs, visit_fn=resolve_ray_future, return_data=True)

return func(**kwargs)

async def wait(self, key: UUID, timeout: float = None) -> Optional[State]:
ref = self._get_ray_ref(key)

Expand Down

0 comments on commit 013c18e

Please sign in to comment.