Skip to content

Commit

Permalink
Remove Inject
Browse files Browse the repository at this point in the history
  • Loading branch information
DiamondJoseph committed Oct 14, 2024
1 parent 9cb9596 commit b6c942b
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 43 deletions.
15 changes: 11 additions & 4 deletions src/dodal/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
from .coordination import group_uuid, inject
from .coordination import group_uuid
from .maths import in_micros, step_to_num
from .types import MsgGenerator, PlanGenerator
from .types import PlanGenerator

__all__ = [
"group_uuid",
"inject",
"in_micros",
"MsgGenerator",
"PlanGenerator",
"step_to_num",
]


def __getattr__(name):
if name == "MsgGenerator":
raise DeprecationWarning("import from bluesky.utils instead")
if name == "inject":
raise DeprecationWarning("Use result of dodal device call instead")

return globals()[name]
23 changes: 0 additions & 23 deletions src/dodal/common/coordination.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,3 @@ def group_uuid(name: str) -> Group:
readable_uid (Group): name appended with a unique string
"""
return f"{name}-{str(uuid.uuid4())[:6]}"


def inject(name: str) -> Any: # type: ignore
"""
Function to mark a defaulted argument of a plan as a reference to a device stored
in another context and not available to be referenced directly.
Bypasses type checking, returning x as Any and therefore valid as a default
argument, leaving handling to the context from which the plan is called.
Assumes that device.name is unique.
e.g. For a 1-dimensional scan, that is usually performed on a Movable with
name "stage_x"
def scan(x: Movable = inject("stage_x"), start: float = 0.0 ...)
Args:
name (str): Name of a Device to be fetched from an external context
Returns:
Any: name but without typing checking, valid as any default type
"""

return name
41 changes: 41 additions & 0 deletions src/dodal/plans/connect_devices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from collections.abc import Callable, Collection, Mapping, Sequence
from typing import ParamSpec

from bluesky.utils import MsgGenerator, make_decorator
from ophyd_async.core import DEFAULT_TIMEOUT, Device
from ophyd_async.plan_stubs import ensure_connected

P = ParamSpec("P")


def _recursively_find_devices(obj) -> set[Device]:
if isinstance(obj, Device):
return {obj}
if isinstance(obj, Sequence | Collection):
return {dev for arg in obj for dev in _recursively_find_devices(arg)}
if isinstance(obj, Mapping):
return {
dev
for key, value in obj.items()
for dev in _recursively_find_devices(key) | _recursively_find_devices(value)
}
return set()


def ensure_devices_connected(
plan: Callable[P, MsgGenerator],
mock: bool = False,
timeout: float = DEFAULT_TIMEOUT,
force_reconnect: bool = False,
) -> Callable[P, MsgGenerator]:
def plan_with_connected_devices(*args: P.args, **kwargs: P.kwargs):
devices = _recursively_find_devices(args) | _recursively_find_devices(kwargs)
yield from ensure_connected(
*devices, mock=mock, timeout=timeout, force_reconnect=force_reconnect
)
yield from plan(*args, **kwargs)

return plan_with_connected_devices


ensure_devices_connected_decorator = make_decorator(ensure_devices_connected)
17 changes: 1 addition & 16 deletions tests/common/test_coordination.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
from inspect import Parameter, signature

import pytest
from bluesky.protocols import Movable
from bluesky.utils import MsgGenerator

from dodal.common.coordination import group_uuid, inject
from dodal.common.coordination import group_uuid

static_uuid = "51aef931-33b4-4b33-b7ad-a8287f541202"

Expand All @@ -14,14 +10,3 @@ def test_group_uid(group: str):
gid = group_uuid(group)
assert gid.startswith(f"{group}-")
assert not gid.endswith(f"{group}-")


def test_type_checking_ignores_inject():
def example_function(x: Movable = inject("foo")) -> MsgGenerator: # noqa: B008
yield from {}

# These asserts are sanity checks
# the real test is whether this test passes type checking
x: Parameter = signature(example_function).parameters["x"]
assert x.annotation == Movable
assert x.default == "foo"

0 comments on commit b6c942b

Please sign in to comment.