Skip to content

Commit

Permalink
chore: add data shape validation
Browse files Browse the repository at this point in the history
Signed-off-by: Avik Basu <ab93@users.noreply.github.com>
  • Loading branch information
ab93 committed Jul 26, 2023
1 parent 22a8ec9 commit 2da4ea0
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 4 deletions.
19 changes: 17 additions & 2 deletions numalogic/models/threshold/_mahalanobis.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from numalogic.base import BaseThresholdModel
from typing_extensions import Self

from numalogic.tools.exceptions import ModelInitializationError
from numalogic.tools.exceptions import ModelInitializationError, InvalidDataShapeError

_INLIER: Final[int] = 0
_OUTLIER: Final[int] = 1
Expand Down Expand Up @@ -84,6 +84,12 @@ def _chebyshev_k(p: float) -> float:
"""Calculate the k value using Chebyshev's inequality."""
return np.reciprocal(np.sqrt(p))

@staticmethod
def _validate_input(x: npt.NDArray[float]) -> None:
"""Validate the input matrix shape."""
if x.ndim != 2:
raise InvalidDataShapeError(f"Input matrix should have 2 dims, given shape: {x.shape}.")

def fit(self, x: npt.NDArray[float]) -> Self:
"""
Fit the estimator on the training set.
Expand All @@ -95,7 +101,12 @@ def fit(self, x: npt.NDArray[float]) -> Self:
Returns
-------
self
Raises
------
InvalidDataShapeError: if the input matrix is not 2D
"""
self._validate_input(x)
self._distr_mean = np.mean(x, axis=0)
cov = np.cov(x, rowvar=False)
self._cov_inv = np.linalg.pinv(cov)
Expand Down Expand Up @@ -135,10 +146,12 @@ def predict(self, x: npt.NDArray[float]) -> npt.NDArray[int]:
Raises
------
RuntimeError: if the model is not fitted yet
ModelInitializationError: if the model is not fitted yet
InvalidDataShapeError: if the input matrix is not 2D
"""
if not self._is_fitted:
raise ModelInitializationError("Model not fitted yet.")
self._validate_input(x)
md = self.mahalanobis(x)
y_hat = np.zeros(x.shape[0], dtype=int)
y_hat[md >= self._md_thresh] = _OUTLIER
Expand All @@ -162,7 +175,9 @@ def score_samples(self, x: npt.NDArray[float]) -> npt.NDArray[float]:
Raises
------
RuntimeError: if the model is not fitted yet
InvalidDataShapeError: if the input matrix is not 2D
"""
if not self._is_fitted:
raise ModelInitializationError("Model not fitted yet.")
self._validate_input(x)
return self.mahalanobis(x) / self._md_thresh
14 changes: 12 additions & 2 deletions tests/models/test_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
SigmoidThreshold,
MahalanobisThreshold,
)
from numalogic.tools.exceptions import ModelInitializationError
from numalogic.tools.exceptions import ModelInitializationError, InvalidDataShapeError


class TestStdDevThreshold(unittest.TestCase):
Expand Down Expand Up @@ -65,10 +65,20 @@ def test_predict(self):
self.assertEqual(np.max(y), 1)
self.assertEqual(np.min(y), 0)

def test_predict_err(self):
def test_notfitted_err(self):
clf = MahalanobisThreshold()
with self.assertRaises(ModelInitializationError):
clf.predict(self.x_test)
with self.assertRaises(ModelInitializationError):
clf.score_samples(self.x_test)

def test_invalid_input_err(self):
clf = MahalanobisThreshold()
clf.fit(self.x_train)
with self.assertRaises(InvalidDataShapeError):
clf.predict(np.ones((30, 15, 1)))
with self.assertRaises(InvalidDataShapeError):
clf.score_samples(np.ones(30))

def test_score_samples(self):
clf = MahalanobisThreshold()
Expand Down

0 comments on commit 2da4ea0

Please sign in to comment.