Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for linear-time mmd estimator. #475

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions alibi_detect/cd/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,12 +570,16 @@ def predict(self, x: Union[np.ndarray, list], return_p_val: bool = True, return_
'data' contains the drift prediction and optionally the p-value, threshold and MMD metric.
"""
# compute drift scores
p_val, dist, dist_permutations = self.score(x)
drift_pred = int(p_val < self.p_val)
p_val, dist, tmp_v = self.score(x)
if len(np.shape(tmp_v)) > 0:
dist_permutations = tmp_v
# compute distance threshold
idx_threshold = int(self.p_val * len(dist_permutations))
distance_threshold = np.sort(dist_permutations)[::-1][idx_threshold]
else:
distance_threshold = tmp_v
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't seem ideal to have to deal with the differing behaviour of score for the original vs new detectors in predict.

Maybe we could move the distance_threshold computation to score for the original MMD detectors, and then the above would be simplified quite a bit? Draft PR for this here: #489

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes indeed, guess the best thing to do here is to follow your draft PR's template to modify the linear-time detector.

Copy link
Contributor

@ascillitoe ascillitoe Apr 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd say if @arnaudvl and @ojcobb agree with the change in #489 , we should merge that, then it would be a super quick change to this PR.


# compute distance threshold
idx_threshold = int(self.p_val * len(dist_permutations))
distance_threshold = np.sort(dist_permutations)[::-1][idx_threshold]
drift_pred = int(p_val < self.p_val)

# update reference dataset
if isinstance(self.update_x_ref, dict) and self.preprocess_fn is not None and self.preprocess_x_ref:
Expand Down
25 changes: 20 additions & 5 deletions alibi_detect/cd/mmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from alibi_detect.utils.frameworks import has_pytorch, has_tensorflow

if has_pytorch:
from alibi_detect.cd.pytorch.mmd import MMDDriftTorch
from alibi_detect.cd.pytorch.mmd import MMDDriftTorch, LinearTimeDriftTorch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change to LinearTimeMMDDriftTorch for consistency with TensorFlow.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.


if has_tensorflow:
from alibi_detect.cd.tensorflow.mmd import MMDDriftTF
from alibi_detect.cd.tensorflow.mmd import MMDDriftTF, LinearTimeMMDDriftTF

logger = logging.getLogger(__name__)

Expand All @@ -18,6 +18,7 @@ def __init__(
x_ref: Union[np.ndarray, list],
backend: str = 'tensorflow',
p_val: float = .05,
estimator: str = 'quad',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would estimator_complexity be more descriptive? (Or at least make clear in the docstring)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added extra description in the docstring.

preprocess_x_ref: bool = True,
update_x_ref: Optional[Dict[str, int]] = None,
preprocess_fn: Optional[Callable] = None,
Expand All @@ -40,6 +41,8 @@ def __init__(
Backend used for the MMD implementation.
p_val
p-value used for the significance of the permutation test.
estimator
Estimator used for the MMD^2 computation {'quad', 'linear'}.
preprocess_x_ref
Whether to already preprocess and store the reference data.
update_x_ref
Expand Down Expand Up @@ -76,7 +79,7 @@ def __init__(

kwargs = locals()
args = [kwargs['x_ref']]
pop_kwargs = ['self', 'x_ref', 'backend', '__class__']
pop_kwargs = ['self', 'x_ref', 'backend', '__class__', 'estimator']
[kwargs.pop(k, None) for k in pop_kwargs]

if kernel is None:
Expand All @@ -88,9 +91,21 @@ def __init__(

if backend == 'tensorflow' and has_tensorflow:
kwargs.pop('device', None)
self._detector = MMDDriftTF(*args, **kwargs) # type: ignore
if estimator == 'quad':
self._detector = MMDDriftTF(*args, **kwargs) # type: ignore
elif estimator == 'linear':
kwargs.pop('n_permutations', None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Best to clarify in the docstrings that n_permutations is not used for the linear estimator.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

self._detector = LinearTimeMMDDriftTF(*args, **kwargs) # type: ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the logic to set self._detector is located here, we should add additional tests to alibi_detect/cd/tests/test_mmd.py to check that the correct subclass is selected conditional on backend and estimator.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, will modify the tests.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simply rewrite the test to go through different backend and estimator options, should do the job.

else:
raise NotImplementedError(f'{estimator} not implemented. Use quad or linear instead.')
else:
self._detector = MMDDriftTorch(*args, **kwargs) # type: ignore
if estimator == 'quad':
self._detector = MMDDriftTorch(*args, **kwargs) # type: ignore
elif estimator == 'linear':
kwargs.pop('n_permutations', None)
self._detector = LinearTimeDriftTorch(*args, **kwargs) # type: ignore
else:
raise NotImplementedError(f'{estimator} not implemented. Use quad or linear instead.')
self.meta = self._detector.meta

def predict(self, x: Union[np.ndarray, list], return_p_val: bool = True, return_distance: bool = True) \
Expand Down
137 changes: 132 additions & 5 deletions alibi_detect/cd/pytorch/mmd.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import logging
import numpy as np
import scipy.stats as stats
import torch
from typing import Callable, Dict, Optional, Tuple, Union
from alibi_detect.cd.base import BaseMMDDrift
from alibi_detect.utils.pytorch.distance import mmd2_from_kernel_matrix
from alibi_detect.utils.pytorch.distance import mmd2_from_kernel_matrix, linear_mmd2
from alibi_detect.utils.pytorch.kernels import GaussianRBF

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -118,17 +119,143 @@ def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, np.ndarray]:
and the MMD^2 values from the permutation test.
"""
x_ref, x = self.preprocess(x)
n = x.shape[0]
x_ref = torch.from_numpy(x_ref).to(self.device) # type: ignore[assignment]
x = torch.from_numpy(x).to(self.device) # type: ignore[assignment]
# compute kernel matrix, MMD^2 and apply permutation test using the kernel matrix
n = x.shape[0]
kernel_mat = self.kernel_matrix(x_ref, x) # type: ignore[arg-type]
kernel_mat = kernel_mat - torch.diag(kernel_mat.diag()) # zero diagonal
mmd2 = mmd2_from_kernel_matrix(kernel_mat, n, permute=False, zero_diag=False)
mmd2 = mmd2_from_kernel_matrix(kernel_mat, n, permute=False, zero_diag=False) # type: ignore[assignment]
mmd2_permuted = torch.Tensor(
[mmd2_from_kernel_matrix(kernel_mat, n, permute=True, zero_diag=False) for _ in range(self.n_permutations)]
)
[mmd2_from_kernel_matrix(kernel_mat, n, permute=True, zero_diag=False)
for _ in range(self.n_permutations)]
)
if self.device.type == 'cuda':
mmd2, mmd2_permuted = mmd2.cpu(), mmd2_permuted.cpu()
p_val = (mmd2 <= mmd2_permuted).float().mean()
return p_val.numpy().item(), mmd2.numpy().item(), mmd2_permuted.numpy()


class LinearTimeDriftTorch(BaseMMDDrift):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned before, should probably be LinearTimeMMDDriftTorch (what a name...).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The names just get better and better 😅

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

def __init__(
self,
x_ref: Union[np.ndarray, list],
p_val: float = .05,
preprocess_x_ref: bool = True,
update_x_ref: Optional[Dict[str, int]] = None,
preprocess_fn: Optional[Callable] = None,
kernel: Callable = GaussianRBF,
sigma: Optional[np.ndarray] = None,
configure_kernel_from_x_ref: bool = True,
n_permutations: int = 100,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should not contain the n_permutations kwarg.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

device: Optional[str] = None,
input_shape: Optional[tuple] = None,
data_type: Optional[str] = None
) -> None:
"""
Maximum Mean Discrepancy (MMD) data drift detector using a permutation test, with linear-time estimator.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't use a permutation test though?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it doesn't, removed now.


Parameters
----------
x_ref
Data used as reference distribution.
p_val
p-value used for the significance of the permutation test.
preprocess_x_ref
Whether to already preprocess and store the reference data.
update_x_ref
Reference data can optionally be updated to the last n instances seen by the detector
or via reservoir sampling with size n. For the former, the parameter equals {'last': n} while
for reservoir sampling {'reservoir_sampling': n} is passed.
preprocess_fn
Function to preprocess the data before computing the data drift metrics.
kernel
Kernel used for the MMD computation, defaults to Gaussian RBF kernel.
sigma
Optionally set the GaussianRBF kernel bandwidth. Can also pass multiple bandwidth values as an array.
The kernel evaluation is then averaged over those bandwidths.
configure_kernel_from_x_ref
Whether to already configure the kernel bandwidth from the reference data.
n_permutations
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again remove

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

Number of permutations used in the permutation test.
device
Device type used. The default None tries to use the GPU and falls back on CPU if needed.
Can be specified by passing either 'cuda', 'gpu' or 'cpu'.
input_shape
Shape of input data.
data_type
Optionally specify the data type (tabular, image or time-series). Added to metadata.
"""
super().__init__(
x_ref=x_ref,
p_val=p_val,
preprocess_x_ref=preprocess_x_ref,
update_x_ref=update_x_ref,
preprocess_fn=preprocess_fn,
sigma=sigma,
configure_kernel_from_x_ref=configure_kernel_from_x_ref,
n_permutations=n_permutations,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does not need to pass n_permutations here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

input_shape=input_shape,
data_type=data_type
)
self.meta.update({'backend': 'pytorch'})

# set backend
if device is None or device.lower() in ['gpu', 'cuda']:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if self.device.type == 'cpu':
print('No GPU detected, fall back on CPU.')
else:
self.device = torch.device('cpu')

# initialize kernel
sigma = torch.from_numpy(sigma).to(self.device) if isinstance(sigma, # type: ignore[assignment]
np.ndarray) else None
self.kernel = kernel(sigma) if kernel == GaussianRBF else kernel

# compute kernel matrix for the reference data
if self.infer_sigma or isinstance(sigma, torch.Tensor):
x = torch.from_numpy(self.x_ref).to(self.device)
self.k_xx = self.kernel(x, x, infer_sigma=self.infer_sigma)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.k_xx is not used?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed, not self.k_xx is reused for later calculation.

self.infer_sigma = False
else:
self.k_xx, self.infer_sigma = None, True

def kernel_matrix(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Method is not used I believe?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@arnaudvl The base class requires this method for initialisation, I wonder what would be the preferable solution here? the minimal thing could be to simply leave a pseudo method.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO we should remove kernel_matrix from BaseMMDDrift, so that it is no longer an abstractmethod. I don't think it makes sense to have it as an abstract method if not all subclasses use/need it.

""" Compute and return full kernel matrix between arrays x and y. """
k_xy = self.kernel(x, y, self.infer_sigma)
k_xx = self.k_xx if self.k_xx is not None and self.update_x_ref is None else self.kernel(x, x)
k_yy = self.kernel(y, y)
kernel_mat = torch.cat([torch.cat([k_xx, k_xy], 1), torch.cat([k_xy.T, k_yy], 1)], 0)
return kernel_mat

def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, np.ndarray]:
"""
Compute the p-value resulting from a permutation test using the maximum mean discrepancy
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove reference to permutation test

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

as a distance measure between the reference data and the data to be tested.

Parameters
----------
x
Batch of instances.

Returns
-------
p-value obtained from the permutation test, the MMD^2 between the reference and test set
and the MMD^2 values from the permutation test.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the MMD^2 values from the permutation test

Does this need updating, since the score method's third return arg is now distance thresholds?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed, thanks for spotting.

"""
x_ref, x = self.preprocess(x)
n = x.shape[0]
m = x_ref.shape[0]
n_hat = int(np.floor(min(n, m) / 2) * 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This behaviour needs to be well documented.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, now described in the docstring.

x_ref = torch.from_numpy(x_ref[:n_hat, :]).to(self.device) # type: ignore[assignment]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe there is a case to be made that there is an explicit check such that n == m and if not, an error is raised. The reason is that silently some unexpected behaviour can occur by only selecting the first n_hat reference/test instances. If say the reference data is ordered and contains samples from classes 1,2 and 3, then only choosing :n_hat could ignore all samples from class 3 and not form an i.i.d. sample anymore. So my preference would be explicit behaviour around this (raising errors) or if we allow this (which I am not in favour of now) then randomly sample n_hat instances from x_ref and x. Good to have some opinions @jklaise @ascillitoe @ojcobb

Copy link
Contributor

@ascillitoe ascillitoe Apr 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like quite an issue atm. Agree the safest option would be to explicitly check for n == m and raise an error. Otherwise, we could check, and randomly subsample if n != m, with a warning raised to inform the user we are doing this.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently implemented as raise error for n!=m. Guess the subsampling should be implemented on a stand-alone part, so that it can be used with other detectors?

x = torch.from_numpy(x[:n_hat, :]).to(self.device) # type: ignore[assignment]
mmd2, var_mmd2 = linear_mmd2(x_ref, x, self.kernel) # type: ignore[arg-type]
if self.device.type == 'cuda':
mmd2 = mmd2.cpu()
mmd2 = mmd2.numpy().item()
var_mmd2 = var_mmd2.numpy().item()
std_mmd2 = np.sqrt(var_mmd2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can directly use torch.std(...) in linear_mmd2? This would remove the few additional lines of code here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new version uses np.sqrt(np.clip(var_mmd2, 1e-8, 1e-8)) for numeric stability.

p_val = 1 - stats.norm.cdf(mmd2 * np.sqrt(n_hat), loc=0., scale=std_mmd2*np.sqrt(2))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick but should this be a t-test?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice spot, now fixed with t-test for both versions.

distance_threshold = stats.norm.ppf(1 - self.p_val, loc=0., scale=std_mmd2*np.sqrt(2))
return p_val, mmd2 * np.sqrt(n_hat), distance_threshold
115 changes: 112 additions & 3 deletions alibi_detect/cd/tensorflow/mmd.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import logging
import numpy as np
import scipy.stats as stats
import tensorflow as tf
from typing import Callable, Dict, Optional, Tuple, Union
from alibi_detect.cd.base import BaseMMDDrift
from alibi_detect.utils.tensorflow.distance import mmd2_from_kernel_matrix
from alibi_detect.utils.tensorflow.distance import mmd2_from_kernel_matrix, linear_mmd2
from alibi_detect.utils.tensorflow.kernels import GaussianRBF

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -112,7 +113,115 @@ def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, np.ndarray]:
mmd2 = mmd2_from_kernel_matrix(kernel_mat, n, permute=False, zero_diag=False).numpy()
mmd2_permuted = np.array(
[mmd2_from_kernel_matrix(kernel_mat, n, permute=True, zero_diag=False).numpy()
for _ in range(self.n_permutations)]
)
for _ in range(self.n_permutations)])
p_val = (mmd2 <= mmd2_permuted).mean()
return p_val, mmd2, mmd2_permuted


class LinearTimeMMDDriftTF(BaseMMDDrift):
def __init__(
self,
x_ref: Union[np.ndarray, list],
p_val: float = .05,
preprocess_x_ref: bool = True,
update_x_ref: Optional[Dict[str, int]] = None,
preprocess_fn: Optional[Callable] = None,
kernel: Callable = GaussianRBF,
sigma: Optional[np.ndarray] = None,
configure_kernel_from_x_ref: bool = True,
input_shape: Optional[tuple] = None,
data_type: Optional[str] = None
) -> None:
"""
Maximum Mean Discrepancy (MMD) data drift detector using a permutation test, with linear-time estimator.

Parameters
----------
x_ref
Data used as reference distribution.
p_val
p-value used for the significance of the permutation test.
preprocess_x_ref
Whether to already preprocess and store the reference data.
update_x_ref
Reference data can optionally be updated to the last n instances seen by the detector
or via reservoir sampling with size n. For the former, the parameter equals {'last': n} while
for reservoir sampling {'reservoir_sampling': n} is passed.
preprocess_fn
Function to preprocess the data before computing the data drift metrics.
kernel
Kernel used for the MMD computation, defaults to Gaussian RBF kernel.
sigma
Optionally set the GaussianRBF kernel bandwidth. Can also pass multiple bandwidth values as an array.
The kernel evaluation is then averaged over those bandwidths.
configure_kernel_from_x_ref
Whether to already configure the kernel bandwidth from the reference data.
n_permutations
Number of permutations used in the permutation test.
input_shape
Shape of input data.
data_type
Optionally specify the data type (tabular, image or time-series). Added to metadata.
"""
super().__init__(
x_ref=x_ref,
p_val=p_val,
preprocess_x_ref=preprocess_x_ref,
update_x_ref=update_x_ref,
preprocess_fn=preprocess_fn,
sigma=sigma,
configure_kernel_from_x_ref=configure_kernel_from_x_ref,
input_shape=input_shape,
data_type=data_type
)
self.meta.update({'backend': 'tensorflow'})

# initialize kernel
if isinstance(sigma, np.ndarray):
sigma = tf.convert_to_tensor(sigma)
self.kernel = kernel(sigma) if kernel == GaussianRBF else kernel

# compute kernel matrix for the reference data
if self.infer_sigma or isinstance(sigma, tf.Tensor):
self.k_xx = self.kernel(self.x_ref, self.x_ref, infer_sigma=self.infer_sigma)
self.infer_sigma = False
else:
self.k_xx, self.infer_sigma = None, True

def kernel_matrix(self, x: Union[np.ndarray, tf.Tensor], y: Union[np.ndarray, tf.Tensor]) -> tf.Tensor:
""" Compute and return full kernel matrix between arrays x and y. """
k_xy = self.kernel(x, y, self.infer_sigma)
k_xx = self.k_xx if self.k_xx is not None and self.update_x_ref is None else self.kernel(x, x)
k_yy = self.kernel(y, y)
kernel_mat = tf.concat([tf.concat([k_xx, k_xy], 1), tf.concat([tf.transpose(k_xy, (1, 0)), k_yy], 1)], 0)
return kernel_mat

def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, np.ndarray]:
"""
Compute the p-value resulting from a permutation test using the maximum mean discrepancy
as a distance measure between the reference data and the data to be tested.

Parameters
----------
x
Batch of instances.

Returns
-------
p-value obtained from the permutation test, the MMD^2 between the reference and test set
and the MMD^2 values from the permutation test.
"""
x_ref, x = self.preprocess(x)
# compute kernel matrix, MMD^2 and apply permutation test using the kernel matrix
n = x.shape[0]
m = x_ref.shape[0]
n_hat = int(np.floor(min(n, m) / 2) * 2)
x_ref = x_ref[:n_hat, :]
x = x[:n_hat, :]
mmd2, var_mmd2 = linear_mmd2(x_ref, x, self.kernel, permute=False)
mmd2 = mmd2.numpy()
var_mmd2 = var_mmd2.numpy()
std_mmd2 = np.sqrt(var_mmd2)
p_val = 1 - stats.norm.cdf(mmd2 * np.sqrt(n_hat), loc=0., scale=std_mmd2*np.sqrt(2))
distance_threshold = stats.norm.ppf(1 - self.p_val, loc=0., scale=std_mmd2*np.sqrt(2))
return p_val, mmd2 * np.sqrt(n_hat), distance_threshold
Loading