diff --git a/README.md b/README.md
index 8ae2b6d1..3917f14d 100644
--- a/README.md
+++ b/README.md
@@ -135,10 +135,10 @@ D = gpx.Dataset(X=x, y=y)
# Construct the prior
meanf = gpx.mean_functions.Zero()
kernel = gpx.kernels.RBF()
-prior = gpx.Prior(mean_function=meanf, kernel=kernel)
+prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel)
# Define a likelihood
-likelihood = gpx.Gaussian(num_datapoints=n)
+likelihood = gpx.likelihoods.Gaussian(num_datapoints = n)
# Construct the posterior
posterior = prior * likelihood
diff --git a/benchmarks/objectives.py b/benchmarks/objectives.py
index dd217b42..65b8569c 100644
--- a/benchmarks/objectives.py
+++ b/benchmarks/objectives.py
@@ -22,7 +22,7 @@ def setup(self, n_datapoints: int, n_dims: int):
self.data = gpx.Dataset(X=self.X, y=self.y)
kernel = gpx.kernels.RBF(active_dims=list(range(n_dims)))
meanf = gpx.mean_functions.Constant()
- self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
+ self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Gaussian(num_datapoints=self.data.n)
self.objective = gpx.ConjugateMLL()
self.posterior = self.prior * self.likelihood
@@ -48,7 +48,7 @@ def setup(self, n_datapoints: int, n_dims: int):
self.data = gpx.Dataset(X=self.X, y=self.y)
kernel = gpx.kernels.RBF(active_dims=list(range(n_dims)))
meanf = gpx.mean_functions.Constant()
- self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
+ self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Bernoulli(num_datapoints=self.data.n)
self.objective = gpx.LogPosteriorDensity()
self.posterior = self.prior * self.likelihood
@@ -75,7 +75,7 @@ def setup(self, n_datapoints: int, n_dims: int):
self.data = gpx.Dataset(X=self.X, y=self.y)
kernel = gpx.kernels.RBF(active_dims=list(range(n_dims)))
meanf = gpx.mean_functions.Constant()
- self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
+ self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Poisson(num_datapoints=self.data.n)
self.objective = gpx.LogPosteriorDensity()
self.posterior = self.prior * self.likelihood
diff --git a/benchmarks/predictions.py b/benchmarks/predictions.py
index eed35d66..a3dd4fa8 100644
--- a/benchmarks/predictions.py
+++ b/benchmarks/predictions.py
@@ -21,7 +21,7 @@ def setup(self, n_test: int, n_dims: int):
self.data = gpx.Dataset(X=self.X, y=self.y)
kernel = gpx.kernels.RBF(active_dims=list(range(n_dims)))
meanf = gpx.mean_functions.Constant()
- self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
+ self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Gaussian(num_datapoints=self.data.n)
self.posterior = self.prior * self.likelihood
key, subkey = jr.split(key)
@@ -46,7 +46,7 @@ def setup(self, n_test: int, n_dims: int):
self.data = gpx.Dataset(X=self.X, y=self.y)
kernel = gpx.kernels.RBF(active_dims=list(range(n_dims)))
meanf = gpx.mean_functions.Constant()
- self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
+ self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Bernoulli(num_datapoints=self.data.n)
self.posterior = self.prior * self.likelihood
key, subkey = jr.split(key)
@@ -71,7 +71,7 @@ def setup(self, n_test: int, n_dims: int):
self.data = gpx.Dataset(X=self.X, y=self.y)
kernel = gpx.kernels.RBF(active_dims=list(range(n_dims)))
meanf = gpx.mean_functions.Constant()
- self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
+ self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Bernoulli(num_datapoints=self.data.n)
self.posterior = self.prior * self.likelihood
key, subkey = jr.split(key)
diff --git a/benchmarks/sparse.py b/benchmarks/sparse.py
index 759cac9b..de7d4705 100644
--- a/benchmarks/sparse.py
+++ b/benchmarks/sparse.py
@@ -19,7 +19,7 @@ def setup(self, n_datapoints: int, n_inducing: int):
self.data = gpx.Dataset(X=self.X, y=self.y)
kernel = gpx.kernels.RBF(active_dims=list(range(1)))
meanf = gpx.mean_functions.Constant()
- self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
+ self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Gaussian(num_datapoints=self.data.n)
self.posterior = self.prior * self.likelihood
diff --git a/benchmarks/stochastic.py b/benchmarks/stochastic.py
index 14681535..1e530c73 100644
--- a/benchmarks/stochastic.py
+++ b/benchmarks/stochastic.py
@@ -20,7 +20,7 @@ def setup(self, n_datapoints: int, n_inducing: int, batch_size: int):
self.data = gpx.Dataset(X=self.X, y=self.y)
kernel = gpx.kernels.RBF(active_dims=list(range(1)))
meanf = gpx.mean_functions.Constant()
- self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
+ self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Gaussian(num_datapoints=self.data.n)
self.posterior = self.prior * self.likelihood
diff --git a/docs/examples/README.md b/docs/examples/README.md
index a5188c35..2b7d37c1 100644
--- a/docs/examples/README.md
+++ b/docs/examples/README.md
@@ -67,7 +67,7 @@ class Prior(AbstractPrior):
>>>
>>> meanf = gpx.mean_functions.Zero()
>>> kernel = gpx.kernels.RBF()
- >>> prior = gpx.Prior(mean_function=meanf, kernel = kernel)
+ >>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel)
Attributes:
kernel (Kernel): The kernel function used to parameterise the prior.
diff --git a/docs/examples/barycentres.py b/docs/examples/barycentres.py
index 3733672f..39e12d3c 100644
--- a/docs/examples/barycentres.py
+++ b/docs/examples/barycentres.py
@@ -134,8 +134,13 @@ def fit_gp(x: jax.Array, y: jax.Array) -> tfd.MultivariateNormalFullCovariance:
y = y.reshape(-1, 1)
D = gpx.Dataset(X=x, y=y)
- likelihood = gpx.Gaussian(num_datapoints=n)
- posterior = gpx.Prior(mean_function=gpx.Constant(), kernel=gpx.RBF()) * likelihood
+ likelihood = gpx.likelihoods.Gaussian(num_datapoints=n)
+ posterior = (
+ gpx.gps.Prior(
+ mean_function=gpx.mean_functions.Constant(), kernel=gpx.kernels.RBF()
+ )
+ * likelihood
+ )
opt_posterior, _ = gpx.fit_scipy(
model=posterior,
diff --git a/docs/examples/bayesian_optimisation.py b/docs/examples/bayesian_optimisation.py
index d507fbc0..a2361814 100644
--- a/docs/examples/bayesian_optimisation.py
+++ b/docs/examples/bayesian_optimisation.py
@@ -201,9 +201,9 @@ def forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]:
# %%
def return_optimised_posterior(
- data: gpx.Dataset, prior: gpx.Module, key: Array
-) -> gpx.Module:
- likelihood = gpx.Gaussian(
+ data: gpx.Dataset, prior: gpx.base.Module, key: Array
+) -> gpx.base.Module:
+ likelihood = gpx.likelihoods.Gaussian(
num_datapoints=data.n, obs_stddev=jnp.array(1e-3)
) # Our function is noise-free, so we set the observation noise's standard deviation to a very small value
likelihood = likelihood.replace_trainable(obs_stddev=False)
@@ -230,7 +230,7 @@ def return_optimised_posterior(
mean = gpx.mean_functions.Zero()
kernel = gpx.kernels.Matern52()
-prior = gpx.Prior(mean_function=mean, kernel=kernel)
+prior = gpx.gps.Prior(mean_function=mean, kernel=kernel)
opt_posterior = return_optimised_posterior(D, prior, key)
# %% [markdown]
@@ -315,7 +315,7 @@ def optimise_sample(
# %%
def plot_bayes_opt(
- posterior: gpx.Module,
+ posterior: gpx.base.Module,
sample: FunctionalSample,
dataset: gpx.Dataset,
queried_x: ScalarFloat,
@@ -401,7 +401,7 @@ def plot_bayes_opt(
# Generate optimised posterior using previously observed data
mean = gpx.mean_functions.Zero()
kernel = gpx.kernels.Matern52()
- prior = gpx.Prior(mean_function=mean, kernel=kernel)
+ prior = gpx.gps.Prior(mean_function=mean, kernel=kernel)
opt_posterior = return_optimised_posterior(D, prior, subkey)
# Draw a sample from the posterior, and find the minimiser of it
@@ -543,7 +543,7 @@ def six_hump_camel(x: Float[Array, "N 2"]) -> Float[Array, "N 1"]:
kernel = gpx.kernels.Matern52(
active_dims=[0, 1], lengthscale=jnp.array([1.0, 1.0]), variance=2.0
)
- prior = gpx.Prior(mean_function=mean, kernel=kernel)
+ prior = gpx.gps.Prior(mean_function=mean, kernel=kernel)
opt_posterior = return_optimised_posterior(D, prior, subkey)
# Draw a sample from the posterior, and find the minimiser of it
@@ -561,7 +561,8 @@ def six_hump_camel(x: Float[Array, "N 2"]) -> Float[Array, "N 1"]:
# Evaluate the black-box function at the best point observed so far, and add it to the dataset
y_star = six_hump_camel(x_star)
print(
- f"BO Iteration: {i + 1}, Queried Point: {x_star}, Black-Box Function Value: {y_star}"
+ f"BO Iteration: {i + 1}, Queried Point: {x_star}, Black-Box Function Value:"
+ f" {y_star}"
)
D = D + gpx.Dataset(X=x_star, y=y_star)
bo_experiment_results.append(D)
diff --git a/docs/examples/classification.py b/docs/examples/classification.py
index 818a5272..4b955236 100644
--- a/docs/examples/classification.py
+++ b/docs/examples/classification.py
@@ -89,10 +89,10 @@
# choose a Bernoulli likelihood with a probit link function.
# %%
-kernel = gpx.RBF()
-meanf = gpx.Constant()
-prior = gpx.Prior(mean_function=meanf, kernel=kernel)
-likelihood = gpx.Bernoulli(num_datapoints=D.n)
+kernel = gpx.kernels.RBF()
+meanf = gpx.mean_functions.Constant()
+prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
+likelihood = gpx.likelihoods.Bernoulli(num_datapoints=D.n)
# %% [markdown]
# We construct the posterior through the product of our prior and likelihood.
@@ -116,7 +116,7 @@
# Optax's optimisers.
# %%
-negative_lpd = jax.jit(gpx.LogPosteriorDensity(negative=True))
+negative_lpd = jax.jit(gpx.objectives.LogPosteriorDensity(negative=True))
optimiser = ox.adam(learning_rate=0.01)
@@ -345,7 +345,7 @@ def construct_laplace(test_inputs: Float[Array, "N D"]) -> tfd.MultivariateNorma
num_adapt = 500
num_samples = 500
-lpd = jax.jit(gpx.LogPosteriorDensity(negative=False))
+lpd = jax.jit(gpx.objectives.LogPosteriorDensity(negative=False))
unconstrained_lpd = jax.jit(lambda tree: lpd(tree.constrain(), D))
adapt = blackjax.window_adaptation(
diff --git a/docs/examples/collapsed_vi.py b/docs/examples/collapsed_vi.py
index d497997b..6ee7d81d 100644
--- a/docs/examples/collapsed_vi.py
+++ b/docs/examples/collapsed_vi.py
@@ -106,10 +106,10 @@
# this, it is intractable to evaluate.
# %%
-meanf = gpx.Constant()
-kernel = gpx.RBF()
-likelihood = gpx.Gaussian(num_datapoints=D.n)
-prior = gpx.Prior(mean_function=meanf, kernel=kernel)
+meanf = gpx.mean_functions.Constant()
+kernel = gpx.kernels.RBF()
+likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n)
+prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
posterior = prior * likelihood
# %% [markdown]
@@ -119,7 +119,9 @@
# inducing points into the constructor as arguments.
# %%
-q = gpx.CollapsedVariationalGaussian(posterior=posterior, inducing_inputs=z)
+q = gpx.variational_families.CollapsedVariationalGaussian(
+ posterior=posterior, inducing_inputs=z
+)
# %% [markdown]
# We define our variational inference algorithm through `CollapsedVI`. This defines
@@ -127,7 +129,7 @@
# Titsias (2009).
# %%
-elbo = gpx.CollapsedELBO(negative=True)
+elbo = gpx.objectives.CollapsedELBO(negative=True)
# %% [markdown]
# For researchers, GPJax has the capacity to print the bibtex citation for objects such
@@ -241,14 +243,14 @@
# full model.
# %%
-full_rank_model = gpx.Prior(mean_function=gpx.Zero(), kernel=gpx.RBF()) * gpx.Gaussian(
- num_datapoints=D.n
-)
-negative_mll = jit(gpx.ConjugateMLL(negative=True).step)
+full_rank_model = gpx.gps.Prior(
+ mean_function=gpx.mean_functions.Zero(), kernel=gpx.kernels.RBF()
+) * gpx.likelihoods.Gaussian(num_datapoints=D.n)
+negative_mll = jit(gpx.objectives.ConjugateMLL(negative=True).step)
# %timeit negative_mll(full_rank_model, D).block_until_ready()
# %%
-negative_elbo = jit(gpx.CollapsedELBO(negative=True).step)
+negative_elbo = jit(gpx.objectives.CollapsedELBO(negative=True).step)
# %timeit negative_elbo(q, D).block_until_ready()
# %% [markdown]
diff --git a/docs/examples/constructing_new_kernels.py b/docs/examples/constructing_new_kernels.py
index 27001f38..073015ca 100644
--- a/docs/examples/constructing_new_kernels.py
+++ b/docs/examples/constructing_new_kernels.py
@@ -89,7 +89,7 @@
meanf = gpx.mean_functions.Zero()
for k, ax in zip(kernels, axes.ravel()):
- prior = gpx.Prior(mean_function=meanf, kernel=k)
+ prior = gpx.gps.Prior(mean_function=meanf, kernel=k)
rv = prior(x)
y = rv.sample(seed=key, sample_shape=(10,))
ax.plot(x, y.T, alpha=0.7)
@@ -263,13 +263,13 @@ def __call__(
# Define polar Gaussian process
PKern = Polar()
meanf = gpx.mean_functions.Zero()
-likelihood = gpx.Gaussian(num_datapoints=n)
-circular_posterior = gpx.Prior(mean_function=meanf, kernel=PKern) * likelihood
+likelihood = gpx.likelihoods.Gaussian(num_datapoints=n)
+circular_posterior = gpx.gps.Prior(mean_function=meanf, kernel=PKern) * likelihood
# Optimise GP's marginal log-likelihood using BFGS
opt_posterior, history = gpx.fit_scipy(
model=circular_posterior,
- objective=jit(gpx.ConjugateMLL(negative=True)),
+ objective=jit(gpx.objectives.ConjugateMLL(negative=True)),
train_data=D,
)
diff --git a/docs/examples/decision_making.py b/docs/examples/decision_making.py
index 0bf08e60..66dc77a9 100644
--- a/docs/examples/decision_making.py
+++ b/docs/examples/decision_making.py
@@ -136,9 +136,9 @@ def forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]:
# mean function and kernel for the job at hand:
# %%
-mean = gpx.Zero()
-kernel = gpx.Matern52()
-prior = gpx.Prior(mean_function=mean, kernel=kernel)
+mean = gpx.mean_functions.Zero()
+kernel = gpx.kernels.Matern52()
+prior = gpx.gps.Prior(mean_function=mean, kernel=kernel)
# %% [markdown]
# One difference from GPJax is the way in which we define our likelihood. In GPJax, we
@@ -153,7 +153,7 @@ def forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]:
# with the correct number of datapoints:
# %%
-likelihood_builder = lambda n: gpx.Gaussian(
+likelihood_builder = lambda n: gpx.likelihoods.Gaussian(
num_datapoints=n, obs_stddev=jnp.array(1e-3)
)
@@ -174,7 +174,7 @@ def forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]:
posterior_handler = PosteriorHandler(
prior,
likelihood_builder=likelihood_builder,
- optimization_objective=gpx.ConjugateMLL(negative=True),
+ optimization_objective=gpx.objectives.ConjugateMLL(negative=True),
optimizer=ox.adam(learning_rate=0.01),
num_optimization_iters=1000,
)
diff --git a/docs/examples/deep_kernels.py b/docs/examples/deep_kernels.py
index 72eae15a..97123fbd 100644
--- a/docs/examples/deep_kernels.py
+++ b/docs/examples/deep_kernels.py
@@ -163,16 +163,16 @@ def __call__(self, x):
# kernel and assume a Gaussian likelihood.
# %%
-base_kernel = gpx.Matern52(
+base_kernel = gpx.kernels.Matern52(
active_dims=list(range(feature_space_dim)),
lengthscale=jnp.ones((feature_space_dim,)),
)
kernel = DeepKernelFunction(
network=forward_linear, base_kernel=base_kernel, key=key, dummy_x=x
)
-meanf = gpx.Zero()
-prior = gpx.Prior(mean_function=meanf, kernel=kernel)
-likelihood = gpx.Gaussian(num_datapoints=D.n)
+meanf = gpx.mean_functions.Zero()
+prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
+likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n)
posterior = prior * likelihood
# %% [markdown]
# ### Optimisation
@@ -207,7 +207,7 @@ def __call__(self, x):
opt_posterior, history = gpx.fit(
model=posterior,
- objective=jax.jit(gpx.ConjugateMLL(negative=True)),
+ objective=jax.jit(gpx.objectives.ConjugateMLL(negative=True)),
train_data=D,
optim=optimiser,
num_iters=800,
diff --git a/docs/examples/graph_kernels.py b/docs/examples/graph_kernels.py
index f64c7b38..9344e340 100644
--- a/docs/examples/graph_kernels.py
+++ b/docs/examples/graph_kernels.py
@@ -94,13 +94,13 @@
# %%
x = jnp.arange(G.number_of_nodes()).reshape(-1, 1)
-true_kernel = gpx.GraphKernel(
+true_kernel = gpx.kernels.GraphKernel(
laplacian=L,
lengthscale=2.3,
variance=3.2,
smoothness=6.1,
)
-prior = gpx.Prior(mean_function=gpx.Zero(), kernel=true_kernel)
+prior = gpx.gps.Prior(mean_function=gpx.mean_functions.Zero(), kernel=true_kernel)
fx = prior(x)
y = fx.sample(seed=key, sample_shape=(1,)).reshape(-1, 1)
@@ -136,9 +136,9 @@
# We do this using the BFGS optimiser provided in `scipy` via 'jaxopt'.
# %%
-likelihood = gpx.Gaussian(num_datapoints=D.n)
-kernel = gpx.GraphKernel(laplacian=L)
-prior = gpx.Prior(mean_function=gpx.Zero(), kernel=kernel)
+likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n)
+kernel = gpx.kernels.GraphKernel(laplacian=L)
+prior = gpx.gps.Prior(mean_function=gpx.mean_functions.Zero(), kernel=kernel)
posterior = prior * likelihood
# %% [markdown]
diff --git a/docs/examples/intro_to_kernels.py b/docs/examples/intro_to_kernels.py
index 60e76a20..55113be6 100644
--- a/docs/examples/intro_to_kernels.py
+++ b/docs/examples/intro_to_kernels.py
@@ -160,7 +160,7 @@
meanf = gpx.mean_functions.Zero()
for k, ax in zip(kernels, axes.ravel()):
- prior = gpx.Prior(mean_function=meanf, kernel=k)
+ prior = gpx.gps.Prior(mean_function=meanf, kernel=k)
rv = prior(x)
y = rv.sample(seed=key, sample_shape=(10,))
ax.plot(x, y.T, alpha=0.7)
@@ -219,9 +219,9 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]:
lengthscale=jnp.array(0.1)
) # Initialise our kernel lengthscale to 0.1
-prior = gpx.Prior(mean_function=mean, kernel=kernel)
+prior = gpx.gps.Prior(mean_function=mean, kernel=kernel)
-likelihood = gpx.Gaussian(
+likelihood = gpx.likelihoods.Gaussian(
num_datapoints=D.n, obs_stddev=jnp.array(1e-3)
) # Our function is noise-free, so we set the observation noise's standard deviation to a very small value
likelihood = likelihood.replace_trainable(obs_stddev=False)
@@ -352,7 +352,7 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]:
# %%
mean = gpx.mean_functions.Zero()
kernel = gpx.kernels.Periodic()
-prior = gpx.Prior(mean_function=mean, kernel=kernel)
+prior = gpx.gps.Prior(mean_function=mean, kernel=kernel)
x = jnp.linspace(-3.0, 3.0, num=200).reshape(-1, 1)
rv = prior(x)
@@ -375,7 +375,7 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]:
# %%
mean = gpx.mean_functions.Zero()
kernel = gpx.kernels.Linear()
-prior = gpx.Prior(mean_function=mean, kernel=kernel)
+prior = gpx.gps.Prior(mean_function=mean, kernel=kernel)
x = jnp.linspace(-3.0, 3.0, num=200).reshape(-1, 1)
rv = prior(x)
@@ -411,7 +411,7 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]:
kernel_two = gpx.kernels.Periodic()
sum_kernel = gpx.kernels.SumKernel(kernels=[kernel_one, kernel_two])
mean = gpx.mean_functions.Zero()
-prior = gpx.Prior(mean_function=mean, kernel=sum_kernel)
+prior = gpx.gps.Prior(mean_function=mean, kernel=sum_kernel)
x = jnp.linspace(-3.0, 3.0, num=200).reshape(-1, 1)
rv = prior(x)
@@ -436,7 +436,7 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]:
kernel_two = gpx.kernels.Periodic()
sum_kernel = gpx.kernels.ProductKernel(kernels=[kernel_one, kernel_two])
mean = gpx.mean_functions.Zero()
-prior = gpx.Prior(mean_function=mean, kernel=sum_kernel)
+prior = gpx.gps.Prior(mean_function=mean, kernel=sum_kernel)
x = jnp.linspace(-3.0, 3.0, num=200).reshape(-1, 1)
rv = prior(x)
@@ -522,8 +522,8 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]:
sum_kernel = gpx.kernels.SumKernel(kernels=[linear_kernel, periodic_kernel])
final_kernel = gpx.kernels.SumKernel(kernels=[rbf_kernel, sum_kernel])
-prior = gpx.Prior(mean_function=mean, kernel=final_kernel)
-likelihood = gpx.Gaussian(num_datapoints=D.n)
+prior = gpx.gps.Prior(mean_function=mean, kernel=final_kernel)
+likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n)
posterior = prior * likelihood
@@ -645,7 +645,8 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]:
# %%
print(
- f"Periodic Kernel Period: {[i for i in opt_posterior.prior.kernel.kernels if isinstance(i, gpx.kernels.Periodic)][0].period}"
+ "Periodic Kernel Period:"
+ f" {[i for i in opt_posterior.prior.kernel.kernels if isinstance(i, gpx.kernels.Periodic)][0].period}"
)
# %% [markdown]
diff --git a/docs/examples/likelihoods_guide.py b/docs/examples/likelihoods_guide.py
index 6fd79521..9199d9e3 100644
--- a/docs/examples/likelihoods_guide.py
+++ b/docs/examples/likelihoods_guide.py
@@ -124,11 +124,11 @@
# $\mathbf{y}^{\star}$.
# +
-kernel = gpx.Matern32()
-meanf = gpx.Zero()
-prior = gpx.Prior(kernel=kernel, mean_function=meanf)
+kernel = gpx.kernels.Matern32()
+meanf = gpx.mean_functions.Zero()
+prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
-likelihood = gpx.Gaussian(num_datapoints=D.n, obs_stddev=0.1)
+likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n, obs_stddev=0.1)
posterior = prior * likelihood
@@ -158,7 +158,7 @@
# Similarly, for a Bernoulli likelihood function, the samples of $y$ would be binary.
# +
-likelihood = gpx.Bernoulli(num_datapoints=D.n)
+likelihood = gpx.likelihoods.Bernoulli(num_datapoints=D.n)
fig, axes = plt.subplots(ncols=3, nrows=1, figsize=(9, 2))
@@ -231,7 +231,7 @@
# +
z = jnp.linspace(-3.0, 3.0, 10).reshape(-1, 1)
-q = gpx.VariationalGaussian(posterior=posterior, inducing_inputs=z)
+q = gpx.variational_families.VariationalGaussian(posterior=posterior, inducing_inputs=z)
def q_moments(x):
@@ -251,7 +251,7 @@ def q_moments(x):
# However, had we wanted to do this using quadrature, then we would have done the
# following:
-lquad = gpx.Gaussian(
+lquad = gpx.likelihoods.Gaussian(
num_datapoints=D.n,
obs_stddev=jnp.array([0.1]),
integrator=gpx.integrators.GHQuadratureIntegrator(num_points=20),
diff --git a/docs/examples/oceanmodelling.py b/docs/examples/oceanmodelling.py
index 162fba0b..f917ec04 100644
--- a/docs/examples/oceanmodelling.py
+++ b/docs/examples/oceanmodelling.py
@@ -223,8 +223,8 @@ def __call__(
# %%
def initialise_gp(kernel, mean, dataset):
- prior = gpx.Prior(mean_function=mean, kernel=kernel)
- likelihood = gpx.Gaussian(
+ prior = gpx.gps.Prior(mean_function=mean, kernel=kernel)
+ likelihood = gpx.likelihoods.Gaussian(
num_datapoints=dataset.n, obs_stddev=jnp.array([1.0e-3], dtype=jnp.float64)
)
posterior = prior * likelihood
diff --git a/docs/examples/poisson.py b/docs/examples/poisson.py
index 3543b9da..61ff0f61 100644
--- a/docs/examples/poisson.py
+++ b/docs/examples/poisson.py
@@ -83,10 +83,10 @@
# kernel, chosen for the purpose of exposition. We adopt the Poisson likelihood available in GPJax.
# %%
-kernel = gpx.RBF()
-meanf = gpx.Constant()
-prior = gpx.Prior(mean_function=meanf, kernel=kernel)
-likelihood = gpx.Poisson(num_datapoints=D.n)
+kernel = gpx.kernels.RBF()
+meanf = gpx.mean_functions.Constant()
+prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
+likelihood = gpx.likelihoods.Poisson(num_datapoints=D.n)
# %% [markdown]
# We construct the posterior through the product of our prior and likelihood.
@@ -135,7 +135,7 @@
num_adapt = 100
num_samples = 200
-lpd = jax.jit(gpx.LogPosteriorDensity(negative=False))
+lpd = jax.jit(gpx.objectives.LogPosteriorDensity(negative=False))
unconstrained_lpd = jax.jit(lambda tree: lpd(tree.constrain(), D))
adapt = blackjax.window_adaptation(
diff --git a/docs/examples/regression.py b/docs/examples/regression.py
index 47b8439a..f1fc77a3 100644
--- a/docs/examples/regression.py
+++ b/docs/examples/regression.py
@@ -108,7 +108,7 @@
# %%
kernel = gpx.kernels.RBF()
meanf = gpx.mean_functions.Zero()
-prior = gpx.Prior(mean_function=meanf, kernel=kernel)
+prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
# %% [markdown]
#
@@ -152,7 +152,7 @@
# This is defined in GPJax through calling a `Gaussian` instance.
# %%
-likelihood = gpx.Gaussian(num_datapoints=D.n)
+likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n)
# %% [markdown]
# The posterior is proportional to the prior multiplied by the likelihood, written as
diff --git a/docs/examples/uncollapsed_vi.py b/docs/examples/uncollapsed_vi.py
index 5511a9f1..0c099464 100644
--- a/docs/examples/uncollapsed_vi.py
+++ b/docs/examples/uncollapsed_vi.py
@@ -203,10 +203,10 @@
# %%
meanf = gpx.mean_functions.Zero()
-likelihood = gpx.Gaussian(num_datapoints=n)
-prior = gpx.Prior(mean_function=meanf, kernel=jk.RBF())
+likelihood = gpx.likelihoods.Gaussian(num_datapoints=n)
+prior = gpx.gps.Prior(mean_function=meanf, kernel=jk.RBF())
p = prior * likelihood
-q = gpx.VariationalGaussian(posterior=p, inducing_inputs=z)
+q = gpx.variational_families.VariationalGaussian(posterior=p, inducing_inputs=z)
# %% [markdown]
# Here, the variational process $q(\cdot)$ depends on the prior through
@@ -232,7 +232,7 @@
# its negative.
# %%
-negative_elbo = gpx.ELBO(negative=True)
+negative_elbo = gpx.objectives.ELBO(negative=True)
# %% [markdown]
# For researchers, GPJax has the capacity to print the bibtex citation for objects such
diff --git a/docs/examples/yacht.py b/docs/examples/yacht.py
index 59b7f8c9..dc40f514 100644
--- a/docs/examples/yacht.py
+++ b/docs/examples/yacht.py
@@ -173,9 +173,9 @@
lengthscale=0.1 * np.ones((n_covariates,)),
)
meanf = gpx.mean_functions.Zero()
-prior = gpx.Prior(mean_function=meanf, kernel=kernel)
+prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
-likelihood = gpx.Gaussian(num_datapoints=n_train)
+likelihood = gpx.likelihoods.Gaussian(num_datapoints=n_train)
posterior = prior * likelihood
diff --git a/gpjax/__init__.py b/gpjax/__init__.py
index 6e45ae72..0fa71556 100644
--- a/gpjax/__init__.py
+++ b/gpjax/__init__.py
@@ -13,8 +13,15 @@
# limitations under the License.
# ==============================================================================
from gpjax import (
+ base,
decision_making,
+ gps,
integrators,
+ kernels,
+ likelihoods,
+ mean_functions,
+ objectives,
+ variational_families,
)
from gpjax.base import (
Module,
@@ -26,111 +33,27 @@
fit,
fit_scipy,
)
-from gpjax.gps import (
- Prior,
- construct_posterior,
-)
-from gpjax.kernels import (
- RBF,
- RFF,
- AbstractKernel,
- BasisFunctionComputation,
- ConstantDiagonalKernelComputation,
- DenseKernelComputation,
- DiagonalKernelComputation,
- EigenKernelComputation,
- GraphKernel,
- Linear,
- Matern12,
- Matern32,
- Matern52,
- Periodic,
- Polynomial,
- PoweredExponential,
- ProductKernel,
- RationalQuadratic,
- SumKernel,
- White,
-)
-from gpjax.likelihoods import (
- Bernoulli,
- Gaussian,
- Poisson,
-)
-from gpjax.mean_functions import (
- Constant,
- Zero,
-)
-from gpjax.objectives import (
- ELBO,
- CollapsedELBO,
- ConjugateLOOCV,
- ConjugateMLL,
- LogPosteriorDensity,
- NonConjugateMLL,
-)
-from gpjax.variational_families import (
- CollapsedVariationalGaussian,
- ExpectationVariationalGaussian,
- NaturalVariationalGaussian,
- VariationalGaussian,
- WhitenedVariationalGaussian,
-)
__license__ = "MIT"
__description__ = "Didactic Gaussian processes in JAX"
__url__ = "https://github.com/JaxGaussianProcesses/GPJax"
__contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
-__version__ = "0.7.3"
+__version__ = "0.8.0"
__all__ = [
- "Module",
- "param_field",
- "cite",
+ "base",
"decision_making",
+ "gps",
+ "integrators",
"kernels",
+ "likelihoods",
+ "mean_functions",
+ "objectives",
+ "variational_families",
+ "Dataset",
+ "cite",
"fit",
+ "Module",
+ "param_field",
"fit_scipy",
- "Prior",
- "construct_posterior",
- "integrators",
- "RBF",
- "GraphKernel",
- "Matern12",
- "Matern32",
- "Matern52",
- "Polynomial",
- "ProductKernel",
- "SumKernel",
- "Bernoulli",
- "Gaussian",
- "Poisson",
- "Constant",
- "Zero",
- "Dataset",
- "CollapsedVariationalGaussian",
- "ExpectationVariationalGaussian",
- "NaturalVariationalGaussian",
- "VariationalGaussian",
- "WhitenedVariationalGaussian",
- "CollapsedVI",
- "StochasticVI",
- "ConjugateMLL",
- "ConjugateLOOCV",
- "NonConjugateMLL",
- "LogPosteriorDensity",
- "CollapsedELBO",
- "ELBO",
- "AbstractKernel",
- "Linear",
- "DenseKernelComputation",
- "DiagonalKernelComputation",
- "ConstantDiagonalKernelComputation",
- "EigenKernelComputation",
- "PoweredExponential",
- "Periodic",
- "RationalQuadratic",
- "White",
- "BasisFunctionComputation",
- "RFF",
]
diff --git a/gpjax/fit.py b/gpjax/fit.py
index 6986549b..9cdcdcd3 100644
--- a/gpjax/fit.py
+++ b/gpjax/fit.py
@@ -76,9 +76,9 @@ def fit( # noqa: PLR0913
>>> D = gpx.Dataset(X, y)
>>>
>>> # (2) Define your model:
- >>> class LinearModel(gpx.Module):
- weight: float = gpx.param_field()
- bias: float = gpx.param_field()
+ >>> class LinearModel(gpx.base.Module):
+ weight: float = gpx.base.param_field()
+ bias: float = gpx.base.param_field()
def __call__(self, x):
return self.weight * x + self.bias
@@ -86,7 +86,7 @@ def __call__(self, x):
>>> model = LinearModel(weight=1.0, bias=1.0)
>>>
>>> # (3) Define your loss function:
- >>> class MeanSquareError(gpx.AbstractObjective):
+ >>> class MeanSquareError(gpx.objectives.AbstractObjective):
def evaluate(self, model: LinearModel, train_data: gpx.Dataset) -> float:
return jnp.mean((train_data.y - model(train_data.X)) ** 2)
>>>
diff --git a/gpjax/gps.py b/gpjax/gps.py
index 5928ef49..f1fa477b 100644
--- a/gpjax/gps.py
+++ b/gpjax/gps.py
@@ -133,7 +133,7 @@ class Prior(AbstractPrior):
>>> kernel = gpx.kernels.RBF()
>>> meanf = gpx.mean_functions.Zero()
- >>> prior = gpx.Prior(mean_function=meanf, kernel = kernel)
+ >>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel)
```
"""
@@ -167,7 +167,7 @@ def __mul__(self, other):
>>>
>>> meanf = gpx.mean_functions.Zero()
>>> kernel = gpx.kernels.RBF()
- >>> prior = gpx.Prior(mean_function=meanf, kernel = kernel)
+ >>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel)
>>> likelihood = gpx.likelihoods.Gaussian(num_datapoints=100)
>>>
>>> prior * likelihood
@@ -228,7 +228,7 @@ def predict(self, test_inputs: Num[Array, "N D"]) -> GaussianDistribution:
>>>
>>> kernel = gpx.kernels.RBF()
>>> meanf = gpx.mean_functions.Zero()
- >>> prior = gpx.Prior(mean_function=meanf, kernel = kernel)
+ >>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel)
>>>
>>> prior.predict(jnp.linspace(0, 1, 100))
```
@@ -289,7 +289,7 @@ def sample_approx(
>>>
>>> meanf = gpx.mean_functions.Zero()
>>> kernel = gpx.kernels.RBF()
- >>> prior = gpx.Prior(mean_function=meanf, kernel = kernel)
+ >>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel)
>>>
>>> sample_fn = prior.sample_approx(10, key)
>>> sample_fn(jnp.linspace(0, 1, 100).reshape(-1, 1))
@@ -413,7 +413,7 @@ class ConjugatePosterior(AbstractPosterior):
>>> import gpjax as gpx
>>> import jax.numpy as jnp
- >>> prior = gpx.Prior(
+ >>> prior = gpx.gps.Prior(
mean_function = gpx.mean_functions.Zero(),
kernel = gpx.kernels.RBF()
)
@@ -461,8 +461,8 @@ def predict(
>>> D = gpx.Dataset(X=xtrain, y=ytrain)
>>> xtest = jnp.linspace(0, 1).reshape(-1, 1)
>>>
- >>> prior = gpx.Prior(mean_function = gpx.Zero(), kernel = gpx.RBF())
- >>> posterior = prior * gpx.Gaussian(num_datapoints = D.n)
+ >>> prior = gpx.gps.Prior(mean_function = gpx.mean_functions.Zero(), kernel = gpx.kernels.RBF())
+ >>> posterior = prior * gpx.likelihoods.Gaussian(num_datapoints = D.n)
>>> predictive_dist = posterior(xtest, D)
```
diff --git a/gpjax/objectives.py b/gpjax/objectives.py
index c07290c4..9b684f89 100644
--- a/gpjax/objectives.py
+++ b/gpjax/objectives.py
@@ -100,10 +100,10 @@ def step(
>>> meanf = gpx.mean_functions.Constant()
>>> kernel = gpx.kernels.RBF()
>>> likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n)
- >>> prior = gpx.Prior(mean_function = meanf, kernel=kernel)
+ >>> prior = gpx.gps.Prior(mean_function = meanf, kernel=kernel)
>>> posterior = prior * likelihood
>>>
- >>> mll = gpx.ConjugateMLL(negative=True)
+ >>> mll = gpx.objectives.ConjugateMLL(negative=True)
>>> mll(posterior, train_data = D)
```
@@ -112,13 +112,13 @@ def step(
marginal log-likelihood. This can be realised through
```python
- mll = gpx.ConjugateMLL(negative=True)
+ mll = gpx.objectives.ConjugateMLL(negative=True)
```
For optimal performance, the marginal log-likelihood should be ``jax.jit``
compiled.
```python
- mll = jit(gpx.ConjugateMLL(negative=True))
+ mll = jit(gpx.objectives.ConjugateMLL(negative=True))
```
Args:
@@ -180,10 +180,10 @@ def step(
>>> meanf = gpx.mean_functions.Constant()
>>> kernel = gpx.kernels.RBF()
>>> likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n)
- >>> prior = gpx.Prior(mean_function = meanf, kernel=kernel)
+ >>> prior = gpx.gps.Prior(mean_function = meanf, kernel=kernel)
>>> posterior = prior * likelihood
>>>
- >>> loocv = gpx.ConjugateLOOCV(negative=True)
+ >>> loocv = gpx.objectives.ConjugateLOOCV(negative=True)
>>> loocv(posterior, train_data = D)
```
@@ -192,13 +192,13 @@ def step(
leave-one-out log predictive probability. This can be realised through
```python
- mll = gpx.ConjugateLOOCV(negative=True)
+ mll = gpx.objectives.ConjugateLOOCV(negative=True)
```
For optimal performance, the objective should be ``jax.jit``
compiled.
```python
- mll = jit(gpx.ConjugateLOOCV(negative=True))
+ mll = jit(gpx.objectives.ConjugateLOOCV(negative=True))
```
Args:
diff --git a/gpjax/progress_bar.py b/gpjax/progress_bar.py
index 090a03ea..3072b71b 100644
--- a/gpjax/progress_bar.py
+++ b/gpjax/progress_bar.py
@@ -36,10 +36,10 @@ def progress_bar(num_iters: int, log_rate: int) -> Callable:
>>>
>>> carry = jnp.array(0.0)
>>> iteration_numbers = jnp.arange(100)
-
+ >>>
>>> @progress_bar(num_iters=iteration_numbers.shape[0], log_rate=10)
>>> def body_func(carry, x):
- ... return carry, x
+ >>> return carry, x
>>>
>>> carry, _ = jax.lax.scan(body_func, carry, iteration_numbers)
```
diff --git a/pyproject.toml b/pyproject.toml
index 6e2aa453..198ccac6 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "gpjax"
-version = "0.7.3"
+version = "0.8.0"
description = "Gaussian processes in JAX."
authors = [
"Thomas Pinder ",
diff --git a/tests/test_citations.py b/tests/test_citations.py
index b7fe8a47..fd2531f5 100644
--- a/tests/test_citations.py
+++ b/tests/test_citations.py
@@ -6,7 +6,6 @@
import jax.numpy as jnp
import pytest
-import gpjax as gpx
from gpjax.citation import (
AbstractCitation,
BookCitation,
@@ -30,6 +29,13 @@
Matern32,
Matern52,
)
+from gpjax.objectives import (
+ ELBO,
+ CollapsedELBO,
+ ConjugateMLL,
+ LogPosteriorDensity,
+ NonConjugateMLL,
+)
def _check_no_fallback(citation: AbstractCitation):
@@ -103,7 +109,7 @@ def test_missing_citation(kernel):
@pytest.mark.parametrize(
- "mll", [gpx.ConjugateMLL(), gpx.NonConjugateMLL(), gpx.LogPosteriorDensity()]
+ "mll", [ConjugateMLL(), NonConjugateMLL(), LogPosteriorDensity()]
)
def test_mlls(mll):
citation = cite(mll)
@@ -115,7 +121,7 @@ def test_mlls(mll):
def test_uncollapsed_elbo():
- elbo = gpx.ELBO()
+ elbo = ELBO()
citation = cite(elbo)
assert isinstance(citation, PaperCitation)
@@ -128,7 +134,7 @@ def test_uncollapsed_elbo():
def test_collapsed_elbo():
- elbo = gpx.CollapsedELBO()
+ elbo = CollapsedELBO()
citation = cite(elbo)
assert isinstance(citation, PaperCitation)
@@ -158,7 +164,8 @@ def test_thompson_sampling():
)
assert (
citation.authors
- == "Wilson, James and Borovitskiy, Viacheslav and Terenin, Alexander and Mostowsky, Peter and Deisenroth, Marc"
+ == "Wilson, James and Borovitskiy, Viacheslav and Terenin, Alexander and"
+ " Mostowsky, Peter and Deisenroth, Marc"
)
assert citation.year == "2020"
assert citation.booktitle == "International Conference on Machine Learning"
@@ -205,7 +212,7 @@ def test_logarithmic_goldstein_price():
@pytest.mark.parametrize(
"objective",
- [gpx.ELBO(), gpx.CollapsedELBO(), gpx.LogPosteriorDensity(), gpx.ConjugateMLL()],
+ [ELBO(), CollapsedELBO(), LogPosteriorDensity(), ConjugateMLL()],
)
def test_jitted_fallback(objective):
with pytest.raises(RuntimeError):
diff --git a/tests/test_decision_making/test_decision_maker.py b/tests/test_decision_making/test_decision_maker.py
index 4d79f8d5..23a87a2f 100644
--- a/tests/test_decision_making/test_decision_maker.py
+++ b/tests/test_decision_making/test_decision_maker.py
@@ -61,16 +61,16 @@ def search_space() -> ContinuousSearchSpace:
@pytest.fixture
def posterior_handler() -> PosteriorHandler:
- mean = gpx.Zero()
- kernel = gpx.Matern52(lengthscale=jnp.array(1.0), variance=jnp.array(1.0))
- prior = gpx.Prior(mean_function=mean, kernel=kernel)
- likelihood_builder = lambda x: gpx.Gaussian(
+ mean = gpx.mean_functions.Zero()
+ kernel = gpx.kernels.Matern52(lengthscale=jnp.array(1.0), variance=jnp.array(1.0))
+ prior = gpx.gps.Prior(mean_function=mean, kernel=kernel)
+ likelihood_builder = lambda x: gpx.likelihoods.Gaussian(
num_datapoints=x, obs_stddev=jnp.array(1e-3)
)
posterior_handler = PosteriorHandler(
prior=prior,
likelihood_builder=likelihood_builder,
- optimization_objective=gpx.ConjugateMLL(negative=True),
+ optimization_objective=gpx.objectives.ConjugateMLL(negative=True),
optimizer=ox.adam(learning_rate=0.01),
num_optimization_iters=100,
)
diff --git a/tests/test_mean_functions.py b/tests/test_mean_functions.py
index fded647c..2fb74283 100644
--- a/tests/test_mean_functions.py
+++ b/tests/test_mean_functions.py
@@ -65,8 +65,10 @@ def test_zero_mean_remains_zero() -> None:
constant=False
) # Prevent kernel from modelling non-zero mean
meanf = Zero()
- prior = gpx.Prior(mean_function=meanf, kernel=kernel)
- likelihood = gpx.Gaussian(num_datapoints=D.n, obs_stddev=jnp.array(1e-3))
+ prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
+ likelihood = gpx.likelihoods.Gaussian(
+ num_datapoints=D.n, obs_stddev=jnp.array(1e-3)
+ )
likelihood = likelihood.replace_trainable(obs_stddev=False)
posterior = prior * likelihood
diff --git a/tests/test_objectives.py b/tests/test_objectives.py
index 5817f13a..fbb53f14 100644
--- a/tests/test_objectives.py
+++ b/tests/test_objectives.py
@@ -5,12 +5,9 @@
import pytest
import gpjax as gpx
-from gpjax import (
- Bernoulli,
- Gaussian,
- Prior,
-)
from gpjax.dataset import Dataset
+from gpjax.gps import Prior
+from gpjax.likelihoods import Gaussian
from gpjax.objectives import (
ELBO,
AbstractObjective,
@@ -64,10 +61,11 @@ def test_conjugate_mll(
D = build_data(num_datapoints, num_dims, key, binary=False)
# Build model
- p = Prior(
- kernel=gpx.RBF(active_dims=list(range(num_dims))), mean_function=gpx.Constant()
+ p = gpx.gps.Prior(
+ kernel=gpx.kernels.RBF(active_dims=list(range(num_dims))),
+ mean_function=gpx.mean_functions.Constant(),
)
- likelihood = Gaussian(num_datapoints=num_datapoints)
+ likelihood = gpx.likelihoods.Gaussian(num_datapoints=num_datapoints)
post = p * likelihood
mll = ConjugateMLL(negative=negative)
@@ -94,7 +92,8 @@ def test_conjugate_loocv(
# Build model
p = Prior(
- kernel=gpx.RBF(active_dims=list(range(num_dims))), mean_function=gpx.Constant()
+ kernel=gpx.kernels.RBF(active_dims=list(range(num_dims))),
+ mean_function=gpx.mean_functions.Constant(),
)
likelihood = Gaussian(num_datapoints=num_datapoints)
post = p * likelihood
@@ -122,10 +121,11 @@ def test_non_conjugate_mll(
D = build_data(num_datapoints, num_dims, key, binary=True)
# Build model
- p = Prior(
- kernel=gpx.RBF(active_dims=list(range(num_dims))), mean_function=gpx.Constant()
+ p = gpx.gps.Prior(
+ kernel=gpx.kernels.RBF(active_dims=list(range(num_dims))),
+ mean_function=gpx.mean_functions.Constant(),
)
- likelihood = Bernoulli(num_datapoints=num_datapoints)
+ likelihood = gpx.likelihoods.Bernoulli(num_datapoints=num_datapoints)
post = p * likelihood
mll = NonConjugateMLL(negative=negative)
@@ -158,11 +158,14 @@ def test_collapsed_elbo(
key=key, minval=-2.0, maxval=2.0, shape=(num_datapoints // 2, num_dims)
)
- p = Prior(
- kernel=gpx.RBF(active_dims=list(range(num_dims))), mean_function=gpx.Constant()
+ p = gpx.gps.Prior(
+ kernel=gpx.kernels.RBF(active_dims=list(range(num_dims))),
+ mean_function=gpx.mean_functions.Constant(),
+ )
+ likelihood = gpx.likelihoods.Gaussian(num_datapoints=num_datapoints)
+ q = gpx.variational_families.CollapsedVariationalGaussian(
+ posterior=p * likelihood, inducing_inputs=z
)
- likelihood = Gaussian(num_datapoints=num_datapoints)
- q = gpx.CollapsedVariationalGaussian(posterior=p * likelihood, inducing_inputs=z)
negative_elbo = CollapsedELBO(negative=negative)
@@ -176,7 +179,9 @@ def test_collapsed_elbo(
assert evaluation.shape == ()
# Data on the full dataset should be the same as the marginal likelihood
- q = gpx.CollapsedVariationalGaussian(posterior=p * likelihood, inducing_inputs=D.X)
+ q = gpx.variational_families.CollapsedVariationalGaussian(
+ posterior=p * likelihood, inducing_inputs=D.X
+ )
mll = ConjugateMLL(negative=negative)
expected_value = mll(p * likelihood, D)
actual_value = negative_elbo(q, D)
@@ -203,16 +208,17 @@ def test_elbo(
key=key, minval=-2.0, maxval=2.0, shape=(num_datapoints // 2, num_dims)
)
- p = Prior(
- kernel=gpx.RBF(active_dims=list(range(num_dims))), mean_function=gpx.Constant()
+ p = gpx.gps.Prior(
+ kernel=gpx.kernels.RBF(active_dims=list(range(num_dims))),
+ mean_function=gpx.mean_functions.Constant(),
)
if binary:
- likelihood = Bernoulli(num_datapoints=num_datapoints)
+ likelihood = gpx.likelihoods.Bernoulli(num_datapoints=num_datapoints)
else:
- likelihood = Gaussian(num_datapoints=num_datapoints)
+ likelihood = gpx.likelihoods.Gaussian(num_datapoints=num_datapoints)
post = p * likelihood
- q = gpx.VariationalGaussian(posterior=post, inducing_inputs=z)
+ q = gpx.variational_families.VariationalGaussian(posterior=post, inducing_inputs=z)
negative_elbo = ELBO(
negative=negative,
diff --git a/tests/test_variational_families.py b/tests/test_variational_families.py
index 8a576c2c..66a3c650 100644
--- a/tests/test_variational_families.py
+++ b/tests/test_variational_families.py
@@ -114,8 +114,10 @@ def test_variational_gaussians(
variational_family: AbstractVariationalFamily,
) -> None:
# Initialise variational family:
- prior = gpx.Prior(kernel=gpx.RBF(), mean_function=gpx.Constant())
- likelihood = gpx.Gaussian(123)
+ prior = gpx.gps.Prior(
+ kernel=gpx.kernels.RBF(), mean_function=gpx.mean_functions.Constant()
+ )
+ likelihood = gpx.likelihoods.Gaussian(123)
inducing_inputs = jnp.linspace(-5.0, 5.0, n_inducing).reshape(-1, 1)
test_inputs = jnp.linspace(-5.0, 5.0, n_test).reshape(-1, 1)
@@ -223,14 +225,16 @@ def test_collapsed_variational_gaussian(
x = jnp.hstack([x] * point_dim)
D = gpx.Dataset(X=x, y=y)
- prior = gpx.Prior(kernel=gpx.RBF(), mean_function=gpx.Constant())
+ prior = gpx.gps.Prior(
+ kernel=gpx.kernels.RBF(), mean_function=gpx.mean_functions.Constant()
+ )
inducing_inputs = jnp.linspace(-5.0, 5.0, n_inducing).reshape(-1, 1)
inducing_inputs = jnp.hstack([inducing_inputs] * point_dim)
test_inputs = jnp.linspace(-5.0, 5.0, n_test).reshape(-1, 1)
test_inputs = jnp.hstack([test_inputs] * point_dim)
- posterior = prior * gpx.Gaussian(num_datapoints=D.n)
+ posterior = prior * gpx.likelihoods.Gaussian(num_datapoints=D.n)
variational_family = CollapsedVariationalGaussian(
posterior=posterior,
@@ -240,7 +244,7 @@ def test_collapsed_variational_gaussian(
# We should raise an error for non-Gaussian likelihoods:
with pytest.raises(TypeError):
CollapsedVariationalGaussian(
- posterior=prior * gpx.Bernoulli(num_datapoints=D.n),
+ posterior=prior * gpx.likelihoods.Bernoulli(num_datapoints=D.n),
inducing_inputs=inducing_inputs,
)