diff --git a/python/tests/test_hdf5dataset.py b/python/tests/test_hdf5dataset.py index 99e72901..485f2719 100644 --- a/python/tests/test_hdf5dataset.py +++ b/python/tests/test_hdf5dataset.py @@ -1,9 +1,11 @@ -import numpy +import os import time import h5py as h5 +import numpy -from odin_data.meta_writer.hdf5dataset import HDF5UnlimitedCache, StringHDF5Dataset +from odin_data.meta_writer.hdf5dataset import (HDF5UnlimitedCache, + StringHDF5Dataset) class _TestMockDataset: @@ -247,3 +249,61 @@ def test_unlimited_cache_2D(self): assert ds.values[0][6][1][0] == 4 assert ds.values[0][6][1][1] == 5 assert ds.values[0][6][1][2] == 6 + +def test_string_types(): + variable_utf = StringHDF5Dataset("variable_utf", encoding="utf-8", cache=False) + variable_ascii = StringHDF5Dataset("variable_ascii", encoding="ascii", cache=False) + fixed_utf = StringHDF5Dataset("fixed_utf", encoding="utf-8", length=9, cache=False) + fixed_ascii = StringHDF5Dataset( + "fixed_ascii", encoding="ascii", length=11, cache=False + ) + variable_utf_int = StringHDF5Dataset( + "variable_utf_int", encoding="utf-8", cache=False + ) + + with h5.File("/tmp/strings.h5", "w") as f: + variable_utf.initialise( + f.create_dataset( + "variable_utf", shape=(1,), maxshape=(1,), dtype=variable_utf.dtype + ), + 0, + ) + variable_ascii.initialise( + f.create_dataset( + "variable_ascii", shape=(1,), maxshape=(1,), dtype=variable_ascii.dtype + ), + 0, + ) + fixed_utf.initialise( + f.create_dataset( + "fixed_utf", shape=(1,), maxshape=(1,), dtype=fixed_utf.dtype + ), + 0, + ) + fixed_ascii.initialise( + f.create_dataset( + "fixed_ascii", shape=(1,), maxshape=(1,), dtype=fixed_ascii.dtype + ), + 0, + ) + variable_utf_int.initialise( + f.create_dataset( + "variable_utf_int", shape=(1,), maxshape=(1,), dtype=variable_utf.dtype + ), + 0, + ) + + variable_utf.write("variable_utf") + variable_ascii.write("variable_ascii") + fixed_utf.write("fixed_utf") + fixed_ascii.write("fixed_ascii") + variable_utf_int.write(2) # Check non-strings can be handled + + with h5.File("/tmp/strings.h5", "r") as f: + assert f["variable_utf"][0] == b"variable_utf" + assert f["variable_ascii"][0] == b"variable_ascii" + assert f["fixed_utf"][0] == b"fixed_utf" + assert f["fixed_ascii"][0] == b"fixed_ascii" + assert f["variable_utf_int"][0] == b"2" + + os.remove("/tmp/strings.h5")