Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
use prefect task name as remote function name (#103)
Browse files Browse the repository at this point in the history
* task name

* using options

* changelog

* Update CHANGELOG.md

---------

Co-authored-by: Alexander Streed <desertaxle@users.noreply.github.com>
  • Loading branch information
j-tr and desertaxle authored Dec 6, 2023
1 parent b271115 commit a5a82c7
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 9 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Security

## 0.3.3

Released on December 6th, 2023.

### Added

- Use perfect task name as Ray remote function name - [#103](https://github.com/PrefectHQ/prefect-ray/pull/103)

## 0.2.6

Released on November 29th, 2023.
Expand Down
6 changes: 4 additions & 2 deletions prefect_ray/task_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,10 @@ async def submit(
else:
ray_decorator = ray.remote

self._ray_refs[key] = ray_decorator(self._run_prefect_task).remote(
sync_compatible(call.func), *upstream_ray_obj_refs, **call_kwargs
self._ray_refs[key] = (
ray_decorator(self._run_prefect_task)
.options(name=call.keywords["task_run"].name)
.remote(sync_compatible(call.func), *upstream_ray_obj_refs, **call_kwargs)
)

def _exchange_prefect_for_ray_futures(self, kwargs_prefect_futures):
Expand Down
18 changes: 11 additions & 7 deletions tests/test_task_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,18 +203,20 @@ async def test_wait_captures_exceptions_as_crashed_state(
lack of re-raise here than the equality of the exception.
"""

async def fake_orchestrate_task_run():
async def fake_orchestrate_task_run(task_run):
raise exception

test_key = uuid4()
task_run = TaskRun(
flow_run_id=uuid4(), task_key=str(uuid4()), dynamic_key="bar"
)

async with task_runner.start():
await task_runner.submit(
call=partial(fake_orchestrate_task_run),
key=test_key,
call=partial(fake_orchestrate_task_run, task_run=task_run),
key=task_run.id,
)

state = await task_runner.wait(test_key, 5)
state = await task_runner.wait(task_run.id, 5)
assert state is not None, "wait timed out"
assert isinstance(state, State), "wait should return a state"
assert state.name == "Crashed"
Expand Down Expand Up @@ -333,7 +335,7 @@ async def test_submit_and_wait(self, task_runner):

task_run = TaskRun(flow_run_id=uuid4(), task_key="foo", dynamic_key="bar")

async def fake_orchestrate_task_run(example_kwarg):
async def fake_orchestrate_task_run(example_kwarg, task_run):
return State(
type=StateType.COMPLETED,
data=example_kwarg,
Expand All @@ -342,7 +344,9 @@ async def fake_orchestrate_task_run(example_kwarg):
async with task_runner.start():
await task_runner.submit(
key=task_run.id,
call=partial(fake_orchestrate_task_run, example_kwarg=1),
call=partial(
fake_orchestrate_task_run, task_run=task_run, example_kwarg=1
),
)
state = await task_runner.wait(task_run.id, MAX_WAIT_TIME)
assert state is not None, "wait timed out"
Expand Down

0 comments on commit a5a82c7

Please sign in to comment.