Skip to content

Commit

Permalink
Handle NaN values in get_slice (#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
adedaran authored Apr 11, 2023
1 parent 428130a commit 72ba4dc
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
4 changes: 2 additions & 2 deletions sliceline/slicefinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def get_slice(self, X, slice_index: int):
Parameters
----------
X: array-like of shape (n_samples, n_features)
Training data, where `n_samples` is the number of samples
Dataset, where `n_samples` is the number of samples
and `n_features` is the number of features.
slice_index: int
Expand All @@ -209,7 +209,7 @@ def get_slice(self, X, slice_index: int):
self._check_top_slices()

# Input validation
X = check_array(X)
X = check_array(X, force_all_finite=False)

slices_masks = self._get_slices_masks(X)

Expand Down
27 changes: 27 additions & 0 deletions tests/test_slicefinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,3 +475,30 @@ def test_get_slice(benchmark, basic_test_data):
[[1, 1, 1, 3], [1, 1, 2, 3], [1, 1, 3, 3], [1, 1, 4, 1]]
)
assert np.array_equal(computed_slice, expected_slice)


def test_get_slice_with_nan(benchmark, basic_test_data):
"""Test get_slice method with NaN values in the dataset."""
basic_test_data["slicefinder_model"].fit(
basic_test_data["X"], basic_test_data["errors"]
)

dataset_nan_case = np.array(
[
[np.nan, 1, 1, 1, 1, 1, 2, 2],
[1, 1, 1, 1, 2, 2, 1, 1],
[1, 2, 3, np.nan, 5, 6, 7, 8],
[3, 3, 3, 1, 3, 1, 2, 1],
]
).T
computed_slice_nan_case = benchmark(
basic_test_data["slicefinder_model"].get_slice,
dataset_nan_case,
0,
)
expected_slice_nan_case = np.array(
[[1, 1, 2, 3], [1, 1, 3, 3], [1, 1, np.nan, 1]]
)
assert np.array_equal(
computed_slice_nan_case, expected_slice_nan_case, equal_nan=True
)

0 comments on commit 72ba4dc

Please sign in to comment.