diff --git a/gpytorch/variational/variational_strategy.py b/gpytorch/variational/variational_strategy.py index 453447da2..3435599df 100644 --- a/gpytorch/variational/variational_strategy.py +++ b/gpytorch/variational/variational_strategy.py @@ -1,5 +1,6 @@ from __future__ import absolute_import, division, print_function, unicode_literals +import math import torch import gpytorch from ..lazy import RootLazyTensor, PsdSumLazyTensor, DiagLazyTensor @@ -132,7 +133,7 @@ def forward(self, x): if gpytorch.beta_features.diagonal_correction.on(): fake_diagonal = (inv_product * induc_data_covar).sum(0) real_diagonal = data_data_covar.diag() - diag_correction = DiagLazyTensor(real_diagonal - fake_diagonal) + diag_correction = DiagLazyTensor((real_diagonal - fake_diagonal).clamp(0, math.inf)) predictive_covar = PsdSumLazyTensor(predictive_covar, diag_correction) return MultivariateNormal(predictive_mean, predictive_covar)