Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kernel function extension #558

Open
wants to merge 38 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
b9043fd
Initial kernel change for pytorch
Srceh Jul 15, 2022
b5f48c5
Change torch kernel-based methods to support new kernel behaviours.
Srceh Jul 21, 2022
27be93b
Initial TF implementation added.
Srceh Jul 25, 2022
e05ee05
Modify generic detector class and associated tests.
Srceh Jul 27, 2022
753ba72
Fixed prediction behaviour for torch gpu with new base kernel.
Srceh Jul 31, 2022
8832bef
Fixed feature dimension selection function.
Srceh Aug 5, 2022
d53984f
Added support to passing multiple kernel parameters. Doc string refin…
Srceh Aug 8, 2022
657e7b8
(1) refine various points according to the review. (2) re-design the …
Srceh Aug 18, 2022
8322330
revert mmd_cifar10 notebook
Srceh Aug 18, 2022
0199bac
This commit includes a major re-design of the base kernel class, it n…
Srceh Sep 4, 2022
3d235fb
Refine the behaviour of the new base kernel class, added further erro…
Srceh Sep 8, 2022
4587499
Added extra treatments for different kernel class. Also refine the ty…
Srceh Sep 20, 2022
bd1dde9
Add additional tests for the new kernels, and fix notebooks with new …
Srceh Oct 14, 2022
52486da
Address reviewer comments on: (1) doc string, (2) outdated comments, …
Srceh Oct 17, 2022
d6af592
Address some discussion and comments from the reviewer, mainly on : (…
Srceh Nov 11, 2022
43b0b4c
pre-rebase minor fixes.
Srceh Dec 15, 2022
335fe2c
Initial kernel change for pytorch
Srceh Jul 15, 2022
d059f65
Change torch kernel-based methods to support new kernel behaviours.
Srceh Jul 21, 2022
a843973
Initial TF implementation added.
Srceh Jul 25, 2022
50197ab
Modify generic detector class and associated tests.
Srceh Jul 27, 2022
db824e8
Fixed prediction behaviour for torch gpu with new base kernel.
Srceh Jul 31, 2022
d8c8083
Fixed feature dimension selection function.
Srceh Aug 5, 2022
3de8f98
Added support to passing multiple kernel parameters. Doc string refin…
Srceh Aug 8, 2022
0f63f61
(1) refine various points according to the review. (2) re-design the …
Srceh Aug 18, 2022
e9e9874
revert mmd_cifar10 notebook
Srceh Aug 18, 2022
61032e0
This commit includes a major re-design of the base kernel class, it n…
Srceh Sep 4, 2022
82d9dd4
Refine the behaviour of the new base kernel class, added further erro…
Srceh Sep 8, 2022
7e64b57
Added extra treatments for different kernel class. Also refine the ty…
Srceh Sep 20, 2022
1084180
Add additional tests for the new kernels, and fix notebooks with new …
Srceh Oct 14, 2022
8af1c51
Address reviewer comments on: (1) doc string, (2) outdated comments, …
Srceh Oct 17, 2022
bdad9d3
Address some discussion and comments from the reviewer, mainly on : (…
Srceh Nov 11, 2022
bca489b
Initial rebase with the current master.
Srceh Jan 3, 2023
be7d9fa
Initial integrate with the Keops and serialisation
Srceh Feb 1, 2023
6cc1798
Add serialisation for all new kernel classes.
Srceh Feb 24, 2023
1e381f4
Add support for serialisation of composite kernels.
Srceh Mar 20, 2023
b14db5a
Move composite kernel validation functions to loading module. Fixes f…
Srceh Mar 20, 2023
b24655e
(1) add 'kernel_list' key in config dict for better management. (2) m…
Srceh Mar 22, 2023
fb922a2
Fix dimension selection for TF.
Srceh Apr 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 13 additions & 16 deletions alibi_detect/cd/_domain_clf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from abc import ABC, abstractmethod
from typing import Callable
import numpy as np
from sklearn.svm import SVC
from sklearn.calibration import CalibratedClassifierCV
Expand Down Expand Up @@ -34,7 +33,6 @@ def predict(self, x: np.ndarray) -> np.ndarray:

class _SVCDomainClf(_DomainClf):
def __init__(self,
kernel: Callable,
cal_method: str = 'sigmoid',
clf_kwargs: dict = None):
"""
Expand All @@ -52,52 +50,51 @@ def __init__(self,
clf_kwargs
A dictionary of keyword arguments to be passed to the :py:class:`~sklearn.svm.SVC` classifier.
"""
self.kernel = kernel
self.cal_method = cal_method
clf_kwargs = clf_kwargs or {}
self.clf = SVC(kernel=self.kernel, **clf_kwargs)
self.clf = SVC(kernel='precomputed', **clf_kwargs)

def fit(self, x: np.ndarray, y: np.ndarray):
def fit(self, K_x: np.ndarray, y: np.ndarray):
"""
Method to fit the classifier.

Parameters
----------
x
Array containing conditioning variables for each instance.
K_x
Kernel matrix on the conditioning variables.
y
Boolean array marking the domain each instance belongs to (`0` for reference, `1` for test).
"""
clf = self.clf
clf.fit(x, y)
clf.fit(K_x, y)
self.clf = clf

def calibrate(self, x: np.ndarray, y: np.ndarray):
def calibrate(self, K_x: np.ndarray, y: np.ndarray):
"""
Method to calibrate the classifier's predicted probabilities.

Parameters
----------
x
Array containing conditioning variables for each instance.
K_x
Kernel matrix on the conditioning variables.
y
Boolean array marking the domain each instance belongs to (`0` for reference, `1` for test).
"""
clf = CalibratedClassifierCV(self.clf, method=self.cal_method, cv='prefit')
clf.fit(x, y)
clf.fit(K_x, y)
self.clf = clf

def predict(self, x: np.ndarray) -> np.ndarray:
def predict(self, K_x: np.ndarray) -> np.ndarray:
"""
The classifier's predict method.

Parameters
----------
x
Array containing conditioning variables for each instance.
K_x
Kernel matrix on the conditioning variables.

Returns
-------
Propensity scores (the probability of being test instances).
"""
return self.clf.predict_proba(x)[:, 1]
return self.clf.predict_proba(K_x)[:, 1]
13 changes: 1 addition & 12 deletions alibi_detect/cd/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,6 @@ def __init__(
preprocess_at_init: bool = True,
update_x_ref: Optional[Dict[str, int]] = None,
preprocess_fn: Optional[Callable] = None,
sigma: Optional[np.ndarray] = None,
configure_kernel_from_x_ref: bool = True,
n_permutations: int = 100,
input_shape: Optional[tuple] = None,
Expand Down Expand Up @@ -536,9 +535,6 @@ def __init__(
for reservoir sampling {'reservoir_sampling': n} is passed.
preprocess_fn
Function to preprocess the data before computing the data drift metrics.
sigma
Optionally set the Gaussian RBF kernel bandwidth. Can also pass multiple bandwidth values as an array.
The kernel evaluation is then averaged over those bandwidths.
configure_kernel_from_x_ref
Whether to already configure the kernel bandwidth from the reference data.
n_permutations
Expand All @@ -553,12 +549,7 @@ def __init__(
if p_val is None:
logger.warning('No p-value set for the drift threshold. Need to set it to detect data drift.')

self.infer_sigma = configure_kernel_from_x_ref
if configure_kernel_from_x_ref and isinstance(sigma, np.ndarray):
self.infer_sigma = False
logger.warning('`sigma` is specified for the kernel and `configure_kernel_from_x_ref` '
'is set to True. `sigma` argument takes priority over '
'`configure_kernel_from_x_ref` (set to False).')
self.infer_parameter = configure_kernel_from_x_ref

# x_ref preprocessing
self.preprocess_at_init = preprocess_at_init
Expand Down Expand Up @@ -668,7 +659,6 @@ def __init__(
preprocess_at_init: bool = True,
update_x_ref: Optional[Dict[str, int]] = None,
preprocess_fn: Optional[Callable] = None,
sigma: Optional[np.ndarray] = None,
n_permutations: int = 100,
n_kernel_centers: Optional[int] = None,
lambda_rd_max: float = 0.2,
Expand Down Expand Up @@ -731,7 +721,6 @@ def __init__(

# Other attributes
self.p_val = p_val
self.sigma = sigma
self.update_x_ref = update_x_ref
self.preprocess_fn = preprocess_fn
self.n = len(x_ref)
Expand Down
10 changes: 6 additions & 4 deletions alibi_detect/cd/context_aware.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from alibi_detect.utils.frameworks import has_pytorch, has_tensorflow, BackendValidator, Framework
from alibi_detect.utils.warnings import deprecated_alias
from alibi_detect.base import DriftConfigMixin
from alibi_detect.utils.pytorch.kernels import BaseKernel as BaseKernel_pt
from alibi_detect.utils.tensorflow.kernels import BaseKernel as BaseKernel_tf

if has_pytorch:
from alibi_detect.cd.pytorch.context_aware import ContextMMDDriftTorch
Expand All @@ -26,8 +28,8 @@ def __init__(
preprocess_at_init: bool = True,
update_ref: Optional[Dict[str, int]] = None,
preprocess_fn: Optional[Callable] = None,
x_kernel: Callable = None,
c_kernel: Callable = None,
x_kernel: Union[BaseKernel_pt, BaseKernel_tf] = None,
c_kernel: Union[BaseKernel_pt, BaseKernel_tf] = None,
n_permutations: int = 1000,
prop_c_held: float = 0.25,
n_folds: int = 5,
Expand Down Expand Up @@ -109,9 +111,9 @@ def __init__(
else:
from alibi_detect.utils.pytorch.kernels import GaussianRBF # type: ignore[no-redef]
if x_kernel is None:
kwargs.update({'x_kernel': GaussianRBF})
kwargs.update({'x_kernel': GaussianRBF()})
if c_kernel is None:
kwargs.update({'c_kernel': GaussianRBF})
kwargs.update({'c_kernel': GaussianRBF()})

if backend == Framework.TENSORFLOW:
kwargs.pop('device', None)
Expand Down
58 changes: 15 additions & 43 deletions alibi_detect/cd/keops/learned_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
from functools import partial
from tqdm import tqdm
import numpy as np
from pykeops.torch import LazyTensor
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from typing import Callable, Dict, List, Optional, Union, Tuple
from alibi_detect.cd.base import BaseLearnedKernelDrift
from alibi_detect.utils.pytorch import get_device, predict_batch
from alibi_detect.utils.pytorch import get_device
from alibi_detect.utils.pytorch.data import TorchDataset
from alibi_detect.utils.frameworks import Framework

Expand Down Expand Up @@ -137,6 +136,7 @@ def __init__(
self.device = get_device(device)
self.original_kernel = kernel
self.kernel = deepcopy(kernel)
self.kernel = self.kernel.to(self.device)

# Check kernel format
self.has_proj = hasattr(self.kernel, 'proj') and isinstance(self.kernel.proj, nn.Module)
Expand Down Expand Up @@ -174,21 +174,10 @@ def __init__(self, kernel: nn.Module, var_reg: float, has_proj: bool, has_kernel

def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
n = len(x)
if self.has_proj and isinstance(self.kernel.proj, nn.Module):
x_proj, y_proj = self.kernel.proj(x), self.kernel.proj(y)
else:
x_proj, y_proj = x, y
x2_proj, x_proj = LazyTensor(x_proj[None, :, :]), LazyTensor(x_proj[:, None, :])
y2_proj, y_proj = LazyTensor(y_proj[None, :, :]), LazyTensor(y_proj[:, None, :])
if self.has_kernel_b:
x2, x = LazyTensor(x[None, :, :]), LazyTensor(x[:, None, :])
y2, y = LazyTensor(y[None, :, :]), LazyTensor(y[:, None, :])
else:
x, x2, y, y2 = None, None, None, None

k_xy = self.kernel(x_proj, y2_proj, x, y2)
k_xx = self.kernel(x_proj, x2_proj, x, x2)
k_yy = self.kernel(y_proj, y2_proj, y, y2)
k_xy = self.kernel(x, y)
k_xx = self.kernel(x, x)
k_yy = self.kernel(y, y)
h_mat = k_xx + k_yy - k_xy - k_xy.t()

h_i = h_mat.sum(1).squeeze(-1)
Expand Down Expand Up @@ -221,6 +210,7 @@ def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, float]:

self.kernel = deepcopy(self.original_kernel) if self.retrain_from_scratch else self.kernel
self.kernel = self.kernel.to(self.device)

train_args = [self.j_hat, (dl_ref_tr, dl_cur_tr), self.device]
LearnedKernelDriftKeops.trainer(*train_args, **self.train_kwargs) # type: ignore

Expand Down Expand Up @@ -263,42 +253,24 @@ def _mmd2(self, x_all: Union[list, torch.Tensor], perms: List[torch.Tensor], m:
preprocess_batch_fn = self.train_kwargs['preprocess_fn']
if isinstance(preprocess_batch_fn, Callable): # type: ignore[arg-type]
x_all = preprocess_batch_fn(x_all) # type: ignore[operator]
if self.has_proj:
x_all_proj = predict_batch(x_all, self.kernel.proj, device=self.device, batch_size=self.batch_size_predict,
dtype=x_all.dtype if isinstance(x_all, torch.Tensor) else torch.float32)
else:
x_all_proj = x_all

x, x2, y, y2 = None, None, None, None
x, y = None, None
k_xx, k_yy, k_xy = [], [], []
for batch in range(self.n_batches):
i, j = batch * self.batch_size_perms, (batch + 1) * self.batch_size_perms
# Stack a batch of permuted reference and test tensors and their projections
x_proj = torch.cat([x_all_proj[perm[:m]][None, :, :] for perm in perms[i:j]], 0)
y_proj = torch.cat([x_all_proj[perm[m:]][None, :, :] for perm in perms[i:j]], 0)
if self.has_kernel_b:
x = torch.cat([x_all[perm[:m]][None, :, :] for perm in perms[i:j]], 0)
y = torch.cat([x_all[perm[m:]][None, :, :] for perm in perms[i:j]], 0)
x = torch.cat([x_all[perm[:m]][None, :, :] for perm in perms[i:j]], 0)
y = torch.cat([x_all[perm[m:]][None, :, :] for perm in perms[i:j]], 0)
if batch == 0:
x_proj = torch.cat([x_all_proj[None, :m, :], x_proj], 0)
y_proj = torch.cat([x_all_proj[None, m:, :], y_proj], 0)
if self.has_kernel_b:
x = torch.cat([x_all[None, :m, :], x], 0) # type: ignore[call-overload]
y = torch.cat([x_all[None, m:, :], y], 0) # type: ignore[call-overload]
x_proj, y_proj = x_proj.to(self.device), y_proj.to(self.device)
if self.has_kernel_b:
x, y = x.to(self.device), y.to(self.device)
x = torch.cat([x_all[None, :m, :], x], 0) # type: ignore[call-overload]
y = torch.cat([x_all[None, m:, :], y], 0) # type: ignore[call-overload]
x, y = x.to(self.device), y.to(self.device)

# Batch-wise kernel matrix computation over the permutations
with torch.no_grad():
x2_proj, x_proj = LazyTensor(x_proj[:, None, :, :]), LazyTensor(x_proj[:, :, None, :])
y2_proj, y_proj = LazyTensor(y_proj[:, None, :, :]), LazyTensor(y_proj[:, :, None, :])
if self.has_kernel_b:
x2, x = LazyTensor(x[:, None, :, :]), LazyTensor(x[:, :, None, :])
y2, y = LazyTensor(y[:, None, :, :]), LazyTensor(y[:, :, None, :])
k_xy.append(self.kernel(x_proj, y2_proj, x, y2).sum(1).sum(1).squeeze(-1))
k_xx.append(self.kernel(x_proj, x2_proj, x, x2).sum(1).sum(1).squeeze(-1))
k_yy.append(self.kernel(y_proj, y2_proj, y, y2).sum(1).sum(1).squeeze(-1))
k_xy.append(self.kernel(x, y).sum(1).sum(1).squeeze(-1))
k_xx.append(self.kernel(x, x).sum(1).sum(1).squeeze(-1))
k_yy.append(self.kernel(y, y).sum(1).sum(1).squeeze(-1))

c_xx, c_yy, c_xy = 1 / (m * (m - 1)), 1 / (n * (n - 1)), 2. / (m * n)
# Note that the MMD^2 estimates assume that the diagonal of the kernel matrix consists of 1's
Expand Down
57 changes: 32 additions & 25 deletions alibi_detect/cd/keops/mmd.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import logging
import numpy as np
from pykeops.torch import LazyTensor
import torch
from typing import Callable, Dict, List, Optional, Tuple, Union
from alibi_detect.cd.base import BaseMMDDrift
from alibi_detect.utils.keops.kernels import GaussianRBF
from alibi_detect.utils.keops.kernels import BaseKernel, GaussianRBF
from alibi_detect.utils.pytorch import get_device
from alibi_detect.utils.frameworks import Framework

Expand All @@ -20,8 +19,7 @@ def __init__(
preprocess_at_init: bool = True,
update_x_ref: Optional[Dict[str, int]] = None,
preprocess_fn: Optional[Callable] = None,
kernel: Callable = GaussianRBF,
sigma: Optional[np.ndarray] = None,
kernel: Union[BaseKernel, Callable] = GaussianRBF,
configure_kernel_from_x_ref: bool = True,
n_permutations: int = 100,
batch_size_permutations: int = 1000000,
Expand Down Expand Up @@ -53,9 +51,6 @@ def __init__(
Function to preprocess the data before computing the data drift metrics.
kernel
Kernel used for the MMD computation, defaults to Gaussian RBF kernel.
sigma
Optionally set the GaussianRBF kernel bandwidth. Can also pass multiple bandwidth values as an array.
The kernel evaluation is then averaged over those bandwidths.
configure_kernel_from_x_ref
Whether to already configure the kernel bandwidth from the reference data.
n_permutations
Expand All @@ -77,7 +72,6 @@ def __init__(
preprocess_at_init=preprocess_at_init,
update_x_ref=update_x_ref,
preprocess_fn=preprocess_fn,
sigma=sigma,
configure_kernel_from_x_ref=configure_kernel_from_x_ref,
n_permutations=n_permutations,
input_shape=input_shape,
Expand All @@ -88,24 +82,39 @@ def __init__(
# set device
self.device = get_device(device)

# initialize kernel
sigma = torch.from_numpy(sigma).to(self.device) if isinstance(sigma, # type: ignore[assignment]
np.ndarray) else None
self.kernel = kernel(sigma).to(self.device) if kernel == GaussianRBF else kernel
# initialise kernel
if isinstance(kernel, BaseKernel):
self.kernel = kernel
elif kernel == GaussianRBF:
self.kernel = kernel()
else:
raise ValueError("kernel must be an instance of alibi_detect.utils.keops.kernels.BaseKernel or a callable ")

self.kernel_parameter_specified = True
if hasattr(kernel, 'parameter_dict'):
for param in self.kernel.parameter_dict.keys():
kernel.parameter_dict[param].value.to(self.device)
if kernel.parameter_dict[param].requires_init:
self.given_kernel_parameter = False
break

if self.kernel_parameter_specified and self.infer_parameter:
self.infer_parameter = False
logger.warning('parameters are specified for the kernel and `configure_kernel_from_x_ref` '
'is set to True. Specified parameters take priority over '
'`configure_kernel_from_x_ref` (set to False).')

# set the correct MMD^2 function based on the batch size for the permutations
self.batch_size = batch_size_permutations
self.n_batches = 1 + (n_permutations - 1) // batch_size_permutations

# infer the kernel bandwidth from the reference data
if isinstance(sigma, torch.Tensor):
self.infer_sigma = False
elif self.infer_sigma:
x = torch.from_numpy(self.x_ref).to(self.device)
_ = self.kernel(LazyTensor(x[:, None, :]), LazyTensor(x[None, :, :]), infer_sigma=self.infer_sigma)
self.infer_sigma = False
if self.infer_parameter:
x = torch.from_numpy(self.x_ref).to(self.device).reshape(1, self.x_ref.shape[0], -1)
_ = self.kernel(x, x, infer_parameter=self.infer_parameter)
self.infer_parameter = False
else:
self.infer_sigma = True
self.infer_parameter = True

def _mmd2(self, x_all: torch.Tensor, perms: List[torch.Tensor], m: int, n: int) \
-> Tuple[torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -139,12 +148,10 @@ def _mmd2(self, x_all: torch.Tensor, perms: List[torch.Tensor], m: int, n: int)
x, y = x.to(self.device), y.to(self.device)

# batch-wise kernel matrix computation over the permutations
k_xy.append(self.kernel(
LazyTensor(x[:, :, None, :]), LazyTensor(y[:, None, :, :]), self.infer_sigma).sum(1).sum(1).squeeze(-1))
k_xx.append(self.kernel(
LazyTensor(x[:, :, None, :]), LazyTensor(x[:, None, :, :])).sum(1).sum(1).squeeze(-1))
k_yy.append(self.kernel(
LazyTensor(y[:, :, None, :]), LazyTensor(y[:, None, :, :])).sum(1).sum(1).squeeze(-1))
k_xy.append(self.kernel(x, y, infer_parameter=self.infer_parameter).sum(1).sum(1).squeeze(-1))
k_xx.append(self.kernel(x, x, infer_parameter=self.infer_parameter).sum(1).sum(1).squeeze(-1))
k_yy.append(self.kernel(y, y, infer_parameter=self.infer_parameter).sum(1).sum(1).squeeze(-1))

c_xx, c_yy, c_xy = 1 / (m * (m - 1)), 1 / (n * (n - 1)), 2. / (m * n)
# Note that the MMD^2 estimates assume that the diagonal of the kernel matrix consists of 1's
stats = c_xx * (torch.cat(k_xx) - m) + c_yy * (torch.cat(k_yy) - n) - c_xy * torch.cat(k_xy)
Expand Down
Loading