Skip to content

Commit

Permalink
Merge branch 'master' of github.com:jrg365/gpytorch
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobrgardner committed Apr 15, 2019
2 parents 2c58abe + f90f944 commit 5f691af
Show file tree
Hide file tree
Showing 11 changed files with 292 additions and 52 deletions.
10 changes: 1 addition & 9 deletions gpytorch/beta_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,6 @@ class checkpoint_kernel(_value_context):
_global_value = 0


class diagonal_correction(_feature_flag):
"""
Add a diagonal correction to scalable inducing point methods
"""

_state = True


class default_preconditioner(_feature_flag):
"""
Add a diagonal correction to scalable inducing point methods
Expand All @@ -55,4 +47,4 @@ class default_preconditioner(_feature_flag):
pass


__all__ = ["checkpoint_kernel", "diagonal_correction", "default_preconditioner"]
__all__ = ["checkpoint_kernel", "default_preconditioner"]
30 changes: 23 additions & 7 deletions gpytorch/likelihoods/gaussian_likelihood.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#!/usr/bin/env python3

from copy import deepcopy

import math
import warnings
from typing import Any, Optional
Expand Down Expand Up @@ -37,8 +39,8 @@ def _shaped_noise_covar(self, base_shape: torch.Size, *params: Any, **kwargs: An
return self.noise_covar(*params, shape=shape, **kwargs)

def forward(self, function_samples: Tensor, *params: Any, **kwargs: Any) -> base_distributions.Normal:
observation_noise = self._shaped_noise_covar(function_samples.shape, *params, **kwargs).diag()
return base_distributions.Normal(function_samples, observation_noise.sqrt())
noise = self._shaped_noise_covar(function_samples.shape, *params, **kwargs).diag()
return base_distributions.Normal(function_samples, noise.sqrt())

def marginal(self, function_dist: MultivariateNormal, *params: Any, **kwargs: Any) -> MultivariateNormal:
mean, covar = function_dist.mean, function_dist.lazy_covariance_matrix
Expand Down Expand Up @@ -94,9 +96,8 @@ class FixedNoiseGaussianLikelihood(_GaussianLikelihoodBase):
:attr:`learn_additional_noise` (bool, optional):
Set to true if you additionally want to learn added diagonal noise, similar to GaussianLikelihood.
Note that this likelihood takes an additional argument when you call it, `observation_noise`, that adds
a specified amount of noise to the passed MultivariateNormal. This allows for adding known observational noise
to test data.
Note that this likelihood takes an additional argument when you call it, `noise`, that adds a specified amount
of noise to the passed MultivariateNormal. This allows for adding known observational noise to test data.
Example:
>>> train_x = torch.randn(55, 2)
Expand All @@ -106,7 +107,7 @@ class FixedNoiseGaussianLikelihood(_GaussianLikelihoodBase):
>>>
>>> test_x = torch.randn(21, 2)
>>> test_noises = torch.ones(21) * 0.02
>>> pred_y = likelihood(gp_model(test_x), observation_noise=test_noises)
>>> pred_y = likelihood(gp_model(test_x), noise=test_noises)
"""
def __init__(
self,
Expand Down Expand Up @@ -155,6 +156,21 @@ def second_noise(self, value: Tensor) -> None:
)
self.second_noise_covar.initialize(noise=value)

def get_fantasy_likelihood(self, **kwargs):
if "noise" not in kwargs:
raise RuntimeError("FixedNoiseGaussianLikelihood.fantasize requires a `noise` kwarg")
old_noise_covar = self.noise_covar
self.noise_covar = None
fantasy_liklihood = deepcopy(self)
self.noise_covar = old_noise_covar

old_noise = old_noise_covar.noise
new_noise = kwargs.get("noise")
if old_noise.dim() != new_noise.dim():
old_noise = old_noise.expand(*new_noise.shape[:-1], old_noise.shape[-1])
fantasy_liklihood.noise_covar = FixedGaussianNoise(noise=torch.cat([old_noise, new_noise], -1))
return fantasy_liklihood

def _shaped_noise_covar(self, base_shape: torch.Size, *params: Any, **kwargs: Any):
if len(params) > 0:
# we can infer the shape from the params
Expand All @@ -170,7 +186,7 @@ def _shaped_noise_covar(self, base_shape: torch.Size, *params: Any, **kwargs: An
elif isinstance(res, ZeroLazyTensor):
warnings.warn(
"You have passed data through a FixedNoiseGaussianLikelihood that did not match the size "
"of the fixed noise, *and* you did not specify observation_noise. This is treated as a no-op."
"of the fixed noise, *and* you did not specify noise. This is treated as a no-op."
)

return res
5 changes: 5 additions & 0 deletions gpytorch/likelihoods/likelihood.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#!/usr/bin/env python3

from copy import deepcopy

import functools
import torch
import warnings
Expand Down Expand Up @@ -109,6 +111,9 @@ def variational_log_probability(self, function_dist, observations):
)
return self.expected_log_prob(observations, function_dist)

def get_fantasy_likelihood(self, **kwargs):
return deepcopy(self)

def __call__(self, input, *params, **kwargs):
# Conditional
if torch.is_tensor(input):
Expand Down
6 changes: 3 additions & 3 deletions gpytorch/likelihoods/multitask_gaussian_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ def _shaped_noise_covar(self, base_shape, *params):
return task_covar

def forward(self, function_samples: Tensor, *params: Any, **kwargs: Any) -> base_distributions.Normal:
observation_noise = self._shaped_noise_covar(function_samples.shape, *params, **kwargs).diag()
observation_noise = observation_noise.view(*observation_noise.shape[:-1], *function_samples.shape[-2:])
return base_distributions.Normal(function_samples, observation_noise.sqrt())
noise = self._shaped_noise_covar(function_samples.shape, *params, **kwargs).diag()
noise = noise.view(*noise.shape[:-1], *function_samples.shape[-2:])
return base_distributions.Normal(function_samples, noise.sqrt())


class MultitaskGaussianLikelihood(_MultitaskGaussianLikelihoodBase):
Expand Down
56 changes: 44 additions & 12 deletions gpytorch/likelihoods/noise_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,27 @@ class Noise(Module):


class _HomoskedasticNoiseBase(Noise):
def __init__(self, noise_prior=None, noise_constraint=None, batch_shape=torch.Size(), num_tasks=1):
def __init__(
self,
noise_prior=None,
noise_constraint=None,
batch_shape=torch.Size(),
num_tasks=1,
):
super().__init__()
if noise_constraint is None:
noise_constraint = GreaterThan(1e-4)

self.register_parameter(name="raw_noise", parameter=Parameter(torch.zeros(*batch_shape, num_tasks)))
self.register_parameter(
name="raw_noise", parameter=Parameter(torch.zeros(*batch_shape, num_tasks))
)
if noise_prior is not None:
self.register_prior("noise_prior", noise_prior, lambda: self.noise, lambda v: self._set_noise(v))
self.register_prior(
"noise_prior",
noise_prior,
lambda: self.noise,
lambda v: self._set_noise(v),
)

self.register_constraint("raw_noise", noise_constraint)

Expand All @@ -42,7 +55,9 @@ def _set_noise(self, value: Tensor) -> None:
value = torch.as_tensor(value).to(self.raw_noise)
self.initialize(raw_noise=self.raw_noise_constraint.inverse_transform(value))

def forward(self, *params: Any, shape: Optional[torch.Size] = None) -> DiagLazyTensor:
def forward(
self, *params: Any, shape: Optional[torch.Size] = None
) -> DiagLazyTensor:
"""In the homoskedastic case, the parameters are only used to infer the required shape.
Here are the possible scenarios:
- non-batched noise, non-batched input, non-MT -> noise_diag shape is `n`
Expand Down Expand Up @@ -75,7 +90,9 @@ def forward(self, *params: Any, shape: Optional[torch.Size] = None) -> DiagLazyT


class HomoskedasticNoise(_HomoskedasticNoiseBase):
def __init__(self, noise_prior=None, noise_constraint=None, batch_shape=torch.Size()):
def __init__(
self, noise_prior=None, noise_constraint=None, batch_shape=torch.Size()
):
super().__init__(
noise_prior=noise_prior,
noise_constraint=noise_constraint,
Expand All @@ -85,7 +102,13 @@ def __init__(self, noise_prior=None, noise_constraint=None, batch_shape=torch.Si


class MultitaskHomoskedasticNoise(_HomoskedasticNoiseBase):
def __init__(self, num_tasks, noise_prior=None, noise_constraint=None, batch_shape=torch.Size()):
def __init__(
self,
num_tasks,
noise_prior=None,
noise_constraint=None,
batch_shape=torch.Size(),
):
super().__init__(
noise_prior=noise_prior,
noise_constraint=noise_constraint,
Expand All @@ -104,17 +127,26 @@ def __init__(self, noise_model, noise_indices=None, noise_constraint=None):
self._noise_indices = noise_indices

def forward(
self, *params: Any, batch_shape: Optional[torch.Size] = None, shape: Optional[torch.Size] = None
self,
*params: Any,
batch_shape: Optional[torch.Size] = None,
shape: Optional[torch.Size] = None
) -> DiagLazyTensor:
if len(params) == 1 and not torch.is_tensor(params[0]):
output = self.noise_model(*params[0])
else:
output = self.noise_model(*params)
if not isinstance(output, MultivariateNormal):
raise NotImplementedError("Currently only noise models that return a MultivariateNormal are supported")
raise NotImplementedError(
"Currently only noise models that return a MultivariateNormal are supported"
)
# note: this also works with MultitaskMultivariateNormal, where this
# will return a batched DiagLazyTensors of size n x num_tasks x num_tasks
noise_diag = output.mean if self._noise_indices is None else output.mean[..., self._noise_indices]
noise_diag = (
output.mean
if self._noise_indices is None
else output.mean[..., self._noise_indices]
)
return DiagLazyTensor(self._noise_constraint.transform(noise_diag))


Expand All @@ -127,15 +159,15 @@ def forward(
self,
*params: Any,
shape: Optional[torch.Size] = None,
observation_noise: Optional[Tensor] = None,
noise: Optional[Tensor] = None,
**kwargs: Any
) -> DiagLazyTensor:
if shape is None:
p = params[0] if torch.is_tensor(params[0]) else params[0][0]
shape = p.shape if len(p.shape) == 1 else p.shape[:-1]

if observation_noise is not None:
return DiagLazyTensor(observation_noise)
if noise is not None:
return DiagLazyTensor(noise)
elif shape[-1] == self.noise.shape[-1]:
return DiagLazyTensor(self.noise)
else:
Expand Down
16 changes: 13 additions & 3 deletions gpytorch/models/exact_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def get_fantasy_model(self, inputs, targets, **kwargs):
if not isinstance(inputs, list):
inputs = [inputs]

inputs = list(i.unsqueeze(-1) if i.ndimension() == 1 else i for i in inputs)
inputs = [i.unsqueeze(-1) if i.ndimension() == 1 else i for i in inputs]

# If input is n x d but targets is b x n x d, expand input to b x n x d
for i, input in enumerate(inputs):
Expand All @@ -133,28 +133,38 @@ def get_fantasy_model(self, inputs, targets, **kwargs):
full_inputs = [torch.cat([train_input, input], dim=-2) for train_input, input in zip(train_inputs, inputs)]
full_targets = torch.cat([train_targets, targets], dim=-1)

try:
fantasy_kwargs = {"noise": kwargs.pop("noise")}
except KeyError:
fantasy_kwargs = {}

full_output = super(ExactGP, self).__call__(*full_inputs, **kwargs)

# Copy model without copying training data or prediction strategy (since we'll overwrite those)
old_pred_strat = self.prediction_strategy
old_train_inputs = self.train_inputs
old_train_targets = self.train_targets
old_likelihood = self.likelihood
self.prediction_strategy = None
self.train_inputs = None
self.train_targets = None
self.likelihood = None
new_model = deepcopy(self)
self.prediction_strategy = old_pred_strat
self.train_inputs = old_train_inputs
self.train_targets = old_train_targets
self.likelihood = old_likelihood

new_model.train_inputs = full_inputs
new_model.train_targets = full_targets
new_model.likelihood = old_likelihood.get_fantasy_likelihood(**fantasy_kwargs)
new_model.prediction_strategy = old_pred_strat.get_fantasy_strategy(
inputs,
targets,
full_inputs,
full_targets,
full_output
full_output,
**fantasy_kwargs,
)

return new_model
Expand All @@ -179,7 +189,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,

def __call__(self, *args, **kwargs):
train_inputs = list(self.train_inputs) if self.train_inputs is not None else []
inputs = list(i.unsqueeze(-1) if i.ndimension() == 1 else i for i in args)
inputs = [i.unsqueeze(-1) if i.ndimension() == 1 else i for i in args]

# Training mode: optimizing
if self.training:
Expand Down
15 changes: 9 additions & 6 deletions gpytorch/models/exact_prediction_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, train_inputs, train_prior_dist, train_labels, likelihood):
self.lik_train_train_covar = mvn.lazy_covariance_matrix

def __deepcopy__(self, memo):
# deepcopying prediciton strategies of a model evaluated on inputs that require gradients fails
# deepcopying prediction strategies of a model evaluated on inputs that require gradients fails
# with RuntimeError (Only Tensors created explicitly by the user (graph leaves) support the deepcopy
# protocol at the moment). Overwriting this method make sure that the prediction strategies of a
# model are set to None upon deepcopying.
Expand Down Expand Up @@ -80,7 +80,7 @@ def _exact_predictive_covar_inv_quad_form_root(self, precomputed_cache, test_tra
# where S S^T = (K_XX + sigma^2 I)^-1
return test_train_covar.matmul(precomputed_cache)

def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_output):
def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_output, **kwargs):
"""
Returns a new PredictionStrategy that incorporates the specified inputs and targets as new training data.
Expand Down Expand Up @@ -108,9 +108,11 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_
# Evaluate fant x train and fant x fant covariance matrices, leave train x train unevaluated.
fant_fant_covar = full_covar[..., num_train:, num_train:]
fant_mean = full_mean[..., num_train:]
mvn = self.likelihood(self.train_prior_dist.__class__(fant_mean, fant_fant_covar), inputs)
fant_fant_covar = mvn.covariance_matrix
mvn = self.train_prior_dist.__class__(fant_mean, fant_fant_covar)
self.likelihood = self.likelihood.get_fantasy_likelihood(**kwargs)
mvn_obs = self.likelihood(mvn, inputs, **kwargs)

fant_fant_covar = mvn_obs.covariance_matrix
fant_train_covar = delazify(full_covar[..., num_train:, :num_train])

self.fantasy_inputs = inputs
Expand All @@ -134,10 +136,11 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_
# Solve for "b", the lower portion of the *new* \\alpha corresponding to the fantasy points.
schur_complement = fant_fant_covar - fant_train_covar.matmul(fant_solve)
small_system_rhs = targets - fant_mean - fant_train_covar.matmul(self.mean_cache)
small_system_rhs = small_system_rhs.unsqueeze(-1)
# Schur complement of a spd matrix is guaranteed to be positive definite
if small_system_rhs.requires_grad or schur_complement.requires_grad:
# TODO: Delete this part of the if statement when PyTorch implements cholesky_solve derivative.
fant_cache_lower = torch.gesv(small_system_rhs.unsqueeze(-1), schur_complement)[0]
fant_cache_lower = torch.gesv(small_system_rhs, schur_complement)[0]
else:
fant_cache_lower = cholesky_solve(small_system_rhs, torch.cholesky(schur_complement))

Expand Down Expand Up @@ -355,7 +358,7 @@ def _exact_predictive_covar_inv_quad_form_root(self, precomputed_cache, test_tra
res = left_interp(test_interp_indices, test_interp_values, precomputed_cache)
return res

def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_output):
def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_output, **kwargs):
raise NotImplementedError(
"Fantasy observation updates not yet supported for models using InterpolatedLazyTensors"
)
Expand Down
13 changes: 9 additions & 4 deletions gpytorch/variational/variational_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import math
import torch
from .. import beta_features, settings
from .. import settings
from ..lazy import DiagLazyTensor, CachedCGLazyTensor, CholLazyTensor, PsdSumLazyTensor, RootLazyTensor
from ..module import Module
from ..distributions import MultivariateNormal
Expand Down Expand Up @@ -171,12 +171,17 @@ def forward(self, x):
inv_products = induc_induc_covar.inv_matmul(induc_data_covar, left_tensors.transpose(-1, -2))
predictive_mean = torch.add(test_mean, inv_products[..., 0, :])
predictive_covar = RootLazyTensor(inv_products[..., 1:, :].transpose(-1, -2))
if beta_features.diagonal_correction.on():
if self.training:
interp_data_data_var, _ = induc_induc_covar.inv_quad_logdet(
induc_data_covar, logdet=False, reduce_inv_quad=False
)
diag_correction = DiagLazyTensor((data_data_covar.diag() - interp_data_data_var).clamp(0, math.inf))
predictive_covar = PsdSumLazyTensor(predictive_covar, diag_correction)
data_covariance = DiagLazyTensor((data_data_covar.diag() - interp_data_data_var).clamp(0, math.inf))
else:
neg_induc_data_data_covar = induc_induc_covar.inv_matmul(
induc_data_covar, left_tensor=induc_data_covar.transpose(-1, -2).mul(-1)
)
data_covariance = data_data_covar + neg_induc_data_data_covar
predictive_covar = PsdSumLazyTensor(predictive_covar, data_covariance)

return MultivariateNormal(predictive_mean, predictive_covar)

Expand Down
Loading

0 comments on commit 5f691af

Please sign in to comment.