Skip to content

Commit

Permalink
Break up qNEHVI test for readibility (#2076)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2076

This commit splits the `qNEHVI` test into three parts: 1) tests for the base functionality, as well as 2) with, and 3) without the cached box decomposition (CBD).

Reviewed By: esantorella

Differential Revision: D50808514

fbshipit-source-id: 84247c61c6a2c569cc4c1d7618ee887e68af7557
  • Loading branch information
SebastianAment authored and facebook-github-bot committed Nov 1, 2023
1 parent 01b2503 commit 85ccd2d
Showing 1 changed file with 110 additions and 32 deletions.
142 changes: 110 additions & 32 deletions test/acquisition/multi_objective/test_monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,33 +721,67 @@ def setUp(self):
super().setUp()

def test_q_noisy_expected_hypervolume_improvement(self):
for dtype, m in product(
(torch.float, torch.double),
(1, 2, 3),
):
with self.subTest(dtype=dtype, m=m):
self._test_q_noisy_expected_hypervolume_improvement(
qNoisyExpectedHypervolumeImprovement, dtype, m
)
for dtype in (torch.float, torch.double):
self._test_q_noisy_expected_hypervolume_improvement_m1(
qNoisyExpectedHypervolumeImprovement, dtype
)
for m in (2, 3):
with self.subTest(dtype=dtype, m=m):
self._test_q_noisy_expected_hypervolume_improvement(
qNoisyExpectedHypervolumeImprovement, dtype, m
)

def test_q_log_noisy_expected_hypervolume_improvement(self):
for dtype, m in product(
(torch.float, torch.double),
(1, 2, 3),
for dtype in (torch.float, torch.double):
self._test_q_noisy_expected_hypervolume_improvement_m1(
qLogNoisyExpectedHypervolumeImprovement, dtype
)
for m in (2, 3):
with self.subTest(dtype=dtype, m=m):
self._test_q_noisy_expected_hypervolume_improvement(
qLogNoisyExpectedHypervolumeImprovement, dtype, m
)

def _test_q_noisy_expected_hypervolume_improvement_m1(
self, acqf_class: Type[AcquisitionFunction], dtype: torch.dtype
):
# special case test for m = 1.
(
ref_point,
X,
X_baseline,
mm,
sampler,
samples,
baseline_samples,
tkwargs,
) = self._setup_qnehvi_test(dtype=dtype, m=1)
# test error is raised if m == 1
with self.assertRaisesRegex(
ValueError,
"NoisyExpectedHypervolumeMixin supports m>=2 outcomes ",
):
with self.subTest(dtype=dtype, m=m):
self._test_q_noisy_expected_hypervolume_improvement(
qLogNoisyExpectedHypervolumeImprovement, dtype, m
)
acqf_class(
model=mm,
ref_point=ref_point,
X_baseline=X_baseline,
sampler=sampler,
cache_root=False,
)

def _test_q_noisy_expected_hypervolume_improvement(
self, acqf_class: Type[AcquisitionFunction], dtype: torch.dtype, m: int
):
) -> None:
self._test_qnehvi_base(acqf_class, dtype, m)
# test with and without cached box decomposition (CBD)
self._test_qnehvi_with_CBD(acqf_class, dtype, m)
self._test_qnehvi_without_CBD(acqf_class, dtype, m)

def _setup_qnehvi_test(self, dtype: torch.dtype, m: int) -> None:
tkwargs = {"device": self.device}
tkwargs["dtype"] = dtype
ref_point = self.ref_point[:m]
Y = self.Y_raw[:, :m].to(**tkwargs)
pareto_Y = self.pareto_Y_raw[:, :m].to(**tkwargs)
X_baseline = torch.rand(Y.shape[0], 1, **tkwargs)
# the event shape is `b x q + r x m` = 1 x 1 x 2
baseline_samples = Y
Expand All @@ -759,22 +793,21 @@ def _test_q_noisy_expected_hypervolume_improvement(
X = torch.zeros(1, 1, **tkwargs)
# basic test
sampler = IIDNormalSampler(sample_shape=torch.Size([1]))
return ref_point, X, X_baseline, mm, sampler, samples, baseline_samples, tkwargs

# test error is raised if m == 1
if m == 1:
with self.assertRaisesRegex(
ValueError,
"NoisyExpectedHypervolumeMixin supports m>=2 outcomes ",
):
acqf = acqf_class(
model=mm,
ref_point=ref_point,
X_baseline=X_baseline,
sampler=sampler,
cache_root=False,
)
return

def _test_qnehvi_base(
self, acqf_class: Type[AcquisitionFunction], dtype: torch.dtype, m: int
) -> None:
(
ref_point,
X,
X_baseline,
mm,
sampler,
samples,
baseline_samples,
tkwargs,
) = self._setup_qnehvi_test(dtype=dtype, m=m)
acqf = acqf_class(
model=mm,
ref_point=ref_point,
Expand Down Expand Up @@ -934,6 +967,21 @@ def _test_q_noisy_expected_hypervolume_improvement(
self.assertEqual(list(b.shape), [1, 1, m])
self.assertEqual(list(b.shape), [1, 1, m])

def _test_qnehvi_with_CBD(
self, acqf_class: Type[AcquisitionFunction], dtype: torch.dtype, m: int
) -> None:
(
ref_point,
X,
X_baseline,
mm,
sampler,
samples,
baseline_samples,
tkwargs,
) = self._setup_qnehvi_test(dtype=dtype, m=m)
pareto_Y = self.pareto_Y_raw[:, :m].to(**tkwargs)

# test no baseline points
ref_point2 = [15.0, 14.0, 16.0][:m]
sampler = IIDNormalSampler(sample_shape=torch.Size([1]))
Expand Down Expand Up @@ -1146,6 +1194,21 @@ def _test_q_noisy_expected_hypervolume_improvement(
self.assertTrue(torch.equal(acqf_pareto_Y[:-2], expected_pareto_Y))
self.assertTrue(torch.equal(acqf_pareto_Y[-2:], expected_new_Y2))

def _test_qnehvi_without_CBD(
self, acqf_class: Type[AcquisitionFunction], dtype: torch.dtype, m: int
) -> None:
tkwargs = {"device": self.device}
tkwargs["dtype"] = dtype
ref_point = self.ref_point[:m]
Y = self.Y_raw[:, :m].to(**tkwargs)
pareto_Y = self.pareto_Y_raw[:, :m].to(**tkwargs)
X_baseline = torch.rand(Y.shape[0], 1, **tkwargs)
# the event shape is `b x q + r x m` = 1 x 1 x 2
baseline_samples = Y
mm = MockModel(MockPosterior(samples=baseline_samples))

X_pending = torch.rand(1, 1, dtype=dtype, device=self.device)

# test qNEHVI without CBD
mm._posterior._samples = baseline_samples
sampler = IIDNormalSampler(sample_shape=torch.Size([1]))
Expand All @@ -1158,6 +1221,7 @@ def _test_q_noisy_expected_hypervolume_improvement(
cache_pending=False,
cache_root=False,
)
new_Y = torch.tensor([[0.5, 3.0, 0.5][:m]], dtype=dtype, device=self.device)
mm._posterior._samples = torch.cat(
[
baseline_samples,
Expand All @@ -1168,15 +1232,25 @@ def _test_q_noisy_expected_hypervolume_improvement(
acqf.set_X_pending(X_pending10)
self.assertTrue(torch.equal(acqf.X_pending, X_pending10))
acqf_pareto_Y = acqf.partitioning.pareto_Y[0]
expected_pareto_Y = pareto_Y if m == 2 else pareto_Y.cpu()
self.assertTrue(torch.equal(acqf_pareto_Y, expected_pareto_Y))
acqf.set_X_pending(X_pending)
# test incremental nehvi in forward
new_Y2 = torch.cat(
[
new_Y,
torch.tensor([[0.25, 9.5, 1.5][:m]], dtype=dtype, device=self.device),
],
dim=0,
)
mm._posterior._samples = torch.cat(
[
baseline_samples,
new_Y2,
]
).unsqueeze(0)
with torch.no_grad():
X_test = torch.rand(1, 1, dtype=dtype, device=self.device)
val = evaluate(acqf, X_test)
bd = DominatedPartitioning(
ref_point=torch.tensor(ref_point).to(**tkwargs), Y=pareto_Y
Expand Down Expand Up @@ -1212,6 +1286,10 @@ def _test_q_noisy_expected_hypervolume_improvement(
# test X_pending is not None on __init__
mm._posterior._samples = torch.zeros(1, 5, m, **tkwargs)
sampler = IIDNormalSampler(sample_shape=torch.Size([1]))
# add another point
X_pending2 = torch.cat(
[X_pending, torch.rand(1, 1, dtype=dtype, device=self.device)], dim=0
)
acqf = acqf_class(
model=mm,
ref_point=ref_point,
Expand Down

0 comments on commit 85ccd2d

Please sign in to comment.