From 013c18e8a2209b8012a2b42f0105e9a77da33fb1 Mon Sep 17 00:00:00 2001 From: Justin Trautmann Date: Mon, 13 Nov 2023 16:07:53 +0100 Subject: [PATCH] tests --- prefect_ray/task_runners.py | 43 ++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/prefect_ray/task_runners.py b/prefect_ray/task_runners.py index 3ce4dde..641d9e6 100644 --- a/prefect_ray/task_runners.py +++ b/prefect_ray/task_runners.py @@ -173,29 +173,8 @@ async def submit( else: ray_decorator = ray.remote - def _run_prefect_task(func, *upstream_ray_obj_refs, **kwargs): - """Resolves Ray futures before calling the actual Prefect task function. - - Passing upstream_ray_obj_refs directly as args enables Ray to wait for - upstream tasks before running this remote function. - This variable is otherwise unused as the ray object refs are also - contained in kwargs. - """ - - def resolve_ray_future(expr): - """Resolves Ray future.""" - if isinstance(expr, ray.ObjectRef): - return ray.get(expr) - return expr - - kwargs = visit_collection( - kwargs, visit_fn=resolve_ray_future, return_data=True - ) - - return func(**kwargs) - self._ray_refs[key] = ( - ray_decorator(_run_prefect_task) + ray_decorator(self._run_prefect_task) .options(call.keywords["task_run"].name) .remote(sync_compatible(call.func), *upstream_ray_obj_refs, **call_kwargs) ) @@ -222,6 +201,26 @@ def exchange_prefect_for_ray_future(expr): return kwargs_ray_futures, upstream_ray_obj_refs + @staticmethod + def _run_prefect_task(func, *upstream_ray_obj_refs, **kwargs): + """Resolves Ray futures before calling the actual Prefect task function. + + Passing upstream_ray_obj_refs directly as args enables Ray to wait for + upstream tasks before running this remote function. + This variable is otherwise unused as the ray object refs are also + contained in kwargs. + """ + + def resolve_ray_future(expr): + """Resolves Ray future.""" + if isinstance(expr, ray.ObjectRef): + return ray.get(expr) + return expr + + kwargs = visit_collection(kwargs, visit_fn=resolve_ray_future, return_data=True) + + return func(**kwargs) + async def wait(self, key: UUID, timeout: float = None) -> Optional[State]: ref = self._get_ray_ref(key)