-
Notifications
You must be signed in to change notification settings - Fork 225
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for linear-time mmd estimator. #475
base: master
Are you sure you want to change the base?
Changes from 12 commits
f4d2692
8ef820d
ba29712
eaa2e45
ea97f52
f6b93d6
59110ec
7bacbec
7507276
29cd155
eef8def
678eae0
7952ab3
016f23f
05626ec
20e442c
3b96ad4
95f3de4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,10 +4,10 @@ | |
from alibi_detect.utils.frameworks import has_pytorch, has_tensorflow | ||
|
||
if has_pytorch: | ||
from alibi_detect.cd.pytorch.mmd import MMDDriftTorch | ||
from alibi_detect.cd.pytorch.mmd import MMDDriftTorch, LinearTimeDriftTorch | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Change to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. |
||
|
||
if has_tensorflow: | ||
from alibi_detect.cd.tensorflow.mmd import MMDDriftTF | ||
from alibi_detect.cd.tensorflow.mmd import MMDDriftTF, LinearTimeMMDDriftTF | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
@@ -18,6 +18,7 @@ def __init__( | |
x_ref: Union[np.ndarray, list], | ||
backend: str = 'tensorflow', | ||
p_val: float = .05, | ||
estimator: str = 'quad', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added extra description in the docstring. |
||
preprocess_x_ref: bool = True, | ||
update_x_ref: Optional[Dict[str, int]] = None, | ||
preprocess_fn: Optional[Callable] = None, | ||
|
@@ -40,6 +41,11 @@ def __init__( | |
Backend used for the MMD implementation. | ||
p_val | ||
p-value used for the significance of the permutation test. | ||
estimator | ||
Estimator used for the MMD^2 computation {'quad', 'linear'}. 'Quad' is the default and | ||
uses the quadratic u-statistics on each square kernel matrix. 'Linear' uses the linear | ||
time estimator as in Gretton et al. (2014), and the threshold is computed using the Gaussian | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would be helpful to add a link to the paper + reference to the specific section in the paper. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. |
||
asympotic distribution under null. | ||
preprocess_x_ref | ||
Whether to already preprocess and store the reference data. | ||
update_x_ref | ||
|
@@ -76,7 +82,7 @@ def __init__( | |
|
||
kwargs = locals() | ||
args = [kwargs['x_ref']] | ||
pop_kwargs = ['self', 'x_ref', 'backend', '__class__'] | ||
pop_kwargs = ['self', 'x_ref', 'backend', '__class__', 'estimator'] | ||
[kwargs.pop(k, None) for k in pop_kwargs] | ||
|
||
if kernel is None: | ||
|
@@ -88,9 +94,21 @@ def __init__( | |
|
||
if backend == 'tensorflow' and has_tensorflow: | ||
kwargs.pop('device', None) | ||
self._detector = MMDDriftTF(*args, **kwargs) # type: ignore | ||
if estimator == 'quad': | ||
self._detector = MMDDriftTF(*args, **kwargs) # type: ignore | ||
elif estimator == 'linear': | ||
kwargs.pop('n_permutations', None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Best to clarify in the docstrings that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. |
||
self._detector = LinearTimeMMDDriftTF(*args, **kwargs) # type: ignore | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since the logic to set There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed, will modify the tests. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Simply rewrite the test to go through different |
||
else: | ||
raise NotImplementedError(f'{estimator} not implemented. Use quad or linear instead.') | ||
else: | ||
self._detector = MMDDriftTorch(*args, **kwargs) # type: ignore | ||
if estimator == 'quad': | ||
self._detector = MMDDriftTorch(*args, **kwargs) # type: ignore | ||
elif estimator == 'linear': | ||
kwargs.pop('n_permutations', None) | ||
self._detector = LinearTimeDriftTorch(*args, **kwargs) # type: ignore | ||
else: | ||
raise NotImplementedError(f'{estimator} not implemented. Use quad or linear instead.') | ||
self.meta = self._detector.meta | ||
|
||
def predict(self, x: Union[np.ndarray, list], return_p_val: bool = True, return_distance: bool = True) \ | ||
|
@@ -128,6 +146,7 @@ def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, np.ndarray]: | |
Returns | ||
------- | ||
p-value obtained from the permutation test, the MMD^2 between the reference and test set | ||
and the MMD^2 values from the permutation test. | ||
and the MMD^2 values from the qudratic permutation test, or the threshold for the given | ||
significance level for the linear time test. | ||
""" | ||
return self._detector.score(x) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,10 @@ | ||
import logging | ||
import numpy as np | ||
import scipy.stats as stats | ||
import torch | ||
from typing import Callable, Dict, Optional, Tuple, Union | ||
from alibi_detect.cd.base import BaseMMDDrift | ||
from alibi_detect.utils.pytorch.distance import mmd2_from_kernel_matrix | ||
from alibi_detect.utils.pytorch.distance import mmd2_from_kernel_matrix, linear_mmd2 | ||
from alibi_detect.utils.pytorch.kernels import GaussianRBF | ||
|
||
logger = logging.getLogger(__name__) | ||
|
@@ -118,17 +119,144 @@ def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, np.ndarray]: | |
and the MMD^2 values from the permutation test. | ||
""" | ||
x_ref, x = self.preprocess(x) | ||
n = x.shape[0] | ||
x_ref = torch.from_numpy(x_ref).to(self.device) # type: ignore[assignment] | ||
x = torch.from_numpy(x).to(self.device) # type: ignore[assignment] | ||
# compute kernel matrix, MMD^2 and apply permutation test using the kernel matrix | ||
n = x.shape[0] | ||
kernel_mat = self.kernel_matrix(x_ref, x) # type: ignore[arg-type] | ||
kernel_mat = kernel_mat - torch.diag(kernel_mat.diag()) # zero diagonal | ||
mmd2 = mmd2_from_kernel_matrix(kernel_mat, n, permute=False, zero_diag=False) | ||
mmd2 = mmd2_from_kernel_matrix(kernel_mat, n, permute=False, zero_diag=False) # type: ignore[assignment] | ||
mmd2_permuted = torch.Tensor( | ||
[mmd2_from_kernel_matrix(kernel_mat, n, permute=True, zero_diag=False) for _ in range(self.n_permutations)] | ||
) | ||
[mmd2_from_kernel_matrix(kernel_mat, n, permute=True, zero_diag=False) | ||
for _ in range(self.n_permutations)] | ||
) | ||
if self.device.type == 'cuda': | ||
mmd2, mmd2_permuted = mmd2.cpu(), mmd2_permuted.cpu() | ||
p_val = (mmd2 <= mmd2_permuted).float().mean() | ||
return p_val.numpy().item(), mmd2.numpy().item(), mmd2_permuted.numpy() | ||
|
||
|
||
class LinearTimeDriftTorch(BaseMMDDrift): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As mentioned before, should probably be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The names just get better and better 😅 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. |
||
def __init__( | ||
self, | ||
x_ref: Union[np.ndarray, list], | ||
p_val: float = .05, | ||
preprocess_x_ref: bool = True, | ||
update_x_ref: Optional[Dict[str, int]] = None, | ||
preprocess_fn: Optional[Callable] = None, | ||
kernel: Callable = GaussianRBF, | ||
sigma: Optional[np.ndarray] = None, | ||
configure_kernel_from_x_ref: bool = True, | ||
n_permutations: int = 100, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should not contain the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. |
||
device: Optional[str] = None, | ||
input_shape: Optional[tuple] = None, | ||
data_type: Optional[str] = None | ||
) -> None: | ||
""" | ||
Maximum Mean Discrepancy (MMD) data drift detector using a permutation test, with linear-time estimator. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It doesn't use a permutation test though? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No it doesn't, removed now. |
||
|
||
Parameters | ||
---------- | ||
x_ref | ||
Data used as reference distribution. | ||
p_val | ||
p-value used for the significance of the permutation test. | ||
preprocess_x_ref | ||
Whether to already preprocess and store the reference data. | ||
update_x_ref | ||
Reference data can optionally be updated to the last n instances seen by the detector | ||
or via reservoir sampling with size n. For the former, the parameter equals {'last': n} while | ||
for reservoir sampling {'reservoir_sampling': n} is passed. | ||
preprocess_fn | ||
Function to preprocess the data before computing the data drift metrics. | ||
kernel | ||
Kernel used for the MMD computation, defaults to Gaussian RBF kernel. | ||
sigma | ||
Optionally set the GaussianRBF kernel bandwidth. Can also pass multiple bandwidth values as an array. | ||
The kernel evaluation is then averaged over those bandwidths. | ||
configure_kernel_from_x_ref | ||
Whether to already configure the kernel bandwidth from the reference data. | ||
n_permutations | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again remove There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. |
||
Number of permutations used in the permutation test. | ||
device | ||
Device type used. The default None tries to use the GPU and falls back on CPU if needed. | ||
Can be specified by passing either 'cuda', 'gpu' or 'cpu'. | ||
input_shape | ||
Shape of input data. | ||
data_type | ||
Optionally specify the data type (tabular, image or time-series). Added to metadata. | ||
""" | ||
super().__init__( | ||
x_ref=x_ref, | ||
p_val=p_val, | ||
preprocess_x_ref=preprocess_x_ref, | ||
update_x_ref=update_x_ref, | ||
preprocess_fn=preprocess_fn, | ||
sigma=sigma, | ||
configure_kernel_from_x_ref=configure_kernel_from_x_ref, | ||
n_permutations=n_permutations, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does not need to pass There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. |
||
input_shape=input_shape, | ||
data_type=data_type | ||
) | ||
self.meta.update({'backend': 'pytorch'}) | ||
|
||
# set backend | ||
if device is None or device.lower() in ['gpu', 'cuda']: | ||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
if self.device.type == 'cpu': | ||
print('No GPU detected, fall back on CPU.') | ||
else: | ||
self.device = torch.device('cpu') | ||
|
||
# initialize kernel | ||
sigma = torch.from_numpy(sigma).to(self.device) if isinstance(sigma, # type: ignore[assignment] | ||
np.ndarray) else None | ||
self.kernel = kernel(sigma) if kernel == GaussianRBF else kernel | ||
|
||
# compute kernel matrix for the reference data | ||
if self.infer_sigma or isinstance(sigma, torch.Tensor): | ||
x = torch.from_numpy(self.x_ref).to(self.device) | ||
self.k_xx = self.kernel(x, x, infer_sigma=self.infer_sigma) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed, not |
||
self.infer_sigma = False | ||
else: | ||
self.k_xx, self.infer_sigma = None, True | ||
|
||
def kernel_matrix(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Method is not used I believe? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @arnaudvl The base class requires this method for initialisation, I wonder what would be the preferable solution here? the minimal thing could be to simply leave a pseudo method. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IMO we should remove |
||
""" Compute and return full kernel matrix between arrays x and y. """ | ||
k_xy = self.kernel(x, y, self.infer_sigma) | ||
k_xx = self.k_xx if self.k_xx is not None and self.update_x_ref is None else self.kernel(x, x) | ||
k_yy = self.kernel(y, y) | ||
kernel_mat = torch.cat([torch.cat([k_xx, k_xy], 1), torch.cat([k_xy.T, k_yy], 1)], 0) | ||
return kernel_mat | ||
|
||
def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, np.ndarray]: | ||
""" | ||
Compute the p-value resulting from a permutation test using the maximum mean discrepancy | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove reference to permutation test There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. |
||
as a distance measure between the reference data and the data to be tested. | ||
|
||
Parameters | ||
---------- | ||
x | ||
Batch of instances. | ||
|
||
Returns | ||
------- | ||
p-value obtained from the null hypothesis, the MMD^2 between the reference and test set | ||
and the MMD^2 threshold for the given significance level. | ||
""" | ||
x_ref, x = self.preprocess(x) | ||
n = x.shape[0] | ||
m = x_ref.shape[0] | ||
n_hat = int(np.floor(min(n, m) / 2) * 2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This behaviour needs to be well documented. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, now described in the docstring. |
||
x_ref = torch.from_numpy(x_ref[:n_hat, :]).to(self.device) # type: ignore[assignment] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe there is a case to be made that there is an explicit check such that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems like quite an issue atm. Agree the safest option would be to explicitly check for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently implemented as raise error for |
||
x = torch.from_numpy(x[:n_hat, :]).to(self.device) # type: ignore[assignment] | ||
mmd2, var_mmd2 = linear_mmd2(x_ref, x, self.kernel) # type: ignore[arg-type] | ||
if self.device.type == 'cuda': | ||
mmd2 = mmd2.cpu() | ||
mmd2 = mmd2.numpy().item() | ||
var_mmd2 = var_mmd2.numpy().item() | ||
std_mmd2 = np.sqrt(var_mmd2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can directly use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The new version uses |
||
t = mmd2 / (std_mmd2 / np.sqrt(n_hat / 2.)) | ||
p_val = 1 - stats.t.cdf(t, df=(n_hat / 2.) - 1) | ||
distance_threshold = stats.t.ppf(1 - self.p_val, df=(n_hat / 2.) - 1) | ||
return p_val, t, distance_threshold |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It doesn't seem ideal to have to deal with the differing behaviour of
score
for the original vs new detectors inpredict
.Maybe we could move the
distance_threshold
computation toscore
for the original MMD detectors, and then the above would be simplified quite a bit? Draft PR for this here: #489There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes indeed, guess the best thing to do here is to follow your draft PR's template to modify the linear-time detector.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd say if @arnaudvl and @ojcobb agree with the change in #489 , we should merge that, then it would be a super quick change to this PR.