diff --git a/botorch/acquisition/acquisition.py b/botorch/acquisition/acquisition.py index 94cb6ef5c7..d37e3e4030 100644 --- a/botorch/acquisition/acquisition.py +++ b/botorch/acquisition/acquisition.py @@ -32,6 +32,8 @@ class AcquisitionFunction(Module, ABC): :meta private: """ + _log: bool = False # whether the acquisition utilities are in log-space + def __init__(self, model: Model) -> None: r"""Constructor for the AcquisitionFunction base class. diff --git a/botorch/acquisition/analytic.py b/botorch/acquisition/analytic.py index 41c0f9477b..07d39d1454 100644 --- a/botorch/acquisition/analytic.py +++ b/botorch/acquisition/analytic.py @@ -135,6 +135,8 @@ class LogProbabilityOfImprovement(AnalyticAcquisitionFunction): >>> log_pi = LogPI(test_X) """ + _log: bool = True + def __init__( self, model: Model, @@ -375,6 +377,8 @@ class LogExpectedImprovement(AnalyticAcquisitionFunction): >>> ei = LogEI(test_X) """ + _log: bool = True + def __init__( self, model: Model, @@ -442,6 +446,8 @@ class LogConstrainedExpectedImprovement(AnalyticAcquisitionFunction): >>> cei = LogCEI(test_X) """ + _log: bool = True + def __init__( self, model: Model, @@ -591,6 +597,8 @@ class LogNoisyExpectedImprovement(AnalyticAcquisitionFunction): >>> nei = LogNEI(test_X) """ + _log: bool = True + def __init__( self, model: GPyTorchModel, diff --git a/botorch/acquisition/multi_objective/logei.py b/botorch/acquisition/multi_objective/logei.py index a0e1832b1e..295e5ce6e6 100644 --- a/botorch/acquisition/multi_objective/logei.py +++ b/botorch/acquisition/multi_objective/logei.py @@ -48,6 +48,9 @@ class qLogExpectedHypervolumeImprovement( MultiObjectiveMCAcquisitionFunction, SubsetIndexCachingMixin ): + + _log: bool = True + def __init__( self, model: Model, @@ -318,6 +321,9 @@ class qLogNoisyExpectedHypervolumeImprovement( NoisyExpectedHypervolumeMixin, qLogExpectedHypervolumeImprovement, ): + + _log: bool = True + def __init__( self, model: Model, diff --git a/botorch/utils/transforms.py b/botorch/utils/transforms.py index 7ce87737c1..729bd591b7 100644 --- a/botorch/utils/transforms.py +++ b/botorch/utils/transforms.py @@ -15,6 +15,7 @@ from typing import Any, Callable, List, Optional, TYPE_CHECKING import torch +from botorch.utils.safe_math import logmeanexp from torch import Tensor if TYPE_CHECKING: @@ -255,7 +256,10 @@ def decorated( X = X if X.dim() > 2 else X.unsqueeze(0) output = method(acqf, X, *args, **kwargs) if hasattr(acqf, "model") and is_fully_bayesian(acqf.model): - output = output.mean(dim=-1) + # IDEA: this could be wrapped into SampleReducingMCAcquisitionFunction + output = ( + output.mean(dim=-1) if not acqf._log else logmeanexp(output, dim=-1) + ) if assert_output_shape and not _verify_output_shape( acqf=acqf, X=X, diff --git a/test/acquisition/test_logei.py b/test/acquisition/test_logei.py index 15c623990c..4f58bb8ec6 100644 --- a/test/acquisition/test_logei.py +++ b/test/acquisition/test_logei.py @@ -13,15 +13,25 @@ import torch from botorch import settings from botorch.acquisition import ( + AcquisitionFunction, LogImprovementMCAcquisitionFunction, qLogExpectedImprovement, qLogNoisyExpectedImprovement, ) +from botorch.acquisition.analytic import ( + ExpectedImprovement, + LogExpectedImprovement, + LogNoisyExpectedImprovement, + NoisyExpectedImprovement, +) from botorch.acquisition.input_constructors import ACQF_INPUT_CONSTRUCTOR_REGISTRY from botorch.acquisition.monte_carlo import ( qExpectedImprovement, qNoisyExpectedImprovement, ) +from botorch.acquisition.multi_objective.logei import ( + qLogNoisyExpectedHypervolumeImprovement, +) from botorch.acquisition.objective import ( ConstrainedMCObjective, @@ -33,7 +43,8 @@ from botorch.acquisition.utils import prune_inferior_points from botorch.exceptions import BotorchWarning, UnsupportedError from botorch.exceptions.errors import BotorchError -from botorch.models import SingleTaskGP +from botorch.models import ModelListGP, SingleTaskGP +from botorch.models.gp_regression import FixedNoiseGP from botorch.sampling.normal import IIDNormalSampler, SobolQMCNormalSampler from botorch.utils.low_rank import sample_cached_cholesky from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior @@ -717,3 +728,44 @@ def test_cache_root(self): best_feas_f, torch.full_like(obj[..., [0]], -infcost.item()) ) # TODO: Test different objectives (incl. constraints) + + +class TestIsLog(BotorchTestCase): + def test_is_log(self): + # the flag is False by default + self.assertFalse(AcquisitionFunction._log) + + # single objective case + X, Y = torch.rand(3, 2), torch.randn(3, 1) + model = FixedNoiseGP(train_X=X, train_Y=Y, train_Yvar=torch.rand_like(Y)) + + # (q)LogEI + for acqf_class in [LogExpectedImprovement, qLogExpectedImprovement]: + acqf = acqf_class(model=model, best_f=0.0) + self.assertTrue(acqf._log) + + # (q)EI + for acqf_class in [ExpectedImprovement, qExpectedImprovement]: + acqf = acqf_class(model=model, best_f=0.0) + self.assertFalse(acqf._log) + + # (q)LogNEI + for acqf_class in [LogNoisyExpectedImprovement, qLogNoisyExpectedImprovement]: + # avoiding keywords since they differ: X_observed vs. X_baseline + acqf = acqf_class(model, X) + self.assertTrue(acqf._log) + + # (q)NEI + for acqf_class in [NoisyExpectedImprovement, qNoisyExpectedImprovement]: + acqf = acqf_class(model, X) + self.assertFalse(acqf._log) + + # multi-objective case + model_list = ModelListGP(model, model) + ref_point = [4, 2] # the meaning of life + + # qLogNEHVI + acqf = qLogNoisyExpectedHypervolumeImprovement( + model=model_list, X_baseline=X, ref_point=ref_point + ) + self.assertTrue(acqf._log) diff --git a/test/models/test_fully_bayesian.py b/test/models/test_fully_bayesian.py index afd8b54dd5..cfa242d5d7 100644 --- a/test/models/test_fully_bayesian.py +++ b/test/models/test_fully_bayesian.py @@ -7,11 +7,12 @@ import itertools from unittest import mock +from unittest.mock import patch import pyro import torch -from botorch import fit_fully_bayesian_model_nuts +from botorch import fit_fully_bayesian_model_nuts, utils from botorch.acquisition.analytic import ( ExpectedImprovement, PosteriorMean, @@ -34,6 +35,10 @@ qExpectedHypervolumeImprovement, qNoisyExpectedHypervolumeImprovement, ) +from botorch.acquisition.multi_objective.logei import ( + qLogExpectedHypervolumeImprovement, + qLogNoisyExpectedHypervolumeImprovement, +) from botorch.acquisition.utils import prune_inferior_points from botorch.models import ModelList, ModelListGP from botorch.models.deterministic import GenericDeterministicModel @@ -51,6 +56,7 @@ from botorch.utils.multi_objective.box_decompositions.non_dominated import ( NondominatedPartitioning, ) +from botorch.utils.safe_math import logmeanexp from botorch.utils.testing import BotorchTestCase from gpytorch.distributions import MultivariateNormal from gpytorch.kernels import MaternKernel, ScaleKernel @@ -438,13 +444,13 @@ def test_acquisition_functions(self): qExpectedImprovement( model=model, best_f=train_Y.max(), sampler=simple_sampler ), - qLogNoisyExpectedImprovement( + qNoisyExpectedImprovement( model=model, X_baseline=train_X, sampler=simple_sampler, cache_root=False, ), - qNoisyExpectedImprovement( + qLogNoisyExpectedImprovement( model=model, X_baseline=train_X, sampler=simple_sampler, @@ -462,6 +468,13 @@ def test_acquisition_functions(self): sampler=list_gp_sampler, cache_root=False, ), + qLogNoisyExpectedHypervolumeImprovement( + model=list_gp, + X_baseline=train_X, + ref_point=torch.zeros(2, **tkwargs), + sampler=list_gp_sampler, + cache_root=False, + ), qExpectedHypervolumeImprovement( model=list_gp, ref_point=torch.zeros(2, **tkwargs), @@ -470,6 +483,14 @@ def test_acquisition_functions(self): ref_point=torch.zeros(2, **tkwargs), Y=train_Y.repeat([1, 2]) ), ), + qLogExpectedHypervolumeImprovement( + model=list_gp, + ref_point=torch.zeros(2, **tkwargs), + sampler=list_gp_sampler, + partitioning=NondominatedPartitioning( + ref_point=torch.zeros(2, **tkwargs), Y=train_Y.repeat([1, 2]) + ), + ), # qEHVI/qNEHVI with mixed models qNoisyExpectedHypervolumeImprovement( model=mixed_list, @@ -478,6 +499,13 @@ def test_acquisition_functions(self): sampler=mixed_list_sampler, cache_root=False, ), + qLogNoisyExpectedHypervolumeImprovement( + model=mixed_list, + X_baseline=train_X, + ref_point=torch.zeros(2, **tkwargs), + sampler=mixed_list_sampler, + cache_root=False, + ), qExpectedHypervolumeImprovement( model=mixed_list, ref_point=torch.zeros(2, **tkwargs), @@ -486,12 +514,29 @@ def test_acquisition_functions(self): ref_point=torch.zeros(2, **tkwargs), Y=train_Y.repeat([1, 2]) ), ), + qLogExpectedHypervolumeImprovement( + model=mixed_list, + ref_point=torch.zeros(2, **tkwargs), + sampler=mixed_list_sampler, + partitioning=NondominatedPartitioning( + ref_point=torch.zeros(2, **tkwargs), Y=train_Y.repeat([1, 2]) + ), + ), ] for acqf in acquisition_functions: for batch_shape in [[5], [6, 5, 2]]: test_X = torch.rand(*batch_shape, 1, 4, **tkwargs) - self.assertEqual(acqf(test_X).shape, torch.Size(batch_shape)) + # Testing that the t_batch_mode_transform works correctly for + # fully Bayesian models with log-space acquisition functions. + with patch.object( + utils.transforms, "logmeanexp", wraps=logmeanexp + ) as mock: + self.assertEqual(acqf(test_X).shape, torch.Size(batch_shape)) + if acqf._log: + mock.assert_called_once() + else: + mock.assert_not_called() # Test prune_inferior_points X_pruned = prune_inferior_points(model=model, X=train_X)