Skip to content

Commit

Permalink
FullyBayesian LogEI (#2058)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2058

This commit adds support for combining LogEI acquisition functions with fully Bayesian models. In particular, the commit adds the option to compute
```
LogEI(x) = log( E_SAAS[ E_f[ f_SAAS(x) ] ] ),
```
by replacing `mean` with `logsumexp` in `t_batch_mode_transform`, where `f` is the GP with hyper-parameters `SAAS` evaluated at `x`. Without the change, the acqf would compute
```
ELogEI(x) = E_SAAS[ log( E_f[ f_SAAS(x)] ) ].
```

Reviewed By: dme65, Balandat

Differential Revision: D50413044

fbshipit-source-id: ec5342d8affd7f6d49dd5af9849166974473022e
  • Loading branch information
SebastianAment authored and facebook-github-bot committed Nov 2, 2023
1 parent 260ad89 commit 0af3ca5
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 6 deletions.
2 changes: 2 additions & 0 deletions botorch/acquisition/acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class AcquisitionFunction(Module, ABC):
:meta private:
"""

_log: bool = False # whether the acquisition utilities are in log-space

def __init__(self, model: Model) -> None:
r"""Constructor for the AcquisitionFunction base class.
Expand Down
8 changes: 8 additions & 0 deletions botorch/acquisition/analytic.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ class LogProbabilityOfImprovement(AnalyticAcquisitionFunction):
>>> log_pi = LogPI(test_X)
"""

_log: bool = True

def __init__(
self,
model: Model,
Expand Down Expand Up @@ -375,6 +377,8 @@ class LogExpectedImprovement(AnalyticAcquisitionFunction):
>>> ei = LogEI(test_X)
"""

_log: bool = True

def __init__(
self,
model: Model,
Expand Down Expand Up @@ -442,6 +446,8 @@ class LogConstrainedExpectedImprovement(AnalyticAcquisitionFunction):
>>> cei = LogCEI(test_X)
"""

_log: bool = True

def __init__(
self,
model: Model,
Expand Down Expand Up @@ -591,6 +597,8 @@ class LogNoisyExpectedImprovement(AnalyticAcquisitionFunction):
>>> nei = LogNEI(test_X)
"""

_log: bool = True

def __init__(
self,
model: GPyTorchModel,
Expand Down
6 changes: 6 additions & 0 deletions botorch/acquisition/multi_objective/logei.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@
class qLogExpectedHypervolumeImprovement(
MultiObjectiveMCAcquisitionFunction, SubsetIndexCachingMixin
):

_log: bool = True

def __init__(
self,
model: Model,
Expand Down Expand Up @@ -318,6 +321,9 @@ class qLogNoisyExpectedHypervolumeImprovement(
NoisyExpectedHypervolumeMixin,
qLogExpectedHypervolumeImprovement,
):

_log: bool = True

def __init__(
self,
model: Model,
Expand Down
6 changes: 5 additions & 1 deletion botorch/utils/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Any, Callable, List, Optional, TYPE_CHECKING

import torch
from botorch.utils.safe_math import logmeanexp
from torch import Tensor

if TYPE_CHECKING:
Expand Down Expand Up @@ -255,7 +256,10 @@ def decorated(
X = X if X.dim() > 2 else X.unsqueeze(0)
output = method(acqf, X, *args, **kwargs)
if hasattr(acqf, "model") and is_fully_bayesian(acqf.model):
output = output.mean(dim=-1)
# IDEA: this could be wrapped into SampleReducingMCAcquisitionFunction
output = (
output.mean(dim=-1) if not acqf._log else logmeanexp(output, dim=-1)
)
if assert_output_shape and not _verify_output_shape(
acqf=acqf,
X=X,
Expand Down
54 changes: 53 additions & 1 deletion test/acquisition/test_logei.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,25 @@
import torch
from botorch import settings
from botorch.acquisition import (
AcquisitionFunction,
LogImprovementMCAcquisitionFunction,
qLogExpectedImprovement,
qLogNoisyExpectedImprovement,
)
from botorch.acquisition.analytic import (
ExpectedImprovement,
LogExpectedImprovement,
LogNoisyExpectedImprovement,
NoisyExpectedImprovement,
)
from botorch.acquisition.input_constructors import ACQF_INPUT_CONSTRUCTOR_REGISTRY
from botorch.acquisition.monte_carlo import (
qExpectedImprovement,
qNoisyExpectedImprovement,
)
from botorch.acquisition.multi_objective.logei import (
qLogNoisyExpectedHypervolumeImprovement,
)

from botorch.acquisition.objective import (
ConstrainedMCObjective,
Expand All @@ -33,7 +43,8 @@
from botorch.acquisition.utils import prune_inferior_points
from botorch.exceptions import BotorchWarning, UnsupportedError
from botorch.exceptions.errors import BotorchError
from botorch.models import SingleTaskGP
from botorch.models import ModelListGP, SingleTaskGP
from botorch.models.gp_regression import FixedNoiseGP
from botorch.sampling.normal import IIDNormalSampler, SobolQMCNormalSampler
from botorch.utils.low_rank import sample_cached_cholesky
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior
Expand Down Expand Up @@ -717,3 +728,44 @@ def test_cache_root(self):
best_feas_f, torch.full_like(obj[..., [0]], -infcost.item())
)
# TODO: Test different objectives (incl. constraints)


class TestIsLog(BotorchTestCase):
def test_is_log(self):
# the flag is False by default
self.assertFalse(AcquisitionFunction._log)

# single objective case
X, Y = torch.rand(3, 2), torch.randn(3, 1)
model = FixedNoiseGP(train_X=X, train_Y=Y, train_Yvar=torch.rand_like(Y))

# (q)LogEI
for acqf_class in [LogExpectedImprovement, qLogExpectedImprovement]:
acqf = acqf_class(model=model, best_f=0.0)
self.assertTrue(acqf._log)

# (q)EI
for acqf_class in [ExpectedImprovement, qExpectedImprovement]:
acqf = acqf_class(model=model, best_f=0.0)
self.assertFalse(acqf._log)

# (q)LogNEI
for acqf_class in [LogNoisyExpectedImprovement, qLogNoisyExpectedImprovement]:
# avoiding keywords since they differ: X_observed vs. X_baseline
acqf = acqf_class(model, X)
self.assertTrue(acqf._log)

# (q)NEI
for acqf_class in [NoisyExpectedImprovement, qNoisyExpectedImprovement]:
acqf = acqf_class(model, X)
self.assertFalse(acqf._log)

# multi-objective case
model_list = ModelListGP(model, model)
ref_point = [4, 2] # the meaning of life

# qLogNEHVI
acqf = qLogNoisyExpectedHypervolumeImprovement(
model=model_list, X_baseline=X, ref_point=ref_point
)
self.assertTrue(acqf._log)
53 changes: 49 additions & 4 deletions test/models/test_fully_bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@

import itertools
from unittest import mock
from unittest.mock import patch

import pyro

import torch
from botorch import fit_fully_bayesian_model_nuts
from botorch import fit_fully_bayesian_model_nuts, utils
from botorch.acquisition.analytic import (
ExpectedImprovement,
PosteriorMean,
Expand All @@ -34,6 +35,10 @@
qExpectedHypervolumeImprovement,
qNoisyExpectedHypervolumeImprovement,
)
from botorch.acquisition.multi_objective.logei import (
qLogExpectedHypervolumeImprovement,
qLogNoisyExpectedHypervolumeImprovement,
)
from botorch.acquisition.utils import prune_inferior_points
from botorch.models import ModelList, ModelListGP
from botorch.models.deterministic import GenericDeterministicModel
Expand All @@ -51,6 +56,7 @@
from botorch.utils.multi_objective.box_decompositions.non_dominated import (
NondominatedPartitioning,
)
from botorch.utils.safe_math import logmeanexp
from botorch.utils.testing import BotorchTestCase
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import MaternKernel, ScaleKernel
Expand Down Expand Up @@ -438,13 +444,13 @@ def test_acquisition_functions(self):
qExpectedImprovement(
model=model, best_f=train_Y.max(), sampler=simple_sampler
),
qLogNoisyExpectedImprovement(
qNoisyExpectedImprovement(
model=model,
X_baseline=train_X,
sampler=simple_sampler,
cache_root=False,
),
qNoisyExpectedImprovement(
qLogNoisyExpectedImprovement(
model=model,
X_baseline=train_X,
sampler=simple_sampler,
Expand All @@ -462,6 +468,13 @@ def test_acquisition_functions(self):
sampler=list_gp_sampler,
cache_root=False,
),
qLogNoisyExpectedHypervolumeImprovement(
model=list_gp,
X_baseline=train_X,
ref_point=torch.zeros(2, **tkwargs),
sampler=list_gp_sampler,
cache_root=False,
),
qExpectedHypervolumeImprovement(
model=list_gp,
ref_point=torch.zeros(2, **tkwargs),
Expand All @@ -470,6 +483,14 @@ def test_acquisition_functions(self):
ref_point=torch.zeros(2, **tkwargs), Y=train_Y.repeat([1, 2])
),
),
qLogExpectedHypervolumeImprovement(
model=list_gp,
ref_point=torch.zeros(2, **tkwargs),
sampler=list_gp_sampler,
partitioning=NondominatedPartitioning(
ref_point=torch.zeros(2, **tkwargs), Y=train_Y.repeat([1, 2])
),
),
# qEHVI/qNEHVI with mixed models
qNoisyExpectedHypervolumeImprovement(
model=mixed_list,
Expand All @@ -478,6 +499,13 @@ def test_acquisition_functions(self):
sampler=mixed_list_sampler,
cache_root=False,
),
qLogNoisyExpectedHypervolumeImprovement(
model=mixed_list,
X_baseline=train_X,
ref_point=torch.zeros(2, **tkwargs),
sampler=mixed_list_sampler,
cache_root=False,
),
qExpectedHypervolumeImprovement(
model=mixed_list,
ref_point=torch.zeros(2, **tkwargs),
Expand All @@ -486,12 +514,29 @@ def test_acquisition_functions(self):
ref_point=torch.zeros(2, **tkwargs), Y=train_Y.repeat([1, 2])
),
),
qLogExpectedHypervolumeImprovement(
model=mixed_list,
ref_point=torch.zeros(2, **tkwargs),
sampler=mixed_list_sampler,
partitioning=NondominatedPartitioning(
ref_point=torch.zeros(2, **tkwargs), Y=train_Y.repeat([1, 2])
),
),
]

for acqf in acquisition_functions:
for batch_shape in [[5], [6, 5, 2]]:
test_X = torch.rand(*batch_shape, 1, 4, **tkwargs)
self.assertEqual(acqf(test_X).shape, torch.Size(batch_shape))
# Testing that the t_batch_mode_transform works correctly for
# fully Bayesian models with log-space acquisition functions.
with patch.object(
utils.transforms, "logmeanexp", wraps=logmeanexp
) as mock:
self.assertEqual(acqf(test_X).shape, torch.Size(batch_shape))
if acqf._log:
mock.assert_called_once()
else:
mock.assert_not_called()

# Test prune_inferior_points
X_pruned = prune_inferior_points(model=model, X=train_X)
Expand Down

0 comments on commit 0af3ca5

Please sign in to comment.