diff --git a/python/src/odin_data/meta_writer/hdf5dataset.py b/python/src/odin_data/meta_writer/hdf5dataset.py index 9982fff46..dc06c7fa4 100644 --- a/python/src/odin_data/meta_writer/hdf5dataset.py +++ b/python/src/odin_data/meta_writer/hdf5dataset.py @@ -1,10 +1,15 @@ import logging from time import time +from typing import Optional import h5py as h5 import numpy as np +def units(unit: str): + return {"units": unit} + + class HDF5CacheBlock(object): def __init__(self, shape, fillvalue, dtype): self.has_new_data = True @@ -136,6 +141,7 @@ def __init__( cache=True, block_size=1000000, block_timeout=600, + attributes=None ): """ Args: @@ -150,7 +156,7 @@ def __init__( or write directly to file block_size(int): See HDF5UnlimitedCache block_timeout(int): See HDF5UnlimitedCache - + attributes(dict): A dict of attribute names and values to add to the dataset """ self.name = name self.dtype = dtype @@ -182,7 +188,9 @@ def __init__( block_timeout=block_timeout, ) - self._h5py_dataset = None # h5py.Dataset + self.attributes = attributes or dict() + + self._h5py_dataset: Optional[h5.Dataset] = None self._is_written = False self._logger = logging.getLogger("HDF5Dataset") @@ -221,6 +229,10 @@ def add_value(self, value, offset=None): else: self._cache.add_value(value, offset) + def _add_attributes(self): + for attribute, value in self.attributes.items(): + self._h5py_dataset.attrs[attribute] = value + def write(self, data): """Write the entire dataset with the given data @@ -255,6 +267,9 @@ def write(self, data): return self._h5py_dataset[...] = data + + self._add_attributes() + self._is_written = True def prepare_data(self, data): diff --git a/python/tests/test_hdf5dataset.py b/python/tests/test_hdf5dataset.py index 615d3a87a..a0154efdf 100644 --- a/python/tests/test_hdf5dataset.py +++ b/python/tests/test_hdf5dataset.py @@ -7,7 +7,9 @@ from odin_data.meta_writer.hdf5dataset import ( HDF5UnlimitedCache, StringHDF5Dataset, + Int32HDF5Dataset ) +import tempfile class _TestMockDataset: @@ -264,7 +266,8 @@ def test_string_types(): "variable_utf_int", encoding="utf-8", cache=False ) - with h5.File("/tmp/strings.h5", "w") as f: + temp_file = tempfile.TemporaryFile() + with h5.File(temp_file, "w") as f: variable_utf.initialise( f.create_dataset( "variable_utf", shape=(1,), maxshape=(1,), dtype=variable_utf.dtype @@ -302,11 +305,29 @@ def test_string_types(): fixed_ascii.write("fixed_ascii") variable_utf_int.write(2) # Check non-strings can be handled - with h5.File("/tmp/strings.h5", "r") as f: + with h5.File(temp_file, "r") as f: assert f["variable_utf"][0] == b"variable_utf" assert f["variable_ascii"][0] == b"variable_ascii" assert f["fixed_utf"][0] == b"fixed_utf" assert f["fixed_ascii"][0] == b"fixed_ascii" assert f["variable_utf_int"][0] == b"2" - os.remove("/tmp/strings.h5") + +def test_attributes_added(): + dataset_with_units = Int32HDF5Dataset("test_dataset", cache=False, attributes={"units": "m"}) + + temp_file = tempfile.TemporaryFile() + with h5.File(temp_file, "w") as f: + dataset_with_units.initialise( + f.create_dataset( + "test_dataset", shape=(1,), maxshape=(1,), dtype=dataset_with_units.dtype + ), + 0, + ) + + dataset_with_units.write(2) + + with h5.File(temp_file, "r") as f: + attributes = f["test_dataset"].attrs + assert len(attributes.values()) == 1 + assert attributes["units"] == "m"