diff --git a/CHANGELOG.md b/CHANGELOG.md index 00807c5..4e41d90 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Security +## 0.3.3 + +Released on December 6th, 2023. + +### Added + +- Use perfect task name as Ray remote function name - [#103](https://github.com/PrefectHQ/prefect-ray/pull/103) + ## 0.2.6 Released on November 29th, 2023. diff --git a/prefect_ray/task_runners.py b/prefect_ray/task_runners.py index 1d35f2e..39a79a1 100644 --- a/prefect_ray/task_runners.py +++ b/prefect_ray/task_runners.py @@ -173,8 +173,10 @@ async def submit( else: ray_decorator = ray.remote - self._ray_refs[key] = ray_decorator(self._run_prefect_task).remote( - sync_compatible(call.func), *upstream_ray_obj_refs, **call_kwargs + self._ray_refs[key] = ( + ray_decorator(self._run_prefect_task) + .options(name=call.keywords["task_run"].name) + .remote(sync_compatible(call.func), *upstream_ray_obj_refs, **call_kwargs) ) def _exchange_prefect_for_ray_futures(self, kwargs_prefect_futures): diff --git a/tests/test_task_runners.py b/tests/test_task_runners.py index 9cd8815..98f17b0 100644 --- a/tests/test_task_runners.py +++ b/tests/test_task_runners.py @@ -203,18 +203,20 @@ async def test_wait_captures_exceptions_as_crashed_state( lack of re-raise here than the equality of the exception. """ - async def fake_orchestrate_task_run(): + async def fake_orchestrate_task_run(task_run): raise exception - test_key = uuid4() + task_run = TaskRun( + flow_run_id=uuid4(), task_key=str(uuid4()), dynamic_key="bar" + ) async with task_runner.start(): await task_runner.submit( - call=partial(fake_orchestrate_task_run), - key=test_key, + call=partial(fake_orchestrate_task_run, task_run=task_run), + key=task_run.id, ) - state = await task_runner.wait(test_key, 5) + state = await task_runner.wait(task_run.id, 5) assert state is not None, "wait timed out" assert isinstance(state, State), "wait should return a state" assert state.name == "Crashed" @@ -333,7 +335,7 @@ async def test_submit_and_wait(self, task_runner): task_run = TaskRun(flow_run_id=uuid4(), task_key="foo", dynamic_key="bar") - async def fake_orchestrate_task_run(example_kwarg): + async def fake_orchestrate_task_run(example_kwarg, task_run): return State( type=StateType.COMPLETED, data=example_kwarg, @@ -342,7 +344,9 @@ async def fake_orchestrate_task_run(example_kwarg): async with task_runner.start(): await task_runner.submit( key=task_run.id, - call=partial(fake_orchestrate_task_run, example_kwarg=1), + call=partial( + fake_orchestrate_task_run, task_run=task_run, example_kwarg=1 + ), ) state = await task_runner.wait(task_run.id, MAX_WAIT_TIME) assert state is not None, "wait timed out"