diff --git a/src/ophyd_async/core/__init__.py b/src/ophyd_async/core/__init__.py index 3f88752fdd..71d8ae76d1 100644 --- a/src/ophyd_async/core/__init__.py +++ b/src/ophyd_async/core/__init__.py @@ -63,7 +63,7 @@ ) from ._signal_backend import RuntimeSubsetEnum, SignalBackend, SubsetEnum from ._soft_signal_backend import SignalMetadata, SoftSignalBackend -from ._status import AsyncStatus, WatchableAsyncStatus +from ._status import AsyncStatus, WatchableAsyncStatus, completed_status from ._utils import ( DEFAULT_TIMEOUT, CalculatableTimeout, @@ -158,4 +158,5 @@ "get_unique", "in_micros", "wait_for_connection", + "completed_status", ] diff --git a/src/ophyd_async/core/_status.py b/src/ophyd_async/core/_status.py index d2d285e675..ca35362ace 100644 --- a/src/ophyd_async/core/_status.py +++ b/src/ophyd_async/core/_status.py @@ -4,7 +4,16 @@ import functools import time from dataclasses import asdict, replace -from typing import AsyncIterator, Awaitable, Callable, Generic, Type, TypeVar, cast +from typing import ( + AsyncIterator, + Awaitable, + Callable, + Generic, + Optional, + Type, + TypeVar, + cast, +) from bluesky.protocols import Status @@ -132,3 +141,10 @@ def wrap_f(*args: P.args, **kwargs: P.kwargs) -> WAS: return cls(f(*args, **kwargs)) return cast(Callable[P, WAS], wrap_f) + + +@AsyncStatus.wrap +async def completed_status(exception: Optional[Exception] = None): + if exception: + raise exception + return None diff --git a/tests/core/test_status.py b/tests/core/test_status.py index 6e8da751c9..f8af39bbc4 100644 --- a/tests/core/test_status.py +++ b/tests/core/test_status.py @@ -7,7 +7,7 @@ from bluesky.protocols import Movable, Status from bluesky.utils import FailedStatus -from ophyd_async.core import AsyncStatus, Device +from ophyd_async.core import AsyncStatus, Device, completed_status async def test_async_status_success(): @@ -188,3 +188,9 @@ async def coro_status(x: int, y: int, *, z=False): assert test_result == 12 loop.run_until_complete(do_test()) + + +async def test_completed_status(): + with pytest.raises(ValueError): + await completed_status(ValueError()) + await completed_status()