Skip to content

Commit

Permalink
Merge pull request #96 from legend-exp/fix
Browse files Browse the repository at this point in the history
Pint and NumPy compatibility fixes
  • Loading branch information
gipert authored Jul 2, 2024
2 parents 826daa1 + 7591e54 commit d050bf7
Show file tree
Hide file tree
Showing 9 changed files with 38 additions and 19 deletions.
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ ci:

repos:
- repo: https://github.com/adamchainz/blacken-docs
rev: "1.16.0"
rev: "1.18.0"
hooks:
- id: blacken-docs
additional_dependencies: [black==23.*]
Expand Down Expand Up @@ -40,14 +40,14 @@ repos:
]

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.3.7"
rev: "v0.5.0"
hooks:
- id: ruff
args: ["--fix", "--show-fixes"]
- id: ruff-format

- repo: https://github.com/codespell-project/codespell
rev: "v2.2.6"
rev: "v2.3.0"
hooks:
- id: codespell

Expand All @@ -72,12 +72,12 @@ repos:
args: [--prose-wrap=always]

- repo: https://github.com/abravalheri/validate-pyproject
rev: v0.16
rev: v0.18
hooks:
- id: validate-pyproject

- repo: https://github.com/python-jsonschema/check-jsonschema
rev: 0.28.2
rev: 0.28.6
hooks:
- id: check-dependabot
- id: check-github-workflows
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ dependencies = [
"numpy>=1.21",
"pandas>=1.4.4",
"parse",
"pint",
"pint!=0.24",
"pint-pandas",
]
dynamic = [
Expand Down
6 changes: 3 additions & 3 deletions src/lgdo/lh5/_serializers/read/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _h5_read_lgdo(
lh5_file = list(h5f)
n_rows_read = 0

for i, h5f in enumerate(lh5_file):
for i, _h5f in enumerate(lh5_file):
if isinstance(idx, list) and len(idx) > 0 and not np.isscalar(idx[0]):
# a list of lists: must be one per file
idx_i = idx[i]
Expand All @@ -65,7 +65,7 @@ def _h5_read_lgdo(
if not (isinstance(idx, tuple) and len(idx) == 1):
idx = (idx,)
# idx is a long continuous array
n_rows_i = read_n_rows(name, h5f)
n_rows_i = read_n_rows(name, _h5f)
# find the length of the subset of idx that contains indices
# that are less than n_rows_i
n_rows_to_read_i = bisect.bisect_left(idx[0], n_rows_i)
Expand All @@ -78,7 +78,7 @@ def _h5_read_lgdo(

obj_buf, n_rows_read_i = _h5_read_lgdo(
name,
h5f,
_h5f,
start_row=start_row,
n_rows=n_rows_i,
idx=idx_i,
Expand Down
29 changes: 24 additions & 5 deletions src/lgdo/lh5/_serializers/write/vector_of_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import logging

import numpy as np

from .... import types
from ... import utils
from ...exceptions import LH5EncodeError
Expand Down Expand Up @@ -31,12 +33,15 @@ def _h5_write_vector_of_vectors(

# if appending we need to add an appropriate offset to the
# cumulative lengths as appropriate for the in-file object
offset = 0 # declare here because we have to subtract it off at the end
# declare here because we have to subtract it off at the end
offset = np.int64(0)
if (wo_mode in ("a", "o")) and "cumulative_length" in group:
len_cl = len(group["cumulative_length"])
# if append, ignore write_start and set it to total number of vectors
if wo_mode == "a":
write_start = len_cl
if len_cl > 0:
# set offset to correct number of elements in flattened_data until write_start
offset = group["cumulative_length"][write_start - 1]

# First write flattened_data array. Only write rows with data.
Expand Down Expand Up @@ -71,15 +76,23 @@ def _h5_write_vector_of_vectors(
)

# now offset is used to give appropriate in-file values for
# cumulative_length. Need to adjust it for start_row
# cumulative_length. Need to adjust it for start_row, if different from zero
if start_row > 0:
offset -= obj.cumulative_length.nda[start_row - 1]

# Add offset to obj.cumulative_length itself to avoid memory allocation.
# Then subtract it off after writing! (otherwise it will be changed
# upon return)
cl_dtype = obj.cumulative_length.nda.dtype.type
obj.cumulative_length.nda += cl_dtype(offset)

# NOTE: this operation is not numerically safe (uint overflow in the lower
# part of the array), but this is not a problem because those values are
# not written to disk and we are going to restore the offset at the end
np.add(
obj.cumulative_length.nda,
offset,
out=obj.cumulative_length.nda,
casting="unsafe",
)

_h5_write_array(
obj.cumulative_length,
Expand All @@ -92,4 +105,10 @@ def _h5_write_vector_of_vectors(
write_start=write_start,
**h5py_kwargs,
)
obj.cumulative_length.nda -= cl_dtype(offset)

np.subtract(
obj.cumulative_length.nda,
offset,
out=obj.cumulative_length.nda,
casting="unsafe",
)
2 changes: 1 addition & 1 deletion src/lgdo/lh5/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def load_nda(
f = sto.gimme_file(ff, "r")
for par in par_list:
if f"{lh5_group}/{par}" not in f:
msg = f"'{lh5_group}/{par}' not in file {f_list[ii]}"
msg = f"'{lh5_group}/{par}' not in file {ff}"
raise RuntimeError(msg)

if idx_list is None:
Expand Down
2 changes: 1 addition & 1 deletion src/lgdo/types/fixedsizearray.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,4 @@ def view_as(self, library: str, with_units: bool = False):
--------
.LGDO.view_as
"""
return super.view_as(library, with_units=with_units)
return super().view_as(library, with_units=with_units)
2 changes: 1 addition & 1 deletion src/lgdo/types/vectorofvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ def _set_vector_unsafe(
lens = np.array([lens], dtype="u4")

# calculate stop index in flattened_data
cum_lens = start + lens.cumsum()
cum_lens = np.add(start, lens.cumsum(), dtype=int)

# fill with fast vectorized routine
vovutils._nb_fill(vec, lens, self.flattened_data.nda[start : cum_lens[-1]])
Expand Down
2 changes: 1 addition & 1 deletion src/lgdo/units.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
import pint

default_units_registry = pint.get_application_registry()
default_units_registry.default_format = "~P"
default_units_registry.formatter.default_format = "~P"
2 changes: 1 addition & 1 deletion tests/compression/test_uleb128_zigzag_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_uleb128zzdiff_encode_decode_equality():
pos = varlen.uleb128_encode(varlen.zigzag_encode(int(s) - last), encx)
assert np.array_equal(sig_out[offset : offset + pos], encx[:pos])
offset += pos
last = s
last = int(s)

sig_in_dec = np.empty(100, dtype="uint32")
siglen = np.empty(1, dtype="uint32")
Expand Down

0 comments on commit d050bf7

Please sign in to comment.