From 2da4ea03b90fe4f30bd5ba922783efcc090d98e0 Mon Sep 17 00:00:00 2001 From: Avik Basu Date: Wed, 26 Jul 2023 14:45:45 -0400 Subject: [PATCH] chore: add data shape validation Signed-off-by: Avik Basu --- numalogic/models/threshold/_mahalanobis.py | 19 +++++++++++++++++-- tests/models/test_threshold.py | 14 ++++++++++++-- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/numalogic/models/threshold/_mahalanobis.py b/numalogic/models/threshold/_mahalanobis.py index e8838485..6b996bcc 100644 --- a/numalogic/models/threshold/_mahalanobis.py +++ b/numalogic/models/threshold/_mahalanobis.py @@ -17,7 +17,7 @@ from numalogic.base import BaseThresholdModel from typing_extensions import Self -from numalogic.tools.exceptions import ModelInitializationError +from numalogic.tools.exceptions import ModelInitializationError, InvalidDataShapeError _INLIER: Final[int] = 0 _OUTLIER: Final[int] = 1 @@ -84,6 +84,12 @@ def _chebyshev_k(p: float) -> float: """Calculate the k value using Chebyshev's inequality.""" return np.reciprocal(np.sqrt(p)) + @staticmethod + def _validate_input(x: npt.NDArray[float]) -> None: + """Validate the input matrix shape.""" + if x.ndim != 2: + raise InvalidDataShapeError(f"Input matrix should have 2 dims, given shape: {x.shape}.") + def fit(self, x: npt.NDArray[float]) -> Self: """ Fit the estimator on the training set. @@ -95,7 +101,12 @@ def fit(self, x: npt.NDArray[float]) -> Self: Returns ------- self + + Raises + ------ + InvalidDataShapeError: if the input matrix is not 2D """ + self._validate_input(x) self._distr_mean = np.mean(x, axis=0) cov = np.cov(x, rowvar=False) self._cov_inv = np.linalg.pinv(cov) @@ -135,10 +146,12 @@ def predict(self, x: npt.NDArray[float]) -> npt.NDArray[int]: Raises ------ - RuntimeError: if the model is not fitted yet + ModelInitializationError: if the model is not fitted yet + InvalidDataShapeError: if the input matrix is not 2D """ if not self._is_fitted: raise ModelInitializationError("Model not fitted yet.") + self._validate_input(x) md = self.mahalanobis(x) y_hat = np.zeros(x.shape[0], dtype=int) y_hat[md >= self._md_thresh] = _OUTLIER @@ -162,7 +175,9 @@ def score_samples(self, x: npt.NDArray[float]) -> npt.NDArray[float]: Raises ------ RuntimeError: if the model is not fitted yet + InvalidDataShapeError: if the input matrix is not 2D """ if not self._is_fitted: raise ModelInitializationError("Model not fitted yet.") + self._validate_input(x) return self.mahalanobis(x) / self._md_thresh diff --git a/tests/models/test_threshold.py b/tests/models/test_threshold.py index 4fdf1a8a..bd450904 100644 --- a/tests/models/test_threshold.py +++ b/tests/models/test_threshold.py @@ -8,7 +8,7 @@ SigmoidThreshold, MahalanobisThreshold, ) -from numalogic.tools.exceptions import ModelInitializationError +from numalogic.tools.exceptions import ModelInitializationError, InvalidDataShapeError class TestStdDevThreshold(unittest.TestCase): @@ -65,10 +65,20 @@ def test_predict(self): self.assertEqual(np.max(y), 1) self.assertEqual(np.min(y), 0) - def test_predict_err(self): + def test_notfitted_err(self): clf = MahalanobisThreshold() with self.assertRaises(ModelInitializationError): clf.predict(self.x_test) + with self.assertRaises(ModelInitializationError): + clf.score_samples(self.x_test) + + def test_invalid_input_err(self): + clf = MahalanobisThreshold() + clf.fit(self.x_train) + with self.assertRaises(InvalidDataShapeError): + clf.predict(np.ones((30, 15, 1))) + with self.assertRaises(InvalidDataShapeError): + clf.score_samples(np.ones(30)) def test_score_samples(self): clf = MahalanobisThreshold()