Skip to content

Commit

Permalink
Merge pull request #15 from p2p-ld/hdf5-datetime
Browse files Browse the repository at this point in the history
Support datetimes in hdf5 proxies
  • Loading branch information
sneakers-the-rat authored Sep 4, 2024
2 parents c46015d + 9d31ecf commit 2c625e4
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 12 deletions.
5 changes: 5 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@

### 1.5.*

#### 1.5.2 - 24-09-03 - `datetime` support for HDF5

- [#15](https://github.com/p2p-ld/numpydantic/pull/15): Datetimes are supported as
dtype annotations for HDF5 arrays when encoded as `S32` isoformatted byte strings

#### 1.5.1 - 24-09-03 - Fix revalidation with proxy classes

Bugfix:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "numpydantic"
version = "1.5.1"
version = "1.5.2"
description = "Type and shape validation and serialization for arbitrary array types in pydantic models"
authors = [
{name = "sneakers-the-rat", email = "sneakers-the-rat@protonmail.com"},
Expand Down
74 changes: 66 additions & 8 deletions src/numpydantic/interface/hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,27 @@
To have direct access to the hdf5 dataset, use the
:meth:`.H5Proxy.open` method.
Datetimes
---------
Datetimes are supported as a dtype annotation, but currently they must be stored
as ``S32`` isoformatted byte strings (timezones optional) like:
.. code-block:: python
import h5py
from datetime import datetime
import numpy as np
data = np.array([datetime.now().isoformat().encode('utf-8')], dtype="S32")
h5f = h5py.File('test.hdf5', 'w')
h5f.create_dataset('data', data=data)
"""

import sys
from datetime import datetime
from pathlib import Path
from typing import Any, List, NamedTuple, Optional, Tuple, Union
from typing import Any, Iterable, List, NamedTuple, Optional, Tuple, TypeVar, Union

import numpy as np
from pydantic import SerializationInfo
Expand All @@ -46,6 +62,8 @@

H5Arraylike: TypeAlias = Tuple[Union[Path, str], str]

T = TypeVar("T")


class H5ArrayPath(NamedTuple):
"""Location specifier for arrays within an HDF5 file"""
Expand Down Expand Up @@ -77,18 +95,21 @@ class H5Proxy:
path (str): Path to array within hdf5 file
field (str, list[str]): Optional - refer to a specific field within
a compound dtype
annotation_dtype (dtype): Optional - the dtype of our type annotation
"""

def __init__(
self,
file: Union[Path, str],
path: str,
field: Optional[Union[str, List[str]]] = None,
annotation_dtype: Optional[DtypeType] = None,
):
self._h5f = None
self.file = Path(file)
self.path = path
self.field = field
self._annotation_dtype = annotation_dtype

def array_exists(self) -> bool:
"""Check that there is in fact an array at :attr:`.path` within :attr:`.file`"""
Expand Down Expand Up @@ -120,10 +141,12 @@ def __getattr__(self, item: str):

def __getitem__(
self, item: Union[int, slice, Tuple[Union[int, slice], ...]]
) -> np.ndarray:
) -> Union[np.ndarray, DtypeType]:
with h5py.File(self.file, "r") as h5f:
obj = h5f.get(self.path)
# handle compound dtypes
if self.field is not None:
# handle compound string dtype
if encoding := h5py.h5t.check_string_dtype(obj.dtype[self.field]):
if isinstance(item, tuple):
item = (*item, self.field)
Expand All @@ -132,24 +155,41 @@ def __getitem__(

try:
# single string
return obj[item].decode(encoding.encoding)
val = obj[item].decode(encoding.encoding)
if self._annotation_dtype is np.datetime64:
return np.datetime64(val)
else:
return val
except AttributeError:
# numpy array of bytes
return np.char.decode(obj[item], encoding=encoding.encoding)

val = np.char.decode(obj[item], encoding=encoding.encoding)
if self._annotation_dtype is np.datetime64:
return val.astype(np.datetime64)
else:
return val
# normal compound type
else:
obj = obj.fields(self.field)
else:
if h5py.h5t.check_string_dtype(obj.dtype):
obj = obj.asstr()

return obj[item]
val = obj[item]
if self._annotation_dtype is np.datetime64:
if isinstance(val, str):
return np.datetime64(val)
else:
return val.astype(np.datetime64)
else:
return val

def __setitem__(
self,
key: Union[int, slice, Tuple[Union[int, slice], ...]],
value: Union[int, float, np.ndarray],
value: Union[int, float, datetime, np.ndarray],
):
# TODO: Make a generalized value serdes system instead of ad-hoc type conversion
value = self._serialize_datetime(value)
with h5py.File(self.file, "r+", locking=True) as h5f:
obj = h5f.get(self.path)
if self.field is None:
Expand Down Expand Up @@ -184,6 +224,16 @@ def close(self) -> None:
self._h5f.close()
self._h5f = None

def _serialize_datetime(self, v: Union[T, datetime]) -> Union[T, bytes]:
"""
Convert a datetime into a bytestring
"""
if self._annotation_dtype is np.datetime64:
if not isinstance(v, Iterable):
v = [v]
v = np.array(v).astype("S32")
return v


class H5Interface(Interface):
"""
Expand Down Expand Up @@ -253,6 +303,7 @@ def before_validation(self, array: Any) -> NDArrayType:
"Need to specify a file and a path within an HDF5 file to use the HDF5 "
"Interface"
)
array._annotation_dtype = self.dtype

if not array.array_exists():
raise ValueError(
Expand All @@ -269,7 +320,14 @@ def get_dtype(self, array: NDArrayType) -> DtypeType:
Subclasses to correctly handle
"""
if h5py.h5t.check_string_dtype(array.dtype):
return str
# check for datetimes
try:
if array[0].dtype.type is np.datetime64:
return np.datetime64
else:
return str
except (AttributeError, ValueError, TypeError): # pragma: no cover
return str
else:
return array.dtype

Expand Down
6 changes: 5 additions & 1 deletion src/numpydantic/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,11 @@ def _hash_schema(schema: CoreSchema) -> str:
to produce the same hash.
"""
schema_str = json.dumps(
schema, sort_keys=True, indent=None, separators=(",", ":")
schema,
sort_keys=True,
indent=None,
separators=(",", ":"),
default=lambda x: None,
).encode("utf-8")
hasher = hashlib.blake2b(digest_size=8)
hasher.update(schema_str)
Expand Down
12 changes: 11 additions & 1 deletion tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path
from typing import Any, Callable, Optional, Tuple, Type, Union
from warnings import warn
from datetime import datetime, timezone

import h5py
import numpy as np
Expand Down Expand Up @@ -126,15 +127,24 @@ def _hdf5_array(
if not compound:
if dtype is str:
data = np.random.random(shape).astype(bytes)
elif dtype is datetime:
data = np.empty(shape, dtype="S32")
data.fill(datetime.now(timezone.utc).isoformat().encode("utf-8"))
else:
data = np.random.random(shape).astype(dtype)
_ = hdf5_file.create_dataset(array_path, data=data)
return H5ArrayPath(Path(hdf5_file.filename), array_path)
else:

if dtype is str:
dt = np.dtype([("data", np.dtype("S10")), ("extra", "i8")])
data = np.array([("hey", 0)] * np.prod(shape), dtype=dt).reshape(shape)
elif dtype is datetime:
dt = np.dtype([("data", np.dtype("S32")), ("extra", "i8")])
data = np.array(
[(datetime.now(timezone.utc).isoformat().encode("utf-8"), 0)]
* np.prod(shape),
dtype=dt,
).reshape(shape)
else:
dt = np.dtype([("data", dtype), ("extra", "i8")])
data = np.zeros(shape, dtype=dt)
Expand Down
26 changes: 25 additions & 1 deletion tests/test_interface/test_hdf5.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import json
from datetime import datetime, timezone
from typing import Any

import h5py
import pytest

from pydantic import BaseModel, ValidationError

import numpy as np
Expand Down Expand Up @@ -174,3 +175,26 @@ class MyModel(BaseModel):

instance.array[1] = "sup"
assert all(instance.array[1] == "sup")


@pytest.mark.parametrize("compound", [True, False])
def test_datetime(hdf5_array, compound):
"""
We can treat S32 byte arrays as datetimes if our type annotation
says to, including validation, setting and getting values
"""
array = hdf5_array((10, 10), datetime, compound=compound)

class MyModel(BaseModel):
array: NDArray[Any, datetime]

instance = MyModel(array=array)
assert isinstance(instance.array[0, 0], np.datetime64)
assert instance.array[0:5].dtype.type is np.datetime64

now = datetime.now()

instance.array[0, 0] = now
assert instance.array[0, 0] == now
instance.array[0] = now
assert all(instance.array[0] == now)

0 comments on commit 2c625e4

Please sign in to comment.