Skip to content

Commit

Permalink
Refactor ShapeProvider into DatasetDescriber (#555)
Browse files Browse the repository at this point in the history
* add np_datatype property instead of using final axis
  • Loading branch information
DiamondJoseph authored Sep 4, 2024
1 parent 27ad5b9 commit ab4dc46
Show file tree
Hide file tree
Showing 13 changed files with 39 additions and 38 deletions.
2 changes: 1 addition & 1 deletion docs/examples/foo_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(
self.hdf,
path_provider,
lambda: self.name,
adcore.ADBaseShapeProvider(self.drv),
adcore.ADBaseDatasetDescriber(self.drv),
),
config_sigs=(self.drv.acquire_time,),
name=name,
Expand Down
4 changes: 2 additions & 2 deletions src/ophyd_async/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@
from ._providers import (
AutoIncrementFilenameProvider,
AutoIncrementingPathProvider,
DatasetDescriber,
FilenameProvider,
NameProvider,
PathInfo,
PathProvider,
ShapeProvider,
StaticFilenameProvider,
StaticPathProvider,
UUIDFilenameProvider,
Expand Down Expand Up @@ -117,7 +117,7 @@
"NameProvider",
"PathInfo",
"PathProvider",
"ShapeProvider",
"DatasetDescriber",
"StaticFilenameProvider",
"StaticPathProvider",
"UUIDFilenameProvider",
Expand Down
8 changes: 6 additions & 2 deletions src/ophyd_async/core/_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,11 @@ def __call__(self) -> str:
"""Get the name to be used as a data_key in the descriptor document"""


class ShapeProvider(Protocol):
class DatasetDescriber(Protocol):
@abstractmethod
async def __call__(self) -> tuple:
async def np_datatype(self) -> str:
"""Represents the numpy datatype"""

@abstractmethod
async def shape(self) -> tuple[int, ...]:
"""Get the shape of the data collection"""
2 changes: 1 addition & 1 deletion src/ophyd_async/epics/adaravis/_aravis.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
self.hdf,
path_provider,
lambda: self.name,
adcore.ADBaseShapeProvider(self.drv),
adcore.ADBaseDatasetDescriber(self.drv),
),
config_sigs=(self.drv.acquire_time,),
name=name,
Expand Down
4 changes: 2 additions & 2 deletions src/ophyd_async/epics/adcore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
)
from ._core_logic import (
DEFAULT_GOOD_STATES,
ADBaseShapeProvider,
ADBaseDatasetDescriber,
set_exposure_time_and_acquire_period_if_supplied,
start_acquiring_driver_and_ensure_status,
)
Expand All @@ -31,7 +31,7 @@
"NDFileHDFIO",
"NDPluginStatsIO",
"DEFAULT_GOOD_STATES",
"ADBaseShapeProvider",
"ADBaseDatasetDescriber",
"set_exposure_time_and_acquire_period_if_supplied",
"start_acquiring_driver_and_ensure_status",
"ADHDFWriter",
Expand Down
11 changes: 7 additions & 4 deletions src/ophyd_async/epics/adcore/_core_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from ophyd_async.core import (
DEFAULT_TIMEOUT,
AsyncStatus,
DatasetDescriber,
DetectorControl,
ShapeProvider,
set_and_wait_for_value,
)
from ophyd_async.epics.adcore._utils import convert_ad_dtype_to_np

from ._core_io import ADBaseIO, DetectorState

Expand All @@ -18,15 +19,17 @@
)


class ADBaseShapeProvider(ShapeProvider):
class ADBaseDatasetDescriber(DatasetDescriber):
def __init__(self, driver: ADBaseIO) -> None:
self._driver = driver

async def __call__(self) -> tuple:
async def np_datatype(self) -> str:
return convert_ad_dtype_to_np(await self._driver.data_type.get_value())

async def shape(self) -> tuple[int, int]:
shape = await asyncio.gather(
self._driver.array_size_y.get_value(),
self._driver.array_size_x.get_value(),
self._driver.data_type.get_value(),
)
return shape

Expand Down
20 changes: 7 additions & 13 deletions src/ophyd_async/epics/adcore/_hdf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from ophyd_async.core import (
DEFAULT_TIMEOUT,
AsyncStatus,
DatasetDescriber,
DetectorWriter,
HDFDataset,
HDFFile,
NameProvider,
PathProvider,
ShapeProvider,
observe_value,
set_and_wait_for_value,
wait_for_value,
Expand All @@ -22,7 +22,6 @@
from ._core_io import NDArrayBaseIO, NDFileHDFIO
from ._utils import (
FileWriteMode,
convert_ad_dtype_to_np,
convert_param_dtype_to_np,
convert_pv_dtype_to_np,
)
Expand All @@ -34,13 +33,13 @@ def __init__(
hdf: NDFileHDFIO,
path_provider: PathProvider,
name_provider: NameProvider,
shape_provider: ShapeProvider,
dataset_describer: DatasetDescriber,
*plugins: NDArrayBaseIO,
) -> None:
self.hdf = hdf
self._path_provider = path_provider
self._name_provider = name_provider
self._shape_provider = shape_provider
self._dataset_describer = dataset_describer

self._plugins = plugins
self._capture_status: Optional[AsyncStatus] = None
Expand Down Expand Up @@ -79,23 +78,18 @@ async def open(self, multiplier: int = 1) -> Dict[str, DataKey]:
# Wait for it to start, stashing the status that tells us when it finishes
self._capture_status = await set_and_wait_for_value(self.hdf.capture, True)
name = self._name_provider()
detector_shape = tuple(await self._shape_provider())
detector_shape = await self._dataset_describer.shape()
np_dtype = await self._dataset_describer.np_datatype()
self._multiplier = multiplier
outer_shape = (multiplier,) if multiplier > 1 else ()
frame_shape = detector_shape[:-1] if len(detector_shape) > 0 else []
dtype_numpy = (
convert_ad_dtype_to_np(detector_shape[-1])
if len(detector_shape) > 0
else ""
)

# Add the main data
self._datasets = [
HDFDataset(
data_key=name,
dataset="/entry/data/data",
shape=frame_shape,
dtype_numpy=dtype_numpy,
shape=detector_shape,
dtype_numpy=np_dtype,
multiplier=multiplier,
)
]
Expand Down
2 changes: 1 addition & 1 deletion src/ophyd_async/epics/adkinetix/_kinetix.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
self.hdf,
path_provider,
lambda: self.name,
adcore.ADBaseShapeProvider(self.drv),
adcore.ADBaseDatasetDescriber(self.drv),
),
config_sigs=(self.drv.acquire_time,),
name=name,
Expand Down
2 changes: 1 addition & 1 deletion src/ophyd_async/epics/adpilatus/_pilatus.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(
self.hdf,
path_provider,
lambda: self.name,
adcore.ADBaseShapeProvider(self.drv),
adcore.ADBaseDatasetDescriber(self.drv),
),
config_sigs=(self.drv.acquire_time,),
name=name,
Expand Down
2 changes: 1 addition & 1 deletion src/ophyd_async/epics/adsimdetector/_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(
self.hdf,
path_provider,
lambda: self.name,
adcore.ADBaseShapeProvider(self.drv),
adcore.ADBaseDatasetDescriber(self.drv),
),
config_sigs=config_sigs,
name=name,
Expand Down
2 changes: 1 addition & 1 deletion src/ophyd_async/epics/advimba/_vimba.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
self.hdf,
path_provider,
lambda: self.name,
adcore.ADBaseShapeProvider(self.drv),
adcore.ADBaseDatasetDescriber(self.drv),
),
config_sigs=(self.drv.acquire_time,),
name=name,
Expand Down
2 changes: 1 addition & 1 deletion tests/epics/adcore/test_scans.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def writer(RE, static_path_provider, tmp_path: Path) -> adcore.ADHDFWriter:
hdf,
path_provider=static_path_provider,
name_provider=lambda: "test",
shape_provider=AsyncMock(),
dataset_describer=AsyncMock(),
)


Expand Down
16 changes: 8 additions & 8 deletions tests/epics/adcore/test_writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import pytest

from ophyd_async.core import (
DatasetDescriber,
DeviceCollector,
PathProvider,
ShapeProvider,
StandardDetector,
StaticPathProvider,
)
Expand All @@ -16,12 +16,12 @@
from ophyd_async.plan_stubs._nd_attributes import setup_ndattributes, setup_ndstats_sum


class DummyShapeProvider(ShapeProvider):
def __init__(self) -> None:
pass
class DummyDatasetDescriber(DatasetDescriber):
async def np_datatype(self) -> str:
return "<u2"

async def __call__(self) -> tuple:
return (10, 10, adcore.ADBaseDataType.UInt16)
async def shape(self) -> tuple[int, int]:
return (10, 10)


@pytest.fixture
Expand All @@ -35,7 +35,7 @@ async def hdf_writer(
hdf,
static_path_provider,
lambda: "test",
DummyShapeProvider(),
DummyDatasetDescriber(),
)


Expand All @@ -51,7 +51,7 @@ async def hdf_writer_with_stats(
hdf,
static_path_provider,
lambda: "test",
DummyShapeProvider(),
DummyDatasetDescriber(),
stats,
)

Expand Down

0 comments on commit ab4dc46

Please sign in to comment.