Skip to content

Commit

Permalink
return tsdframe when indexing a single number to preserve columns and…
Browse files Browse the repository at this point in the history
… metadata
  • Loading branch information
sjvenditto committed Oct 29, 2024
1 parent f70cc71 commit 726821e
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 14 deletions.
19 changes: 9 additions & 10 deletions pynapple/core/time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,17 +1101,16 @@ def __getitem__(self, key, *args, **kwargs):
index = np.array([index])

if all(is_array_like(a) for a in [index, output]):
if output.shape[0] == index.shape[0]:
# if isinstance(columns, pd.Index):
# if not pd.api.types.is_integer_dtype(columns):
kwargs["columns"] = columns
kwargs["metadata"] = self._metadata.loc[columns]
# reshape output if single index to preserve column axis
if len(index) == 1:
output = output[None, :]

return _get_class(output)(
t=index, d=output, time_support=self.time_support, **kwargs
)
else:
return output
kwargs["columns"] = columns
kwargs["metadata"] = self._metadata.loc[columns]

return _get_class(output)(
t=index, d=output, time_support=self.time_support, **kwargs
)
else:
return output

Expand Down
14 changes: 10 additions & 4 deletions tests/test_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,12 +956,18 @@ def test_horizontal_slicing(self, tsdframe, index, nap_type):
tsdframe[:, index].metadata_index == tsdframe.metadata_index[index]
)

@pytest.mark.parametrize("index", [slice(0, 10), [0, 2]])
@pytest.mark.parametrize("index", [0, slice(0, 10), [0, 2]])
def test_vertical_slicing(self, tsdframe, index):
assert isinstance(tsdframe[index], nap.TsdFrame)
np.testing.assert_array_almost_equal(
tsdframe.values[index], tsdframe[index].values
)
if len(tsdframe[index] == 1):
# use ravel to ignore shape mismatch
np.testing.assert_array_almost_equal(
tsdframe.values[index].ravel(), tsdframe[index].values.ravel()
)
else:
np.testing.assert_array_almost_equal(
tsdframe.values[index], tsdframe[index].values
)
assert isinstance(tsdframe[index].time_support, nap.IntervalSet)
np.testing.assert_array_almost_equal(
tsdframe[index].time_support, tsdframe.time_support
Expand Down

0 comments on commit 726821e

Please sign in to comment.