diff --git a/src/signature_mahalanobis_knn/sig_mahal_knn.py b/src/signature_mahalanobis_knn/sig_mahal_knn.py index dc88996..d831460 100644 --- a/src/signature_mahalanobis_knn/sig_mahal_knn.py +++ b/src/signature_mahalanobis_knn/sig_mahal_knn.py @@ -181,6 +181,7 @@ def conformance( signatures_test: np.ndarray | None = None, n_neighbors: int = 20, return_indices: bool = False, + take_median: bool = False, ) -> np.ndarray: """ Compute the conformance scores for the data points either passed in @@ -202,9 +203,14 @@ def conformance( signatures_test : np.ndarray | None, optional Signatures of the data points, by default None. Two dimensional array of shape (n_samples, sig_dim). + n_neighbors : int, optional return_indices : bool, optional Whether to return the indices of the nearest neighbors, by default False. + take_median : bool, optional + Whether we should take the median distance of the k nearest neighbours. + By default, the pipeline takes the minimum distance. (min of k nearest neighbours = closest neighbour) + Setting take_median to true allows for a more robust measure. Returns ------- @@ -274,8 +280,16 @@ def conformance( candidate_distances[denominator < self.mahal_distance.zero_thres] = 0 candidate_distances[rho > self.mahal_distance.subspace_thres] = np.inf - # compute the minimum of the candidate distances for each data point - if return_indices: - return np.min(candidate_distances, axis=-1), train_indices + + if take_median: + # compute the median of the k nearest neighbour distances for each data point for robustness + if return_indices: + return np.median(candidate_distances, axis=-1), train_indices + return np.median(candidate_distances, axis=-1) + + else: + # compute the minimum of the candidate distances for each data point = 1-nearest neighbour distance (the vanilla implementation) + if return_indices: + return np.min(candidate_distances, axis=-1), train_indices - return np.min(candidate_distances, axis=-1) + return np.min(candidate_distances, axis=-1) \ No newline at end of file