Skip to content

Commit

Permalink
Merge pull request #408 from JaxGaussianProcesses/namespace_cleanup
Browse files Browse the repository at this point in the history
Namespace cleanup
  • Loading branch information
thomaspinder authored Nov 30, 2023
2 parents ac47576 + 86490b1 commit 07d99db
Show file tree
Hide file tree
Showing 32 changed files with 198 additions and 247 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,10 @@ D = gpx.Dataset(X=x, y=y)
# Construct the prior
meanf = gpx.mean_functions.Zero()
kernel = gpx.kernels.RBF()
prior = gpx.Prior(mean_function=meanf, kernel=kernel)
prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel)

# Define a likelihood
likelihood = gpx.Gaussian(num_datapoints=n)
likelihood = gpx.likelihoods.Gaussian(num_datapoints = n)

# Construct the posterior
posterior = prior * likelihood
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def setup(self, n_datapoints: int, n_dims: int):
self.data = gpx.Dataset(X=self.X, y=self.y)
kernel = gpx.kernels.RBF(active_dims=list(range(n_dims)))
meanf = gpx.mean_functions.Constant()
self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Gaussian(num_datapoints=self.data.n)
self.objective = gpx.ConjugateMLL()
self.posterior = self.prior * self.likelihood
Expand All @@ -48,7 +48,7 @@ def setup(self, n_datapoints: int, n_dims: int):
self.data = gpx.Dataset(X=self.X, y=self.y)
kernel = gpx.kernels.RBF(active_dims=list(range(n_dims)))
meanf = gpx.mean_functions.Constant()
self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Bernoulli(num_datapoints=self.data.n)
self.objective = gpx.LogPosteriorDensity()
self.posterior = self.prior * self.likelihood
Expand All @@ -75,7 +75,7 @@ def setup(self, n_datapoints: int, n_dims: int):
self.data = gpx.Dataset(X=self.X, y=self.y)
kernel = gpx.kernels.RBF(active_dims=list(range(n_dims)))
meanf = gpx.mean_functions.Constant()
self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Poisson(num_datapoints=self.data.n)
self.objective = gpx.LogPosteriorDensity()
self.posterior = self.prior * self.likelihood
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def setup(self, n_test: int, n_dims: int):
self.data = gpx.Dataset(X=self.X, y=self.y)
kernel = gpx.kernels.RBF(active_dims=list(range(n_dims)))
meanf = gpx.mean_functions.Constant()
self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Gaussian(num_datapoints=self.data.n)
self.posterior = self.prior * self.likelihood
key, subkey = jr.split(key)
Expand All @@ -46,7 +46,7 @@ def setup(self, n_test: int, n_dims: int):
self.data = gpx.Dataset(X=self.X, y=self.y)
kernel = gpx.kernels.RBF(active_dims=list(range(n_dims)))
meanf = gpx.mean_functions.Constant()
self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Bernoulli(num_datapoints=self.data.n)
self.posterior = self.prior * self.likelihood
key, subkey = jr.split(key)
Expand All @@ -71,7 +71,7 @@ def setup(self, n_test: int, n_dims: int):
self.data = gpx.Dataset(X=self.X, y=self.y)
kernel = gpx.kernels.RBF(active_dims=list(range(n_dims)))
meanf = gpx.mean_functions.Constant()
self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Bernoulli(num_datapoints=self.data.n)
self.posterior = self.prior * self.likelihood
key, subkey = jr.split(key)
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def setup(self, n_datapoints: int, n_inducing: int):
self.data = gpx.Dataset(X=self.X, y=self.y)
kernel = gpx.kernels.RBF(active_dims=list(range(1)))
meanf = gpx.mean_functions.Constant()
self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Gaussian(num_datapoints=self.data.n)
self.posterior = self.prior * self.likelihood

Expand Down
2 changes: 1 addition & 1 deletion benchmarks/stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def setup(self, n_datapoints: int, n_inducing: int, batch_size: int):
self.data = gpx.Dataset(X=self.X, y=self.y)
kernel = gpx.kernels.RBF(active_dims=list(range(1)))
meanf = gpx.mean_functions.Constant()
self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Gaussian(num_datapoints=self.data.n)
self.posterior = self.prior * self.likelihood

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class Prior(AbstractPrior):
>>>
>>> meanf = gpx.mean_functions.Zero()
>>> kernel = gpx.kernels.RBF()
>>> prior = gpx.Prior(mean_function=meanf, kernel = kernel)
>>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel)
Attributes:
kernel (Kernel): The kernel function used to parameterise the prior.
Expand Down
9 changes: 7 additions & 2 deletions docs/examples/barycentres.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,13 @@ def fit_gp(x: jax.Array, y: jax.Array) -> tfd.MultivariateNormalFullCovariance:
y = y.reshape(-1, 1)
D = gpx.Dataset(X=x, y=y)

likelihood = gpx.Gaussian(num_datapoints=n)
posterior = gpx.Prior(mean_function=gpx.Constant(), kernel=gpx.RBF()) * likelihood
likelihood = gpx.likelihoods.Gaussian(num_datapoints=n)
posterior = (
gpx.gps.Prior(
mean_function=gpx.mean_functions.Constant(), kernel=gpx.kernels.RBF()
)
* likelihood
)

opt_posterior, _ = gpx.fit_scipy(
model=posterior,
Expand Down
17 changes: 9 additions & 8 deletions docs/examples/bayesian_optimisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,9 @@ def forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]:

# %%
def return_optimised_posterior(
data: gpx.Dataset, prior: gpx.Module, key: Array
) -> gpx.Module:
likelihood = gpx.Gaussian(
data: gpx.Dataset, prior: gpx.base.Module, key: Array
) -> gpx.base.Module:
likelihood = gpx.likelihoods.Gaussian(
num_datapoints=data.n, obs_stddev=jnp.array(1e-3)
) # Our function is noise-free, so we set the observation noise's standard deviation to a very small value
likelihood = likelihood.replace_trainable(obs_stddev=False)
Expand All @@ -230,7 +230,7 @@ def return_optimised_posterior(

mean = gpx.mean_functions.Zero()
kernel = gpx.kernels.Matern52()
prior = gpx.Prior(mean_function=mean, kernel=kernel)
prior = gpx.gps.Prior(mean_function=mean, kernel=kernel)
opt_posterior = return_optimised_posterior(D, prior, key)

# %% [markdown]
Expand Down Expand Up @@ -315,7 +315,7 @@ def optimise_sample(

# %%
def plot_bayes_opt(
posterior: gpx.Module,
posterior: gpx.base.Module,
sample: FunctionalSample,
dataset: gpx.Dataset,
queried_x: ScalarFloat,
Expand Down Expand Up @@ -401,7 +401,7 @@ def plot_bayes_opt(
# Generate optimised posterior using previously observed data
mean = gpx.mean_functions.Zero()
kernel = gpx.kernels.Matern52()
prior = gpx.Prior(mean_function=mean, kernel=kernel)
prior = gpx.gps.Prior(mean_function=mean, kernel=kernel)
opt_posterior = return_optimised_posterior(D, prior, subkey)

# Draw a sample from the posterior, and find the minimiser of it
Expand Down Expand Up @@ -543,7 +543,7 @@ def six_hump_camel(x: Float[Array, "N 2"]) -> Float[Array, "N 1"]:
kernel = gpx.kernels.Matern52(
active_dims=[0, 1], lengthscale=jnp.array([1.0, 1.0]), variance=2.0
)
prior = gpx.Prior(mean_function=mean, kernel=kernel)
prior = gpx.gps.Prior(mean_function=mean, kernel=kernel)
opt_posterior = return_optimised_posterior(D, prior, subkey)

# Draw a sample from the posterior, and find the minimiser of it
Expand All @@ -561,7 +561,8 @@ def six_hump_camel(x: Float[Array, "N 2"]) -> Float[Array, "N 1"]:
# Evaluate the black-box function at the best point observed so far, and add it to the dataset
y_star = six_hump_camel(x_star)
print(
f"BO Iteration: {i + 1}, Queried Point: {x_star}, Black-Box Function Value: {y_star}"
f"BO Iteration: {i + 1}, Queried Point: {x_star}, Black-Box Function Value:"
f" {y_star}"
)
D = D + gpx.Dataset(X=x_star, y=y_star)
bo_experiment_results.append(D)
Expand Down
12 changes: 6 additions & 6 deletions docs/examples/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,10 @@
# choose a Bernoulli likelihood with a probit link function.

# %%
kernel = gpx.RBF()
meanf = gpx.Constant()
prior = gpx.Prior(mean_function=meanf, kernel=kernel)
likelihood = gpx.Bernoulli(num_datapoints=D.n)
kernel = gpx.kernels.RBF()
meanf = gpx.mean_functions.Constant()
prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
likelihood = gpx.likelihoods.Bernoulli(num_datapoints=D.n)

# %% [markdown]
# We construct the posterior through the product of our prior and likelihood.
Expand All @@ -116,7 +116,7 @@
# Optax's optimisers.

# %%
negative_lpd = jax.jit(gpx.LogPosteriorDensity(negative=True))
negative_lpd = jax.jit(gpx.objectives.LogPosteriorDensity(negative=True))

optimiser = ox.adam(learning_rate=0.01)

Expand Down Expand Up @@ -345,7 +345,7 @@ def construct_laplace(test_inputs: Float[Array, "N D"]) -> tfd.MultivariateNorma
num_adapt = 500
num_samples = 500

lpd = jax.jit(gpx.LogPosteriorDensity(negative=False))
lpd = jax.jit(gpx.objectives.LogPosteriorDensity(negative=False))
unconstrained_lpd = jax.jit(lambda tree: lpd(tree.constrain(), D))

adapt = blackjax.window_adaptation(
Expand Down
24 changes: 13 additions & 11 deletions docs/examples/collapsed_vi.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,10 @@
# this, it is intractable to evaluate.

# %%
meanf = gpx.Constant()
kernel = gpx.RBF()
likelihood = gpx.Gaussian(num_datapoints=D.n)
prior = gpx.Prior(mean_function=meanf, kernel=kernel)
meanf = gpx.mean_functions.Constant()
kernel = gpx.kernels.RBF()
likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n)
prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
posterior = prior * likelihood

# %% [markdown]
Expand All @@ -119,15 +119,17 @@
# inducing points into the constructor as arguments.

# %%
q = gpx.CollapsedVariationalGaussian(posterior=posterior, inducing_inputs=z)
q = gpx.variational_families.CollapsedVariationalGaussian(
posterior=posterior, inducing_inputs=z
)

# %% [markdown]
# We define our variational inference algorithm through `CollapsedVI`. This defines
# the collapsed variational free energy bound considered in
# <strong data-cite="titsias2009">Titsias (2009)</strong>.

# %%
elbo = gpx.CollapsedELBO(negative=True)
elbo = gpx.objectives.CollapsedELBO(negative=True)

# %% [markdown]
# For researchers, GPJax has the capacity to print the bibtex citation for objects such
Expand Down Expand Up @@ -241,14 +243,14 @@
# full model.

# %%
full_rank_model = gpx.Prior(mean_function=gpx.Zero(), kernel=gpx.RBF()) * gpx.Gaussian(
num_datapoints=D.n
)
negative_mll = jit(gpx.ConjugateMLL(negative=True).step)
full_rank_model = gpx.gps.Prior(
mean_function=gpx.mean_functions.Zero(), kernel=gpx.kernels.RBF()
) * gpx.likelihoods.Gaussian(num_datapoints=D.n)
negative_mll = jit(gpx.objectives.ConjugateMLL(negative=True).step)
# %timeit negative_mll(full_rank_model, D).block_until_ready()

# %%
negative_elbo = jit(gpx.CollapsedELBO(negative=True).step)
negative_elbo = jit(gpx.objectives.CollapsedELBO(negative=True).step)
# %timeit negative_elbo(q, D).block_until_ready()

# %% [markdown]
Expand Down
8 changes: 4 additions & 4 deletions docs/examples/constructing_new_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
meanf = gpx.mean_functions.Zero()

for k, ax in zip(kernels, axes.ravel()):
prior = gpx.Prior(mean_function=meanf, kernel=k)
prior = gpx.gps.Prior(mean_function=meanf, kernel=k)
rv = prior(x)
y = rv.sample(seed=key, sample_shape=(10,))
ax.plot(x, y.T, alpha=0.7)
Expand Down Expand Up @@ -263,13 +263,13 @@ def __call__(
# Define polar Gaussian process
PKern = Polar()
meanf = gpx.mean_functions.Zero()
likelihood = gpx.Gaussian(num_datapoints=n)
circular_posterior = gpx.Prior(mean_function=meanf, kernel=PKern) * likelihood
likelihood = gpx.likelihoods.Gaussian(num_datapoints=n)
circular_posterior = gpx.gps.Prior(mean_function=meanf, kernel=PKern) * likelihood

# Optimise GP's marginal log-likelihood using BFGS
opt_posterior, history = gpx.fit_scipy(
model=circular_posterior,
objective=jit(gpx.ConjugateMLL(negative=True)),
objective=jit(gpx.objectives.ConjugateMLL(negative=True)),
train_data=D,
)

Expand Down
10 changes: 5 additions & 5 deletions docs/examples/decision_making.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,9 @@ def forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]:
# mean function and kernel for the job at hand:

# %%
mean = gpx.Zero()
kernel = gpx.Matern52()
prior = gpx.Prior(mean_function=mean, kernel=kernel)
mean = gpx.mean_functions.Zero()
kernel = gpx.kernels.Matern52()
prior = gpx.gps.Prior(mean_function=mean, kernel=kernel)

# %% [markdown]
# One difference from GPJax is the way in which we define our likelihood. In GPJax, we
Expand All @@ -153,7 +153,7 @@ def forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]:
# with the correct number of datapoints:

# %%
likelihood_builder = lambda n: gpx.Gaussian(
likelihood_builder = lambda n: gpx.likelihoods.Gaussian(
num_datapoints=n, obs_stddev=jnp.array(1e-3)
)

Expand All @@ -174,7 +174,7 @@ def forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]:
posterior_handler = PosteriorHandler(
prior,
likelihood_builder=likelihood_builder,
optimization_objective=gpx.ConjugateMLL(negative=True),
optimization_objective=gpx.objectives.ConjugateMLL(negative=True),
optimizer=ox.adam(learning_rate=0.01),
num_optimization_iters=1000,
)
Expand Down
10 changes: 5 additions & 5 deletions docs/examples/deep_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,16 +163,16 @@ def __call__(self, x):
# kernel and assume a Gaussian likelihood.

# %%
base_kernel = gpx.Matern52(
base_kernel = gpx.kernels.Matern52(
active_dims=list(range(feature_space_dim)),
lengthscale=jnp.ones((feature_space_dim,)),
)
kernel = DeepKernelFunction(
network=forward_linear, base_kernel=base_kernel, key=key, dummy_x=x
)
meanf = gpx.Zero()
prior = gpx.Prior(mean_function=meanf, kernel=kernel)
likelihood = gpx.Gaussian(num_datapoints=D.n)
meanf = gpx.mean_functions.Zero()
prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n)
posterior = prior * likelihood
# %% [markdown]
# ### Optimisation
Expand Down Expand Up @@ -207,7 +207,7 @@ def __call__(self, x):

opt_posterior, history = gpx.fit(
model=posterior,
objective=jax.jit(gpx.ConjugateMLL(negative=True)),
objective=jax.jit(gpx.objectives.ConjugateMLL(negative=True)),
train_data=D,
optim=optimiser,
num_iters=800,
Expand Down
10 changes: 5 additions & 5 deletions docs/examples/graph_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,13 @@
# %%
x = jnp.arange(G.number_of_nodes()).reshape(-1, 1)

true_kernel = gpx.GraphKernel(
true_kernel = gpx.kernels.GraphKernel(
laplacian=L,
lengthscale=2.3,
variance=3.2,
smoothness=6.1,
)
prior = gpx.Prior(mean_function=gpx.Zero(), kernel=true_kernel)
prior = gpx.gps.Prior(mean_function=gpx.mean_functions.Zero(), kernel=true_kernel)

fx = prior(x)
y = fx.sample(seed=key, sample_shape=(1,)).reshape(-1, 1)
Expand Down Expand Up @@ -136,9 +136,9 @@
# We do this using the BFGS optimiser provided in `scipy` via 'jaxopt'.

# %%
likelihood = gpx.Gaussian(num_datapoints=D.n)
kernel = gpx.GraphKernel(laplacian=L)
prior = gpx.Prior(mean_function=gpx.Zero(), kernel=kernel)
likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n)
kernel = gpx.kernels.GraphKernel(laplacian=L)
prior = gpx.gps.Prior(mean_function=gpx.mean_functions.Zero(), kernel=kernel)
posterior = prior * likelihood

# %% [markdown]
Expand Down
Loading

0 comments on commit 07d99db

Please sign in to comment.