-
Notifications
You must be signed in to change notification settings - Fork 225
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for linear-time mmd estimator. #475
base: master
Are you sure you want to change the base?
Changes from all commits
f4d2692
8ef820d
ba29712
eaa2e45
ea97f52
f6b93d6
59110ec
7bacbec
7507276
29cd155
eef8def
678eae0
7952ab3
016f23f
05626ec
20e442c
3b96ad4
95f3de4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,12 +4,13 @@ | |
from alibi_detect.utils.frameworks import has_pytorch, has_tensorflow, BackendValidator | ||
from alibi_detect.utils.warnings import deprecated_alias | ||
from alibi_detect.base import DriftConfigMixin | ||
from alibi_detect.utils._types import Literal | ||
|
||
if has_pytorch: | ||
from alibi_detect.cd.pytorch.mmd import MMDDriftTorch | ||
from alibi_detect.cd.pytorch.mmd import MMDDriftTorch, LinearTimeMMDDriftTorch | ||
|
||
if has_tensorflow: | ||
from alibi_detect.cd.tensorflow.mmd import MMDDriftTF | ||
from alibi_detect.cd.tensorflow.mmd import MMDDriftTF, LinearTimeMMDDriftTF | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
@@ -21,6 +22,7 @@ def __init__( | |
x_ref: Union[np.ndarray, list], | ||
backend: str = 'tensorflow', | ||
p_val: float = .05, | ||
estimator: Literal['quad', 'linear'] = 'quad', | ||
x_ref_preprocessed: bool = False, | ||
preprocess_at_init: bool = True, | ||
update_x_ref: Optional[Dict[str, int]] = None, | ||
|
@@ -44,6 +46,11 @@ def __init__( | |
Backend used for the MMD implementation. | ||
p_val | ||
p-value used for the significance of the permutation test. | ||
estimator | ||
Estimator used for the MMD^2 computation. 'quad' is the default and | ||
uses the quadratic u-statistics on each square kernel matrix. 'linear' uses the linear | ||
time estimator as in Gretton et al. (JMLR 2014, sec 6), and the threshold is computed | ||
using the Gaussian asympotic distribution under null. | ||
x_ref_preprocessed | ||
Whether the given reference data `x_ref` has been preprocessed yet. If `x_ref_preprocessed=True`, only | ||
the test data `x` will be preprocessed at prediction time. If `x_ref_preprocessed=False`, the reference | ||
|
@@ -65,7 +72,8 @@ def __init__( | |
configure_kernel_from_x_ref | ||
Whether to already configure the kernel bandwidth from the reference data. | ||
n_permutations | ||
Number of permutations used in the permutation test. | ||
Number of permutations used in the permutation test, only used for the quadratic estimator | ||
(estimator='quad'). | ||
device | ||
Device type used. The default None tries to use the GPU and falls back on CPU if needed. | ||
Can be specified by passing either 'cuda', 'gpu' or 'cpu'. Only relevant for 'pytorch' backend. | ||
|
@@ -80,6 +88,7 @@ def __init__( | |
self._set_config(locals()) | ||
|
||
backend = backend.lower() | ||
estimator = estimator.lower() # type: ignore | ||
BackendValidator( | ||
backend_options={'tensorflow': ['tensorflow'], | ||
'pytorch': ['pytorch']}, | ||
|
@@ -88,7 +97,7 @@ def __init__( | |
|
||
kwargs = locals() | ||
args = [kwargs['x_ref']] | ||
pop_kwargs = ['self', 'x_ref', 'backend', '__class__'] | ||
pop_kwargs = ['self', 'x_ref', 'backend', '__class__', 'estimator'] | ||
[kwargs.pop(k, None) for k in pop_kwargs] | ||
|
||
if kernel is None: | ||
|
@@ -100,9 +109,21 @@ def __init__( | |
|
||
if backend == 'tensorflow' and has_tensorflow: | ||
kwargs.pop('device', None) | ||
self._detector = MMDDriftTF(*args, **kwargs) # type: ignore | ||
if estimator == 'quad': | ||
self._detector = MMDDriftTF(*args, **kwargs) # type: ignore | ||
elif estimator == 'linear': | ||
kwargs.pop('n_permutations', None) | ||
self._detector = LinearTimeMMDDriftTF(*args, **kwargs) # type: ignore | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since the logic to set There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed, will modify the tests. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Simply rewrite the test to go through different |
||
else: | ||
raise NotImplementedError(f'{estimator} not implemented. Use quad or linear instead.') | ||
else: | ||
self._detector = MMDDriftTorch(*args, **kwargs) # type: ignore | ||
if estimator == 'quad': | ||
self._detector = MMDDriftTorch(*args, **kwargs) # type: ignore | ||
elif estimator == 'linear': | ||
kwargs.pop('n_permutations', None) | ||
self._detector = LinearTimeMMDDriftTorch(*args, **kwargs) # type: ignore | ||
else: | ||
raise NotImplementedError(f'{estimator} not implemented. Use quad or linear instead.') | ||
self.meta = self._detector.meta | ||
|
||
def predict(self, x: Union[np.ndarray, list], return_p_val: bool = True, return_distance: bool = True) \ | ||
|
@@ -139,7 +160,7 @@ def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, float]: | |
|
||
Returns | ||
------- | ||
p-value obtained from the permutation test, the MMD^2 between the reference and test set, | ||
p-value obtained from the test, the MMD^2 between the reference and test set, | ||
and the MMD^2 threshold above which drift is flagged. | ||
""" | ||
return self._detector.score(x) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,11 @@ | ||
import logging | ||
import numpy as np | ||
import scipy.stats as stats | ||
import torch | ||
from typing import Callable, Dict, Optional, Tuple, Union | ||
from alibi_detect.cd.base import BaseMMDDrift | ||
from alibi_detect.utils.pytorch.distance import mmd2_from_kernel_matrix, linear_mmd2 | ||
from alibi_detect.utils.pytorch import get_device | ||
from alibi_detect.utils.pytorch.distance import mmd2_from_kernel_matrix | ||
from alibi_detect.utils.pytorch.kernels import GaussianRBF | ||
from alibi_detect.utils.warnings import deprecated_alias | ||
|
||
|
@@ -123,21 +124,162 @@ def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, float]: | |
and the MMD^2 threshold above which drift is flagged. | ||
""" | ||
x_ref, x = self.preprocess(x) | ||
n = x.shape[0] | ||
x_ref = torch.from_numpy(x_ref).to(self.device) # type: ignore[assignment] | ||
x = torch.from_numpy(x).to(self.device) # type: ignore[assignment] | ||
# compute kernel matrix, MMD^2 and apply permutation test using the kernel matrix | ||
# TODO: (See https://github.com/SeldonIO/alibi-detect/issues/540) | ||
n = x.shape[0] # type: ignore | ||
kernel_mat = self.kernel_matrix(x_ref, x) # type: ignore[arg-type] | ||
kernel_mat = kernel_mat - torch.diag(kernel_mat.diag()) # zero diagonal | ||
mmd2 = mmd2_from_kernel_matrix(kernel_mat, n, permute=False, zero_diag=False) | ||
mmd2 = mmd2_from_kernel_matrix(kernel_mat, n, permute=False, zero_diag=False) # type: ignore[assignment] | ||
mmd2_permuted = torch.Tensor( | ||
[mmd2_from_kernel_matrix(kernel_mat, n, permute=True, zero_diag=False) for _ in range(self.n_permutations)] | ||
) | ||
[mmd2_from_kernel_matrix(kernel_mat, n, permute=True, zero_diag=False) | ||
for _ in range(self.n_permutations)] | ||
) | ||
if self.device.type == 'cuda': | ||
mmd2, mmd2_permuted = mmd2.cpu(), mmd2_permuted.cpu() | ||
p_val = (mmd2 <= mmd2_permuted).float().mean() | ||
# compute distance threshold | ||
idx_threshold = int(self.p_val * len(mmd2_permuted)) | ||
distance_threshold = torch.sort(mmd2_permuted, descending=True).values[idx_threshold] | ||
return p_val.numpy().item(), mmd2.numpy().item(), distance_threshold.numpy() | ||
|
||
|
||
class LinearTimeMMDDriftTorch(BaseMMDDrift): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since these new subclasses don't make use of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point. The default number of permutations then can be initialised in |
||
def __init__( | ||
self, | ||
x_ref: Union[np.ndarray, list], | ||
p_val: float = .05, | ||
x_ref_preprocessed: bool = False, | ||
preprocess_at_init: bool = True, | ||
update_x_ref: Optional[Dict[str, int]] = None, | ||
preprocess_fn: Optional[Callable] = None, | ||
kernel: Callable = GaussianRBF, | ||
sigma: Optional[np.ndarray] = None, | ||
configure_kernel_from_x_ref: bool = True, | ||
device: Optional[str] = None, | ||
input_shape: Optional[tuple] = None, | ||
data_type: Optional[str] = None | ||
) -> None: | ||
""" | ||
Maximum Mean Discrepancy (MMD) data drift detector using a linear-time estimator. | ||
|
||
Parameters | ||
---------- | ||
x_ref | ||
Data used as reference distribution. | ||
p_val | ||
p-value used for the significance of the permutation test. | ||
x_ref_preprocessed | ||
Whether the given reference data `x_ref` has been preprocessed yet. If `x_ref_preprocessed=True`, only | ||
the test data `x` will be preprocessed at prediction time. If `x_ref_preprocessed=False`, the reference | ||
data will also be preprocessed. | ||
preprocess_at_init | ||
Whether to preprocess the reference data when the detector is instantiated. Otherwise, the reference | ||
data will be preprocessed at prediction time. Only applies if `x_ref_preprocessed=False`. | ||
update_x_ref | ||
Reference data can optionally be updated to the last n instances seen by the detector | ||
or via reservoir sampling with size n. For the former, the parameter equals {'last': n} while | ||
for reservoir sampling {'reservoir_sampling': n} is passed. | ||
preprocess_fn | ||
Function to preprocess the data before computing the data drift metrics. | ||
kernel | ||
Kernel used for the MMD computation, defaults to Gaussian RBF kernel. | ||
sigma | ||
Optionally set the GaussianRBF kernel bandwidth. Can also pass multiple bandwidth values as an array. | ||
The kernel evaluation is then averaged over those bandwidths. | ||
configure_kernel_from_x_ref | ||
Whether to already configure the kernel bandwidth from the reference data. | ||
device | ||
Device type used. The default None tries to use the GPU and falls back on CPU if needed. | ||
Can be specified by passing either 'cuda', 'gpu' or 'cpu'. | ||
input_shape | ||
Shape of input data. | ||
data_type | ||
Optionally specify the data type (tabular, image or time-series). Added to metadata. | ||
""" | ||
super().__init__( | ||
x_ref=x_ref, | ||
p_val=p_val, | ||
x_ref_preprocessed=x_ref_preprocessed, | ||
preprocess_at_init=preprocess_at_init, | ||
update_x_ref=update_x_ref, | ||
preprocess_fn=preprocess_fn, | ||
sigma=sigma, | ||
configure_kernel_from_x_ref=configure_kernel_from_x_ref, | ||
input_shape=input_shape, | ||
data_type=data_type | ||
) | ||
self.meta.update({'backend': 'pytorch'}) | ||
|
||
# set backend | ||
if device is None or device.lower() in ['gpu', 'cuda']: | ||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
if self.device.type == 'cpu': | ||
print('No GPU detected, fall back on CPU.') | ||
else: | ||
self.device = torch.device('cpu') | ||
|
||
# initialize kernel | ||
sigma = torch.from_numpy(sigma).to(self.device) if isinstance(sigma, # type: ignore[assignment] | ||
np.ndarray) else None | ||
self.kernel = kernel(sigma) if kernel == GaussianRBF else kernel | ||
|
||
# compute kernel matrix for the reference data | ||
if self.infer_sigma or isinstance(sigma, torch.Tensor): | ||
n = self.x_ref.shape[0] | ||
n_hat = int(np.floor(n / 2) * 2) | ||
x = torch.from_numpy(self.x_ref[:n_hat, :]).to(self.device) | ||
self.k_xx = self.kernel(x=x[0::2, :], y=x[1::2, :], | ||
pairwise=False, infer_sigma=self.infer_sigma) | ||
self.infer_sigma = False | ||
else: | ||
self.k_xx, self.infer_sigma = None, True | ||
|
||
def kernel_matrix(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Method is not used I believe? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @arnaudvl The base class requires this method for initialisation, I wonder what would be the preferable solution here? the minimal thing could be to simply leave a pseudo method. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IMO we should remove |
||
""" Compute and return full kernel matrix between arrays x and y. """ | ||
k_xy = self.kernel(x, y, self.infer_sigma) | ||
k_xx = self.k_xx if self.k_xx is not None and self.update_x_ref is None else self.kernel(x, x) | ||
k_yy = self.kernel(y, y) | ||
kernel_mat = torch.cat([torch.cat([k_xx, k_xy], 1), torch.cat([k_xy.T, k_yy], 1)], 0) | ||
return kernel_mat | ||
|
||
def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, float]: | ||
""" | ||
Compute the p-value using the maximum mean discrepancy as a distance measure between the | ||
reference data and the data to be tested. x and x_ref are required to have the same size. | ||
The sample size is then specified as the maximal even number below the data size. | ||
|
||
Parameters | ||
---------- | ||
x | ||
Batch of instances. | ||
|
||
Returns | ||
------- | ||
p-value obtained from the null hypothesis, the MMD^2 between the reference and test set | ||
and the MMD^2 threshold for the given significance level. | ||
""" | ||
x_ref, x = self.preprocess(x) | ||
n = x.shape[0] | ||
m = x_ref.shape[0] | ||
if n != m: | ||
raise ValueError('x and x_ref must have the same size.') | ||
n_hat = int(np.floor(n / 2) * 2) | ||
x_ref = torch.from_numpy(x_ref[:n_hat, :]).to(self.device) # type: ignore[assignment] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe there is a case to be made that there is an explicit check such that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems like quite an issue atm. Agree the safest option would be to explicitly check for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently implemented as raise error for |
||
x = torch.from_numpy(x[:n_hat, :]).to(self.device) # type: ignore[assignment] | ||
if self.k_xx is not None and self.update_x_ref is None: | ||
k_xx = self.k_xx | ||
else: | ||
k_xx = self.kernel(x=x_ref[0::2, :], y=x_ref[1::2, :], pairwise=False) | ||
mmd2, var_mmd2 = linear_mmd2(k_xx, x_ref, x, self.kernel) # type: ignore[arg-type] | ||
if self.device.type == 'cuda': | ||
mmd2 = mmd2.cpu() | ||
mmd2 = mmd2.numpy().item() | ||
var_mmd2 = np.clip(var_mmd2.numpy().item(), 1e-8, 1e8) | ||
std_mmd2 = np.sqrt(var_mmd2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can directly use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The new version uses |
||
t = mmd2 / (std_mmd2 / np.sqrt(n_hat / 2.)) | ||
p_val = 1 - stats.t.cdf(t, df=(n_hat / 2.) - 1) | ||
distance_threshold = stats.t.ppf(1 - self.p_val, df=(n_hat / 2.) - 1) | ||
return p_val, t, distance_threshold |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
from functools import partial | ||
from itertools import product | ||
import numpy as np | ||
import pytest | ||
import torch | ||
import torch.nn as nn | ||
from typing import Callable, List | ||
from alibi_detect.cd.pytorch.mmd import LinearTimeMMDDriftTorch | ||
from alibi_detect.cd.pytorch.preprocess import HiddenOutput, preprocess_drift | ||
|
||
n, n_hidden, n_classes = 500, 10, 5 | ||
|
||
|
||
class MyModel(nn.Module): | ||
def __init__(self, n_features: int): | ||
super().__init__() | ||
self.dense1 = nn.Linear(n_features, 20) | ||
self.dense2 = nn.Linear(20, 2) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
x = nn.ReLU()(self.dense1(x)) | ||
return self.dense2(x) | ||
|
||
|
||
# test List[Any] inputs to the detector | ||
def preprocess_list(x: List[np.ndarray]) -> np.ndarray: | ||
return np.concatenate(x, axis=0) | ||
|
||
|
||
n_features = [10] | ||
n_enc = [None, 3] | ||
preprocess = [ | ||
(None, None), | ||
(preprocess_drift, {'model': HiddenOutput, 'layer': -1}), | ||
(preprocess_list, None) | ||
] | ||
update_x_ref = [{'last': 500}, {'reservoir_sampling': 500}, None] | ||
preprocess_at_init = [True, False] | ||
tests_mmddrift = list(product(n_features, n_enc, preprocess, | ||
update_x_ref, preprocess_at_init)) | ||
n_tests = len(tests_mmddrift) | ||
|
||
|
||
@pytest.fixture | ||
def mmd_params(request): | ||
return tests_mmddrift[request.param] | ||
|
||
|
||
@pytest.mark.parametrize('mmd_params', list(range(n_tests)), indirect=True) | ||
def test_mmd(mmd_params): | ||
n_features, n_enc, preprocess, update_x_ref, preprocess_at_init = mmd_params | ||
|
||
np.random.seed(0) | ||
torch.manual_seed(0) | ||
|
||
x_ref = np.random.randn(n * n_features).reshape(n, n_features).astype(np.float32) | ||
preprocess_fn, preprocess_kwargs = preprocess | ||
to_list = False | ||
if hasattr(preprocess_fn, '__name__') and preprocess_fn.__name__ == 'preprocess_list': | ||
if not preprocess_at_init: | ||
return | ||
to_list = True | ||
x_ref = [_[None, :] for _ in x_ref] | ||
elif isinstance(preprocess_fn, Callable) and 'layer' in list(preprocess_kwargs.keys()) \ | ||
and preprocess_kwargs['model'].__name__ == 'HiddenOutput': | ||
model = MyModel(n_features) | ||
layer = preprocess_kwargs['layer'] | ||
preprocess_fn = partial(preprocess_fn, model=HiddenOutput(model=model, layer=layer)) | ||
else: | ||
preprocess_fn = None | ||
|
||
cd = LinearTimeMMDDriftTorch( | ||
x_ref=x_ref, | ||
p_val=.05, | ||
preprocess_at_init=preprocess_at_init if isinstance(preprocess_fn, Callable) else False, | ||
update_x_ref=update_x_ref, | ||
preprocess_fn=preprocess_fn | ||
) | ||
x = x_ref.copy() | ||
preds = cd.predict(x, return_p_val=True) | ||
assert preds['data']['is_drift'] == 0 and preds['data']['p_val'] >= cd.p_val | ||
if isinstance(update_x_ref, dict): | ||
k = list(update_x_ref.keys())[0] | ||
assert cd.n == len(x) + len(x_ref) | ||
assert cd.x_ref.shape[0] == min(update_x_ref[k], len(x) + len(x_ref)) | ||
|
||
x_h1 = np.random.randn(n * n_features).reshape(n, n_features).astype(np.float32) | ||
if to_list: | ||
x_h1 = [_[None, :] for _ in x_h1] | ||
preds = cd.predict(x_h1, return_p_val=True) | ||
if preds['data']['is_drift'] == 1: | ||
assert preds['data']['p_val'] < preds['data']['threshold'] == cd.p_val | ||
assert preds['data']['distance'] > preds['data']['distance_threshold'] | ||
else: | ||
assert preds['data']['p_val'] >= preds['data']['threshold'] == cd.p_val | ||
assert preds['data']['distance'] <= preds['data']['distance_threshold'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Best to clarify in the docstrings that
n_permutations
is not used for the linear estimator.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.