Skip to content

Commit

Permalink
require output dim in MultiTaskGP (#383)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #383

- Require output dim in MTGP
- Validate training tensor dimensions
- Validate input scaling

Reviewed By: Balandat

Differential Revision: D20223856

fbshipit-source-id: 8dfd327ecc6a9bb141211dd43148baa34c22ed70
  • Loading branch information
sdaulton authored and facebook-github-bot committed Mar 4, 2020
1 parent 4302a0c commit a6eddcb
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
11 changes: 8 additions & 3 deletions botorch/models/multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from torch import Tensor

from .gpytorch import MultiTaskGPyTorchModel
from .utils import validate_input_scaling


class MultiTaskGP(ExactGP, MultiTaskGPyTorchModel):
Expand Down Expand Up @@ -74,10 +75,13 @@ def __init__(
>>> train_Y = torch.cat(f1(X1), f2(X2))
>>> model = MultiTaskGP(train_X, train_Y, task_feature=-1)
"""
# TODO: Validate input normalization/scaling
if train_X.ndimension() != 2:
self._validate_tensor_args(X=train_X, Y=train_Y)
validate_input_scaling(train_X=train_X, train_Y=train_Y)
if train_X.ndim != 2:
# Currently, batch mode MTGPs are blocked upstream in GPyTorch
raise ValueError(f"Unsupported shape {train_X.shape} for train_X.")
# squeeze output dim
train_Y = train_Y.squeeze(-1)
d = train_X.shape[-1] - 1
if not (-d <= task_feature <= d):
raise ValueError(f"Must have that -{d} <= task_feature <= {d}")
Expand Down Expand Up @@ -199,6 +203,7 @@ def __init__(
>>> train_Yvar = 0.1 + 0.1 * torch.rand_like(train_Y)
>>> model = FixedNoiseMultiTaskGP(train_X, train_Y, train_Yvar, -1)
"""
self._validate_tensor_args(X=train_X, Y=train_Y, Yvar=train_Yvar)
# We'll instatiate a MultiTaskGP and simply override the likelihood
super().__init__(
train_X=train_X,
Expand All @@ -207,5 +212,5 @@ def __init__(
output_tasks=output_tasks,
rank=rank,
)
self.likelihood = FixedNoiseGaussianLikelihood(noise=train_Yvar)
self.likelihood = FixedNoiseGaussianLikelihood(noise=train_Yvar.squeeze(-1))
self.to(train_X)
6 changes: 3 additions & 3 deletions test/models/test_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _get_random_mt_data(**tkwargs):
full_train_i = torch.cat([train_i_task1, train_i_task2])
full_train_y = torch.cat([train_y1, train_y2])
train_X = torch.stack([full_train_x, full_train_i.type_as(full_train_x)], dim=-1)
train_Y = full_train_y
train_Y = full_train_y.unsqueeze(-1) # add output dim
return train_X, train_Y


Expand Down Expand Up @@ -121,7 +121,7 @@ def test_MultiTaskGP(self):

# test that unsupported batch shape MTGPs throw correct error
with self.assertRaises(ValueError):
MultiTaskGP(torch.rand(2, 2, 2), torch.rand(2, 1), 0)
MultiTaskGP(torch.rand(2, 2, 2), torch.rand(2, 2, 1), 0)

# test that bad feature index throws correct error
train_X, train_Y = _get_random_mt_data(**tkwargs)
Expand Down Expand Up @@ -233,7 +233,7 @@ def test_FixedNoiseMultiTaskGP(self):
# test that unsupported batch shape MTGPs throw correct error
with self.assertRaises(ValueError):
FixedNoiseMultiTaskGP(
torch.rand(2, 2, 2), torch.rand(2, 1), torch.rand(2, 1), 0
torch.rand(2, 2, 2), torch.rand(2, 2, 1), torch.rand(2, 2, 1), 0
)

# test that bad feature index throws correct error
Expand Down

0 comments on commit a6eddcb

Please sign in to comment.