Skip to content

Commit

Permalink
add indexing of multiple metadata columns in first index, add test fo…
Browse files Browse the repository at this point in the history
…r all objects
  • Loading branch information
sjvenditto committed Oct 23, 2024
1 parent ae2f60b commit 73328ed
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 1 deletion.
12 changes: 12 additions & 0 deletions pynapple/core/interval_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def __setitem__(self, key, value):

def __getitem__(self, key):
if isinstance(key, str):
# self[str]
if key == "start":
return self.values[:, 0]
elif key == "end":
Expand All @@ -284,11 +285,22 @@ def __getitem__(self, key):
raise IndexError(
f"Unknown string argument. Should be in {['start', 'end'] + list(self._metadata.keys())}"
)
elif isinstance(key, list) and all(isinstance(x, str) for x in key):
# self[[*str]]
# easiest to convert to dataframe and then slice
# in case of mixing ["start", "end"] with metadata columns
df = self.as_dataframe()
if all(x in key[1] for x in ["start", "end"]):
return IntervalSet(df[key])

Check warning on line 294 in pynapple/core/interval_set.py

View check run for this annotation

Codecov / codecov/patch

pynapple/core/interval_set.py#L294

Added line #L294 was not covered by tests
else:
return df[key]
elif isinstance(key, Number):
# self[Number]
output = self.values.__getitem__(key)
metadata = _MetadataBase.__getitem__(self, key)
return IntervalSet(start=output[0], end=output[1], **metadata)
elif isinstance(key, (slice, list, np.ndarray, pd.Series)):
# self[array_like]
output = self.values.__getitem__(key)
metadata = _MetadataBase.__getitem__(self, key).reset_index(drop=True)
return IntervalSet(start=output[:, 0], end=output[:, 1], **metadata)
Expand Down
5 changes: 4 additions & 1 deletion pynapple/core/time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,7 +1068,10 @@ def __getitem__(self, key, *args, **kwargs):
or hasattr(key, "__iter__")
and all([isinstance(k, str) for k in key])
):
return self.loc[key]
if all(k in self.metadata_columns for k in key):
return _MetadataBase.__getitem__(self, key)
else:
return self.loc[key]

else:
if isinstance(key, pd.Series) and key.index.equals(self.columns):
Expand Down
3 changes: 3 additions & 0 deletions pynapple/core/ts_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,9 @@ def __getitem__(self, key):
return _MetadataBase.__getitem__(self, key)
else:
raise KeyError(r"Key {} not in group index.".format(key))
elif isinstance(key, list) and all(isinstance(k, str) for k in key):
# index multiple metadata columns
return _MetadataBase.__getitem__(self, key)

# array boolean are transformed into indices
# note that raw boolean are hashable, and won't be
Expand Down
8 changes: 8 additions & 0 deletions tests/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,14 @@ def test_metadata_slicing(self, obj, label, val):
# metadata columns should be the same
assert np.all(obj2.metadata_columns == obj.metadata_columns)

def test_metadata_index_columns(self, obj):
# add metadata
obj.set_info(one=[1, 1, 1, 1], two=[2, 2, 2, 2], three=[3, 3, 3, 3])

# test metadata columns
assert np.all(obj["one"] == 1)
assert np.all(obj[["two", "three"]] == obj._metadata[["two", "three"]])

def test_save_and_load_npz(self, obj):
obj.set_info(label1=[1, 1, 2, 2], label2=[0, 1, 2, 3])

Expand Down

0 comments on commit 73328ed

Please sign in to comment.