diff --git a/README.md b/README.md index 8ae2b6d1..3917f14d 100644 --- a/README.md +++ b/README.md @@ -135,10 +135,10 @@ D = gpx.Dataset(X=x, y=y) # Construct the prior meanf = gpx.mean_functions.Zero() kernel = gpx.kernels.RBF() -prior = gpx.Prior(mean_function=meanf, kernel=kernel) +prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel) # Define a likelihood -likelihood = gpx.Gaussian(num_datapoints=n) +likelihood = gpx.likelihoods.Gaussian(num_datapoints = n) # Construct the posterior posterior = prior * likelihood diff --git a/benchmarks/objectives.py b/benchmarks/objectives.py index dd217b42..65b8569c 100644 --- a/benchmarks/objectives.py +++ b/benchmarks/objectives.py @@ -22,7 +22,7 @@ def setup(self, n_datapoints: int, n_dims: int): self.data = gpx.Dataset(X=self.X, y=self.y) kernel = gpx.kernels.RBF(active_dims=list(range(n_dims))) meanf = gpx.mean_functions.Constant() - self.prior = gpx.Prior(kernel=kernel, mean_function=meanf) + self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf) self.likelihood = gpx.likelihoods.Gaussian(num_datapoints=self.data.n) self.objective = gpx.ConjugateMLL() self.posterior = self.prior * self.likelihood @@ -48,7 +48,7 @@ def setup(self, n_datapoints: int, n_dims: int): self.data = gpx.Dataset(X=self.X, y=self.y) kernel = gpx.kernels.RBF(active_dims=list(range(n_dims))) meanf = gpx.mean_functions.Constant() - self.prior = gpx.Prior(kernel=kernel, mean_function=meanf) + self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf) self.likelihood = gpx.likelihoods.Bernoulli(num_datapoints=self.data.n) self.objective = gpx.LogPosteriorDensity() self.posterior = self.prior * self.likelihood @@ -75,7 +75,7 @@ def setup(self, n_datapoints: int, n_dims: int): self.data = gpx.Dataset(X=self.X, y=self.y) kernel = gpx.kernels.RBF(active_dims=list(range(n_dims))) meanf = gpx.mean_functions.Constant() - self.prior = gpx.Prior(kernel=kernel, mean_function=meanf) + self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf) self.likelihood = gpx.likelihoods.Poisson(num_datapoints=self.data.n) self.objective = gpx.LogPosteriorDensity() self.posterior = self.prior * self.likelihood diff --git a/benchmarks/predictions.py b/benchmarks/predictions.py index eed35d66..a3dd4fa8 100644 --- a/benchmarks/predictions.py +++ b/benchmarks/predictions.py @@ -21,7 +21,7 @@ def setup(self, n_test: int, n_dims: int): self.data = gpx.Dataset(X=self.X, y=self.y) kernel = gpx.kernels.RBF(active_dims=list(range(n_dims))) meanf = gpx.mean_functions.Constant() - self.prior = gpx.Prior(kernel=kernel, mean_function=meanf) + self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf) self.likelihood = gpx.likelihoods.Gaussian(num_datapoints=self.data.n) self.posterior = self.prior * self.likelihood key, subkey = jr.split(key) @@ -46,7 +46,7 @@ def setup(self, n_test: int, n_dims: int): self.data = gpx.Dataset(X=self.X, y=self.y) kernel = gpx.kernels.RBF(active_dims=list(range(n_dims))) meanf = gpx.mean_functions.Constant() - self.prior = gpx.Prior(kernel=kernel, mean_function=meanf) + self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf) self.likelihood = gpx.likelihoods.Bernoulli(num_datapoints=self.data.n) self.posterior = self.prior * self.likelihood key, subkey = jr.split(key) @@ -71,7 +71,7 @@ def setup(self, n_test: int, n_dims: int): self.data = gpx.Dataset(X=self.X, y=self.y) kernel = gpx.kernels.RBF(active_dims=list(range(n_dims))) meanf = gpx.mean_functions.Constant() - self.prior = gpx.Prior(kernel=kernel, mean_function=meanf) + self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf) self.likelihood = gpx.likelihoods.Bernoulli(num_datapoints=self.data.n) self.posterior = self.prior * self.likelihood key, subkey = jr.split(key) diff --git a/benchmarks/sparse.py b/benchmarks/sparse.py index 759cac9b..de7d4705 100644 --- a/benchmarks/sparse.py +++ b/benchmarks/sparse.py @@ -19,7 +19,7 @@ def setup(self, n_datapoints: int, n_inducing: int): self.data = gpx.Dataset(X=self.X, y=self.y) kernel = gpx.kernels.RBF(active_dims=list(range(1))) meanf = gpx.mean_functions.Constant() - self.prior = gpx.Prior(kernel=kernel, mean_function=meanf) + self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf) self.likelihood = gpx.likelihoods.Gaussian(num_datapoints=self.data.n) self.posterior = self.prior * self.likelihood diff --git a/benchmarks/stochastic.py b/benchmarks/stochastic.py index 14681535..1e530c73 100644 --- a/benchmarks/stochastic.py +++ b/benchmarks/stochastic.py @@ -20,7 +20,7 @@ def setup(self, n_datapoints: int, n_inducing: int, batch_size: int): self.data = gpx.Dataset(X=self.X, y=self.y) kernel = gpx.kernels.RBF(active_dims=list(range(1))) meanf = gpx.mean_functions.Constant() - self.prior = gpx.Prior(kernel=kernel, mean_function=meanf) + self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf) self.likelihood = gpx.likelihoods.Gaussian(num_datapoints=self.data.n) self.posterior = self.prior * self.likelihood diff --git a/docs/examples/README.md b/docs/examples/README.md index a5188c35..2b7d37c1 100644 --- a/docs/examples/README.md +++ b/docs/examples/README.md @@ -67,7 +67,7 @@ class Prior(AbstractPrior): >>> >>> meanf = gpx.mean_functions.Zero() >>> kernel = gpx.kernels.RBF() - >>> prior = gpx.Prior(mean_function=meanf, kernel = kernel) + >>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel) Attributes: kernel (Kernel): The kernel function used to parameterise the prior. diff --git a/docs/examples/barycentres.py b/docs/examples/barycentres.py index 3733672f..39e12d3c 100644 --- a/docs/examples/barycentres.py +++ b/docs/examples/barycentres.py @@ -134,8 +134,13 @@ def fit_gp(x: jax.Array, y: jax.Array) -> tfd.MultivariateNormalFullCovariance: y = y.reshape(-1, 1) D = gpx.Dataset(X=x, y=y) - likelihood = gpx.Gaussian(num_datapoints=n) - posterior = gpx.Prior(mean_function=gpx.Constant(), kernel=gpx.RBF()) * likelihood + likelihood = gpx.likelihoods.Gaussian(num_datapoints=n) + posterior = ( + gpx.gps.Prior( + mean_function=gpx.mean_functions.Constant(), kernel=gpx.kernels.RBF() + ) + * likelihood + ) opt_posterior, _ = gpx.fit_scipy( model=posterior, diff --git a/docs/examples/bayesian_optimisation.py b/docs/examples/bayesian_optimisation.py index d507fbc0..a2361814 100644 --- a/docs/examples/bayesian_optimisation.py +++ b/docs/examples/bayesian_optimisation.py @@ -201,9 +201,9 @@ def forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]: # %% def return_optimised_posterior( - data: gpx.Dataset, prior: gpx.Module, key: Array -) -> gpx.Module: - likelihood = gpx.Gaussian( + data: gpx.Dataset, prior: gpx.base.Module, key: Array +) -> gpx.base.Module: + likelihood = gpx.likelihoods.Gaussian( num_datapoints=data.n, obs_stddev=jnp.array(1e-3) ) # Our function is noise-free, so we set the observation noise's standard deviation to a very small value likelihood = likelihood.replace_trainable(obs_stddev=False) @@ -230,7 +230,7 @@ def return_optimised_posterior( mean = gpx.mean_functions.Zero() kernel = gpx.kernels.Matern52() -prior = gpx.Prior(mean_function=mean, kernel=kernel) +prior = gpx.gps.Prior(mean_function=mean, kernel=kernel) opt_posterior = return_optimised_posterior(D, prior, key) # %% [markdown] @@ -315,7 +315,7 @@ def optimise_sample( # %% def plot_bayes_opt( - posterior: gpx.Module, + posterior: gpx.base.Module, sample: FunctionalSample, dataset: gpx.Dataset, queried_x: ScalarFloat, @@ -401,7 +401,7 @@ def plot_bayes_opt( # Generate optimised posterior using previously observed data mean = gpx.mean_functions.Zero() kernel = gpx.kernels.Matern52() - prior = gpx.Prior(mean_function=mean, kernel=kernel) + prior = gpx.gps.Prior(mean_function=mean, kernel=kernel) opt_posterior = return_optimised_posterior(D, prior, subkey) # Draw a sample from the posterior, and find the minimiser of it @@ -543,7 +543,7 @@ def six_hump_camel(x: Float[Array, "N 2"]) -> Float[Array, "N 1"]: kernel = gpx.kernels.Matern52( active_dims=[0, 1], lengthscale=jnp.array([1.0, 1.0]), variance=2.0 ) - prior = gpx.Prior(mean_function=mean, kernel=kernel) + prior = gpx.gps.Prior(mean_function=mean, kernel=kernel) opt_posterior = return_optimised_posterior(D, prior, subkey) # Draw a sample from the posterior, and find the minimiser of it @@ -561,7 +561,8 @@ def six_hump_camel(x: Float[Array, "N 2"]) -> Float[Array, "N 1"]: # Evaluate the black-box function at the best point observed so far, and add it to the dataset y_star = six_hump_camel(x_star) print( - f"BO Iteration: {i + 1}, Queried Point: {x_star}, Black-Box Function Value: {y_star}" + f"BO Iteration: {i + 1}, Queried Point: {x_star}, Black-Box Function Value:" + f" {y_star}" ) D = D + gpx.Dataset(X=x_star, y=y_star) bo_experiment_results.append(D) diff --git a/docs/examples/classification.py b/docs/examples/classification.py index 818a5272..4b955236 100644 --- a/docs/examples/classification.py +++ b/docs/examples/classification.py @@ -89,10 +89,10 @@ # choose a Bernoulli likelihood with a probit link function. # %% -kernel = gpx.RBF() -meanf = gpx.Constant() -prior = gpx.Prior(mean_function=meanf, kernel=kernel) -likelihood = gpx.Bernoulli(num_datapoints=D.n) +kernel = gpx.kernels.RBF() +meanf = gpx.mean_functions.Constant() +prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel) +likelihood = gpx.likelihoods.Bernoulli(num_datapoints=D.n) # %% [markdown] # We construct the posterior through the product of our prior and likelihood. @@ -116,7 +116,7 @@ # Optax's optimisers. # %% -negative_lpd = jax.jit(gpx.LogPosteriorDensity(negative=True)) +negative_lpd = jax.jit(gpx.objectives.LogPosteriorDensity(negative=True)) optimiser = ox.adam(learning_rate=0.01) @@ -345,7 +345,7 @@ def construct_laplace(test_inputs: Float[Array, "N D"]) -> tfd.MultivariateNorma num_adapt = 500 num_samples = 500 -lpd = jax.jit(gpx.LogPosteriorDensity(negative=False)) +lpd = jax.jit(gpx.objectives.LogPosteriorDensity(negative=False)) unconstrained_lpd = jax.jit(lambda tree: lpd(tree.constrain(), D)) adapt = blackjax.window_adaptation( diff --git a/docs/examples/collapsed_vi.py b/docs/examples/collapsed_vi.py index d497997b..6ee7d81d 100644 --- a/docs/examples/collapsed_vi.py +++ b/docs/examples/collapsed_vi.py @@ -106,10 +106,10 @@ # this, it is intractable to evaluate. # %% -meanf = gpx.Constant() -kernel = gpx.RBF() -likelihood = gpx.Gaussian(num_datapoints=D.n) -prior = gpx.Prior(mean_function=meanf, kernel=kernel) +meanf = gpx.mean_functions.Constant() +kernel = gpx.kernels.RBF() +likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n) +prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel) posterior = prior * likelihood # %% [markdown] @@ -119,7 +119,9 @@ # inducing points into the constructor as arguments. # %% -q = gpx.CollapsedVariationalGaussian(posterior=posterior, inducing_inputs=z) +q = gpx.variational_families.CollapsedVariationalGaussian( + posterior=posterior, inducing_inputs=z +) # %% [markdown] # We define our variational inference algorithm through `CollapsedVI`. This defines @@ -127,7 +129,7 @@ # Titsias (2009). # %% -elbo = gpx.CollapsedELBO(negative=True) +elbo = gpx.objectives.CollapsedELBO(negative=True) # %% [markdown] # For researchers, GPJax has the capacity to print the bibtex citation for objects such @@ -241,14 +243,14 @@ # full model. # %% -full_rank_model = gpx.Prior(mean_function=gpx.Zero(), kernel=gpx.RBF()) * gpx.Gaussian( - num_datapoints=D.n -) -negative_mll = jit(gpx.ConjugateMLL(negative=True).step) +full_rank_model = gpx.gps.Prior( + mean_function=gpx.mean_functions.Zero(), kernel=gpx.kernels.RBF() +) * gpx.likelihoods.Gaussian(num_datapoints=D.n) +negative_mll = jit(gpx.objectives.ConjugateMLL(negative=True).step) # %timeit negative_mll(full_rank_model, D).block_until_ready() # %% -negative_elbo = jit(gpx.CollapsedELBO(negative=True).step) +negative_elbo = jit(gpx.objectives.CollapsedELBO(negative=True).step) # %timeit negative_elbo(q, D).block_until_ready() # %% [markdown] diff --git a/docs/examples/constructing_new_kernels.py b/docs/examples/constructing_new_kernels.py index 27001f38..073015ca 100644 --- a/docs/examples/constructing_new_kernels.py +++ b/docs/examples/constructing_new_kernels.py @@ -89,7 +89,7 @@ meanf = gpx.mean_functions.Zero() for k, ax in zip(kernels, axes.ravel()): - prior = gpx.Prior(mean_function=meanf, kernel=k) + prior = gpx.gps.Prior(mean_function=meanf, kernel=k) rv = prior(x) y = rv.sample(seed=key, sample_shape=(10,)) ax.plot(x, y.T, alpha=0.7) @@ -263,13 +263,13 @@ def __call__( # Define polar Gaussian process PKern = Polar() meanf = gpx.mean_functions.Zero() -likelihood = gpx.Gaussian(num_datapoints=n) -circular_posterior = gpx.Prior(mean_function=meanf, kernel=PKern) * likelihood +likelihood = gpx.likelihoods.Gaussian(num_datapoints=n) +circular_posterior = gpx.gps.Prior(mean_function=meanf, kernel=PKern) * likelihood # Optimise GP's marginal log-likelihood using BFGS opt_posterior, history = gpx.fit_scipy( model=circular_posterior, - objective=jit(gpx.ConjugateMLL(negative=True)), + objective=jit(gpx.objectives.ConjugateMLL(negative=True)), train_data=D, ) diff --git a/docs/examples/decision_making.py b/docs/examples/decision_making.py index 0bf08e60..66dc77a9 100644 --- a/docs/examples/decision_making.py +++ b/docs/examples/decision_making.py @@ -136,9 +136,9 @@ def forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]: # mean function and kernel for the job at hand: # %% -mean = gpx.Zero() -kernel = gpx.Matern52() -prior = gpx.Prior(mean_function=mean, kernel=kernel) +mean = gpx.mean_functions.Zero() +kernel = gpx.kernels.Matern52() +prior = gpx.gps.Prior(mean_function=mean, kernel=kernel) # %% [markdown] # One difference from GPJax is the way in which we define our likelihood. In GPJax, we @@ -153,7 +153,7 @@ def forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]: # with the correct number of datapoints: # %% -likelihood_builder = lambda n: gpx.Gaussian( +likelihood_builder = lambda n: gpx.likelihoods.Gaussian( num_datapoints=n, obs_stddev=jnp.array(1e-3) ) @@ -174,7 +174,7 @@ def forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]: posterior_handler = PosteriorHandler( prior, likelihood_builder=likelihood_builder, - optimization_objective=gpx.ConjugateMLL(negative=True), + optimization_objective=gpx.objectives.ConjugateMLL(negative=True), optimizer=ox.adam(learning_rate=0.01), num_optimization_iters=1000, ) diff --git a/docs/examples/deep_kernels.py b/docs/examples/deep_kernels.py index 72eae15a..97123fbd 100644 --- a/docs/examples/deep_kernels.py +++ b/docs/examples/deep_kernels.py @@ -163,16 +163,16 @@ def __call__(self, x): # kernel and assume a Gaussian likelihood. # %% -base_kernel = gpx.Matern52( +base_kernel = gpx.kernels.Matern52( active_dims=list(range(feature_space_dim)), lengthscale=jnp.ones((feature_space_dim,)), ) kernel = DeepKernelFunction( network=forward_linear, base_kernel=base_kernel, key=key, dummy_x=x ) -meanf = gpx.Zero() -prior = gpx.Prior(mean_function=meanf, kernel=kernel) -likelihood = gpx.Gaussian(num_datapoints=D.n) +meanf = gpx.mean_functions.Zero() +prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel) +likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n) posterior = prior * likelihood # %% [markdown] # ### Optimisation @@ -207,7 +207,7 @@ def __call__(self, x): opt_posterior, history = gpx.fit( model=posterior, - objective=jax.jit(gpx.ConjugateMLL(negative=True)), + objective=jax.jit(gpx.objectives.ConjugateMLL(negative=True)), train_data=D, optim=optimiser, num_iters=800, diff --git a/docs/examples/graph_kernels.py b/docs/examples/graph_kernels.py index f64c7b38..9344e340 100644 --- a/docs/examples/graph_kernels.py +++ b/docs/examples/graph_kernels.py @@ -94,13 +94,13 @@ # %% x = jnp.arange(G.number_of_nodes()).reshape(-1, 1) -true_kernel = gpx.GraphKernel( +true_kernel = gpx.kernels.GraphKernel( laplacian=L, lengthscale=2.3, variance=3.2, smoothness=6.1, ) -prior = gpx.Prior(mean_function=gpx.Zero(), kernel=true_kernel) +prior = gpx.gps.Prior(mean_function=gpx.mean_functions.Zero(), kernel=true_kernel) fx = prior(x) y = fx.sample(seed=key, sample_shape=(1,)).reshape(-1, 1) @@ -136,9 +136,9 @@ # We do this using the BFGS optimiser provided in `scipy` via 'jaxopt'. # %% -likelihood = gpx.Gaussian(num_datapoints=D.n) -kernel = gpx.GraphKernel(laplacian=L) -prior = gpx.Prior(mean_function=gpx.Zero(), kernel=kernel) +likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n) +kernel = gpx.kernels.GraphKernel(laplacian=L) +prior = gpx.gps.Prior(mean_function=gpx.mean_functions.Zero(), kernel=kernel) posterior = prior * likelihood # %% [markdown] diff --git a/docs/examples/intro_to_kernels.py b/docs/examples/intro_to_kernels.py index 60e76a20..55113be6 100644 --- a/docs/examples/intro_to_kernels.py +++ b/docs/examples/intro_to_kernels.py @@ -160,7 +160,7 @@ meanf = gpx.mean_functions.Zero() for k, ax in zip(kernels, axes.ravel()): - prior = gpx.Prior(mean_function=meanf, kernel=k) + prior = gpx.gps.Prior(mean_function=meanf, kernel=k) rv = prior(x) y = rv.sample(seed=key, sample_shape=(10,)) ax.plot(x, y.T, alpha=0.7) @@ -219,9 +219,9 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]: lengthscale=jnp.array(0.1) ) # Initialise our kernel lengthscale to 0.1 -prior = gpx.Prior(mean_function=mean, kernel=kernel) +prior = gpx.gps.Prior(mean_function=mean, kernel=kernel) -likelihood = gpx.Gaussian( +likelihood = gpx.likelihoods.Gaussian( num_datapoints=D.n, obs_stddev=jnp.array(1e-3) ) # Our function is noise-free, so we set the observation noise's standard deviation to a very small value likelihood = likelihood.replace_trainable(obs_stddev=False) @@ -352,7 +352,7 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]: # %% mean = gpx.mean_functions.Zero() kernel = gpx.kernels.Periodic() -prior = gpx.Prior(mean_function=mean, kernel=kernel) +prior = gpx.gps.Prior(mean_function=mean, kernel=kernel) x = jnp.linspace(-3.0, 3.0, num=200).reshape(-1, 1) rv = prior(x) @@ -375,7 +375,7 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]: # %% mean = gpx.mean_functions.Zero() kernel = gpx.kernels.Linear() -prior = gpx.Prior(mean_function=mean, kernel=kernel) +prior = gpx.gps.Prior(mean_function=mean, kernel=kernel) x = jnp.linspace(-3.0, 3.0, num=200).reshape(-1, 1) rv = prior(x) @@ -411,7 +411,7 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]: kernel_two = gpx.kernels.Periodic() sum_kernel = gpx.kernels.SumKernel(kernels=[kernel_one, kernel_two]) mean = gpx.mean_functions.Zero() -prior = gpx.Prior(mean_function=mean, kernel=sum_kernel) +prior = gpx.gps.Prior(mean_function=mean, kernel=sum_kernel) x = jnp.linspace(-3.0, 3.0, num=200).reshape(-1, 1) rv = prior(x) @@ -436,7 +436,7 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]: kernel_two = gpx.kernels.Periodic() sum_kernel = gpx.kernels.ProductKernel(kernels=[kernel_one, kernel_two]) mean = gpx.mean_functions.Zero() -prior = gpx.Prior(mean_function=mean, kernel=sum_kernel) +prior = gpx.gps.Prior(mean_function=mean, kernel=sum_kernel) x = jnp.linspace(-3.0, 3.0, num=200).reshape(-1, 1) rv = prior(x) @@ -522,8 +522,8 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]: sum_kernel = gpx.kernels.SumKernel(kernels=[linear_kernel, periodic_kernel]) final_kernel = gpx.kernels.SumKernel(kernels=[rbf_kernel, sum_kernel]) -prior = gpx.Prior(mean_function=mean, kernel=final_kernel) -likelihood = gpx.Gaussian(num_datapoints=D.n) +prior = gpx.gps.Prior(mean_function=mean, kernel=final_kernel) +likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n) posterior = prior * likelihood @@ -645,7 +645,8 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]: # %% print( - f"Periodic Kernel Period: {[i for i in opt_posterior.prior.kernel.kernels if isinstance(i, gpx.kernels.Periodic)][0].period}" + "Periodic Kernel Period:" + f" {[i for i in opt_posterior.prior.kernel.kernels if isinstance(i, gpx.kernels.Periodic)][0].period}" ) # %% [markdown] diff --git a/docs/examples/likelihoods_guide.py b/docs/examples/likelihoods_guide.py index 6fd79521..9199d9e3 100644 --- a/docs/examples/likelihoods_guide.py +++ b/docs/examples/likelihoods_guide.py @@ -124,11 +124,11 @@ # $\mathbf{y}^{\star}$. # + -kernel = gpx.Matern32() -meanf = gpx.Zero() -prior = gpx.Prior(kernel=kernel, mean_function=meanf) +kernel = gpx.kernels.Matern32() +meanf = gpx.mean_functions.Zero() +prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf) -likelihood = gpx.Gaussian(num_datapoints=D.n, obs_stddev=0.1) +likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n, obs_stddev=0.1) posterior = prior * likelihood @@ -158,7 +158,7 @@ # Similarly, for a Bernoulli likelihood function, the samples of $y$ would be binary. # + -likelihood = gpx.Bernoulli(num_datapoints=D.n) +likelihood = gpx.likelihoods.Bernoulli(num_datapoints=D.n) fig, axes = plt.subplots(ncols=3, nrows=1, figsize=(9, 2)) @@ -231,7 +231,7 @@ # + z = jnp.linspace(-3.0, 3.0, 10).reshape(-1, 1) -q = gpx.VariationalGaussian(posterior=posterior, inducing_inputs=z) +q = gpx.variational_families.VariationalGaussian(posterior=posterior, inducing_inputs=z) def q_moments(x): @@ -251,7 +251,7 @@ def q_moments(x): # However, had we wanted to do this using quadrature, then we would have done the # following: -lquad = gpx.Gaussian( +lquad = gpx.likelihoods.Gaussian( num_datapoints=D.n, obs_stddev=jnp.array([0.1]), integrator=gpx.integrators.GHQuadratureIntegrator(num_points=20), diff --git a/docs/examples/oceanmodelling.py b/docs/examples/oceanmodelling.py index 162fba0b..f917ec04 100644 --- a/docs/examples/oceanmodelling.py +++ b/docs/examples/oceanmodelling.py @@ -223,8 +223,8 @@ def __call__( # %% def initialise_gp(kernel, mean, dataset): - prior = gpx.Prior(mean_function=mean, kernel=kernel) - likelihood = gpx.Gaussian( + prior = gpx.gps.Prior(mean_function=mean, kernel=kernel) + likelihood = gpx.likelihoods.Gaussian( num_datapoints=dataset.n, obs_stddev=jnp.array([1.0e-3], dtype=jnp.float64) ) posterior = prior * likelihood diff --git a/docs/examples/poisson.py b/docs/examples/poisson.py index 3543b9da..61ff0f61 100644 --- a/docs/examples/poisson.py +++ b/docs/examples/poisson.py @@ -83,10 +83,10 @@ # kernel, chosen for the purpose of exposition. We adopt the Poisson likelihood available in GPJax. # %% -kernel = gpx.RBF() -meanf = gpx.Constant() -prior = gpx.Prior(mean_function=meanf, kernel=kernel) -likelihood = gpx.Poisson(num_datapoints=D.n) +kernel = gpx.kernels.RBF() +meanf = gpx.mean_functions.Constant() +prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel) +likelihood = gpx.likelihoods.Poisson(num_datapoints=D.n) # %% [markdown] # We construct the posterior through the product of our prior and likelihood. @@ -135,7 +135,7 @@ num_adapt = 100 num_samples = 200 -lpd = jax.jit(gpx.LogPosteriorDensity(negative=False)) +lpd = jax.jit(gpx.objectives.LogPosteriorDensity(negative=False)) unconstrained_lpd = jax.jit(lambda tree: lpd(tree.constrain(), D)) adapt = blackjax.window_adaptation( diff --git a/docs/examples/regression.py b/docs/examples/regression.py index 47b8439a..f1fc77a3 100644 --- a/docs/examples/regression.py +++ b/docs/examples/regression.py @@ -108,7 +108,7 @@ # %% kernel = gpx.kernels.RBF() meanf = gpx.mean_functions.Zero() -prior = gpx.Prior(mean_function=meanf, kernel=kernel) +prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel) # %% [markdown] # @@ -152,7 +152,7 @@ # This is defined in GPJax through calling a `Gaussian` instance. # %% -likelihood = gpx.Gaussian(num_datapoints=D.n) +likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n) # %% [markdown] # The posterior is proportional to the prior multiplied by the likelihood, written as diff --git a/docs/examples/uncollapsed_vi.py b/docs/examples/uncollapsed_vi.py index 5511a9f1..0c099464 100644 --- a/docs/examples/uncollapsed_vi.py +++ b/docs/examples/uncollapsed_vi.py @@ -203,10 +203,10 @@ # %% meanf = gpx.mean_functions.Zero() -likelihood = gpx.Gaussian(num_datapoints=n) -prior = gpx.Prior(mean_function=meanf, kernel=jk.RBF()) +likelihood = gpx.likelihoods.Gaussian(num_datapoints=n) +prior = gpx.gps.Prior(mean_function=meanf, kernel=jk.RBF()) p = prior * likelihood -q = gpx.VariationalGaussian(posterior=p, inducing_inputs=z) +q = gpx.variational_families.VariationalGaussian(posterior=p, inducing_inputs=z) # %% [markdown] # Here, the variational process $q(\cdot)$ depends on the prior through @@ -232,7 +232,7 @@ # its negative. # %% -negative_elbo = gpx.ELBO(negative=True) +negative_elbo = gpx.objectives.ELBO(negative=True) # %% [markdown] # For researchers, GPJax has the capacity to print the bibtex citation for objects such diff --git a/docs/examples/yacht.py b/docs/examples/yacht.py index 59b7f8c9..dc40f514 100644 --- a/docs/examples/yacht.py +++ b/docs/examples/yacht.py @@ -173,9 +173,9 @@ lengthscale=0.1 * np.ones((n_covariates,)), ) meanf = gpx.mean_functions.Zero() -prior = gpx.Prior(mean_function=meanf, kernel=kernel) +prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel) -likelihood = gpx.Gaussian(num_datapoints=n_train) +likelihood = gpx.likelihoods.Gaussian(num_datapoints=n_train) posterior = prior * likelihood diff --git a/gpjax/__init__.py b/gpjax/__init__.py index 6e45ae72..0fa71556 100644 --- a/gpjax/__init__.py +++ b/gpjax/__init__.py @@ -13,8 +13,15 @@ # limitations under the License. # ============================================================================== from gpjax import ( + base, decision_making, + gps, integrators, + kernels, + likelihoods, + mean_functions, + objectives, + variational_families, ) from gpjax.base import ( Module, @@ -26,111 +33,27 @@ fit, fit_scipy, ) -from gpjax.gps import ( - Prior, - construct_posterior, -) -from gpjax.kernels import ( - RBF, - RFF, - AbstractKernel, - BasisFunctionComputation, - ConstantDiagonalKernelComputation, - DenseKernelComputation, - DiagonalKernelComputation, - EigenKernelComputation, - GraphKernel, - Linear, - Matern12, - Matern32, - Matern52, - Periodic, - Polynomial, - PoweredExponential, - ProductKernel, - RationalQuadratic, - SumKernel, - White, -) -from gpjax.likelihoods import ( - Bernoulli, - Gaussian, - Poisson, -) -from gpjax.mean_functions import ( - Constant, - Zero, -) -from gpjax.objectives import ( - ELBO, - CollapsedELBO, - ConjugateLOOCV, - ConjugateMLL, - LogPosteriorDensity, - NonConjugateMLL, -) -from gpjax.variational_families import ( - CollapsedVariationalGaussian, - ExpectationVariationalGaussian, - NaturalVariationalGaussian, - VariationalGaussian, - WhitenedVariationalGaussian, -) __license__ = "MIT" __description__ = "Didactic Gaussian processes in JAX" __url__ = "https://github.com/JaxGaussianProcesses/GPJax" __contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors" -__version__ = "0.7.3" +__version__ = "0.8.0" __all__ = [ - "Module", - "param_field", - "cite", + "base", "decision_making", + "gps", + "integrators", "kernels", + "likelihoods", + "mean_functions", + "objectives", + "variational_families", + "Dataset", + "cite", "fit", + "Module", + "param_field", "fit_scipy", - "Prior", - "construct_posterior", - "integrators", - "RBF", - "GraphKernel", - "Matern12", - "Matern32", - "Matern52", - "Polynomial", - "ProductKernel", - "SumKernel", - "Bernoulli", - "Gaussian", - "Poisson", - "Constant", - "Zero", - "Dataset", - "CollapsedVariationalGaussian", - "ExpectationVariationalGaussian", - "NaturalVariationalGaussian", - "VariationalGaussian", - "WhitenedVariationalGaussian", - "CollapsedVI", - "StochasticVI", - "ConjugateMLL", - "ConjugateLOOCV", - "NonConjugateMLL", - "LogPosteriorDensity", - "CollapsedELBO", - "ELBO", - "AbstractKernel", - "Linear", - "DenseKernelComputation", - "DiagonalKernelComputation", - "ConstantDiagonalKernelComputation", - "EigenKernelComputation", - "PoweredExponential", - "Periodic", - "RationalQuadratic", - "White", - "BasisFunctionComputation", - "RFF", ] diff --git a/gpjax/fit.py b/gpjax/fit.py index 6986549b..9cdcdcd3 100644 --- a/gpjax/fit.py +++ b/gpjax/fit.py @@ -76,9 +76,9 @@ def fit( # noqa: PLR0913 >>> D = gpx.Dataset(X, y) >>> >>> # (2) Define your model: - >>> class LinearModel(gpx.Module): - weight: float = gpx.param_field() - bias: float = gpx.param_field() + >>> class LinearModel(gpx.base.Module): + weight: float = gpx.base.param_field() + bias: float = gpx.base.param_field() def __call__(self, x): return self.weight * x + self.bias @@ -86,7 +86,7 @@ def __call__(self, x): >>> model = LinearModel(weight=1.0, bias=1.0) >>> >>> # (3) Define your loss function: - >>> class MeanSquareError(gpx.AbstractObjective): + >>> class MeanSquareError(gpx.objectives.AbstractObjective): def evaluate(self, model: LinearModel, train_data: gpx.Dataset) -> float: return jnp.mean((train_data.y - model(train_data.X)) ** 2) >>> diff --git a/gpjax/gps.py b/gpjax/gps.py index 5928ef49..f1fa477b 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -133,7 +133,7 @@ class Prior(AbstractPrior): >>> kernel = gpx.kernels.RBF() >>> meanf = gpx.mean_functions.Zero() - >>> prior = gpx.Prior(mean_function=meanf, kernel = kernel) + >>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel) ``` """ @@ -167,7 +167,7 @@ def __mul__(self, other): >>> >>> meanf = gpx.mean_functions.Zero() >>> kernel = gpx.kernels.RBF() - >>> prior = gpx.Prior(mean_function=meanf, kernel = kernel) + >>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel) >>> likelihood = gpx.likelihoods.Gaussian(num_datapoints=100) >>> >>> prior * likelihood @@ -228,7 +228,7 @@ def predict(self, test_inputs: Num[Array, "N D"]) -> GaussianDistribution: >>> >>> kernel = gpx.kernels.RBF() >>> meanf = gpx.mean_functions.Zero() - >>> prior = gpx.Prior(mean_function=meanf, kernel = kernel) + >>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel) >>> >>> prior.predict(jnp.linspace(0, 1, 100)) ``` @@ -289,7 +289,7 @@ def sample_approx( >>> >>> meanf = gpx.mean_functions.Zero() >>> kernel = gpx.kernels.RBF() - >>> prior = gpx.Prior(mean_function=meanf, kernel = kernel) + >>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel) >>> >>> sample_fn = prior.sample_approx(10, key) >>> sample_fn(jnp.linspace(0, 1, 100).reshape(-1, 1)) @@ -413,7 +413,7 @@ class ConjugatePosterior(AbstractPosterior): >>> import gpjax as gpx >>> import jax.numpy as jnp - >>> prior = gpx.Prior( + >>> prior = gpx.gps.Prior( mean_function = gpx.mean_functions.Zero(), kernel = gpx.kernels.RBF() ) @@ -461,8 +461,8 @@ def predict( >>> D = gpx.Dataset(X=xtrain, y=ytrain) >>> xtest = jnp.linspace(0, 1).reshape(-1, 1) >>> - >>> prior = gpx.Prior(mean_function = gpx.Zero(), kernel = gpx.RBF()) - >>> posterior = prior * gpx.Gaussian(num_datapoints = D.n) + >>> prior = gpx.gps.Prior(mean_function = gpx.mean_functions.Zero(), kernel = gpx.kernels.RBF()) + >>> posterior = prior * gpx.likelihoods.Gaussian(num_datapoints = D.n) >>> predictive_dist = posterior(xtest, D) ``` diff --git a/gpjax/objectives.py b/gpjax/objectives.py index c07290c4..9b684f89 100644 --- a/gpjax/objectives.py +++ b/gpjax/objectives.py @@ -100,10 +100,10 @@ def step( >>> meanf = gpx.mean_functions.Constant() >>> kernel = gpx.kernels.RBF() >>> likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n) - >>> prior = gpx.Prior(mean_function = meanf, kernel=kernel) + >>> prior = gpx.gps.Prior(mean_function = meanf, kernel=kernel) >>> posterior = prior * likelihood >>> - >>> mll = gpx.ConjugateMLL(negative=True) + >>> mll = gpx.objectives.ConjugateMLL(negative=True) >>> mll(posterior, train_data = D) ``` @@ -112,13 +112,13 @@ def step( marginal log-likelihood. This can be realised through ```python - mll = gpx.ConjugateMLL(negative=True) + mll = gpx.objectives.ConjugateMLL(negative=True) ``` For optimal performance, the marginal log-likelihood should be ``jax.jit`` compiled. ```python - mll = jit(gpx.ConjugateMLL(negative=True)) + mll = jit(gpx.objectives.ConjugateMLL(negative=True)) ``` Args: @@ -180,10 +180,10 @@ def step( >>> meanf = gpx.mean_functions.Constant() >>> kernel = gpx.kernels.RBF() >>> likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n) - >>> prior = gpx.Prior(mean_function = meanf, kernel=kernel) + >>> prior = gpx.gps.Prior(mean_function = meanf, kernel=kernel) >>> posterior = prior * likelihood >>> - >>> loocv = gpx.ConjugateLOOCV(negative=True) + >>> loocv = gpx.objectives.ConjugateLOOCV(negative=True) >>> loocv(posterior, train_data = D) ``` @@ -192,13 +192,13 @@ def step( leave-one-out log predictive probability. This can be realised through ```python - mll = gpx.ConjugateLOOCV(negative=True) + mll = gpx.objectives.ConjugateLOOCV(negative=True) ``` For optimal performance, the objective should be ``jax.jit`` compiled. ```python - mll = jit(gpx.ConjugateLOOCV(negative=True)) + mll = jit(gpx.objectives.ConjugateLOOCV(negative=True)) ``` Args: diff --git a/gpjax/progress_bar.py b/gpjax/progress_bar.py index 090a03ea..3072b71b 100644 --- a/gpjax/progress_bar.py +++ b/gpjax/progress_bar.py @@ -36,10 +36,10 @@ def progress_bar(num_iters: int, log_rate: int) -> Callable: >>> >>> carry = jnp.array(0.0) >>> iteration_numbers = jnp.arange(100) - + >>> >>> @progress_bar(num_iters=iteration_numbers.shape[0], log_rate=10) >>> def body_func(carry, x): - ... return carry, x + >>> return carry, x >>> >>> carry, _ = jax.lax.scan(body_func, carry, iteration_numbers) ``` diff --git a/pyproject.toml b/pyproject.toml index 6e2aa453..198ccac6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "gpjax" -version = "0.7.3" +version = "0.8.0" description = "Gaussian processes in JAX." authors = [ "Thomas Pinder ", diff --git a/tests/test_citations.py b/tests/test_citations.py index b7fe8a47..fd2531f5 100644 --- a/tests/test_citations.py +++ b/tests/test_citations.py @@ -6,7 +6,6 @@ import jax.numpy as jnp import pytest -import gpjax as gpx from gpjax.citation import ( AbstractCitation, BookCitation, @@ -30,6 +29,13 @@ Matern32, Matern52, ) +from gpjax.objectives import ( + ELBO, + CollapsedELBO, + ConjugateMLL, + LogPosteriorDensity, + NonConjugateMLL, +) def _check_no_fallback(citation: AbstractCitation): @@ -103,7 +109,7 @@ def test_missing_citation(kernel): @pytest.mark.parametrize( - "mll", [gpx.ConjugateMLL(), gpx.NonConjugateMLL(), gpx.LogPosteriorDensity()] + "mll", [ConjugateMLL(), NonConjugateMLL(), LogPosteriorDensity()] ) def test_mlls(mll): citation = cite(mll) @@ -115,7 +121,7 @@ def test_mlls(mll): def test_uncollapsed_elbo(): - elbo = gpx.ELBO() + elbo = ELBO() citation = cite(elbo) assert isinstance(citation, PaperCitation) @@ -128,7 +134,7 @@ def test_uncollapsed_elbo(): def test_collapsed_elbo(): - elbo = gpx.CollapsedELBO() + elbo = CollapsedELBO() citation = cite(elbo) assert isinstance(citation, PaperCitation) @@ -158,7 +164,8 @@ def test_thompson_sampling(): ) assert ( citation.authors - == "Wilson, James and Borovitskiy, Viacheslav and Terenin, Alexander and Mostowsky, Peter and Deisenroth, Marc" + == "Wilson, James and Borovitskiy, Viacheslav and Terenin, Alexander and" + " Mostowsky, Peter and Deisenroth, Marc" ) assert citation.year == "2020" assert citation.booktitle == "International Conference on Machine Learning" @@ -205,7 +212,7 @@ def test_logarithmic_goldstein_price(): @pytest.mark.parametrize( "objective", - [gpx.ELBO(), gpx.CollapsedELBO(), gpx.LogPosteriorDensity(), gpx.ConjugateMLL()], + [ELBO(), CollapsedELBO(), LogPosteriorDensity(), ConjugateMLL()], ) def test_jitted_fallback(objective): with pytest.raises(RuntimeError): diff --git a/tests/test_decision_making/test_decision_maker.py b/tests/test_decision_making/test_decision_maker.py index 4d79f8d5..23a87a2f 100644 --- a/tests/test_decision_making/test_decision_maker.py +++ b/tests/test_decision_making/test_decision_maker.py @@ -61,16 +61,16 @@ def search_space() -> ContinuousSearchSpace: @pytest.fixture def posterior_handler() -> PosteriorHandler: - mean = gpx.Zero() - kernel = gpx.Matern52(lengthscale=jnp.array(1.0), variance=jnp.array(1.0)) - prior = gpx.Prior(mean_function=mean, kernel=kernel) - likelihood_builder = lambda x: gpx.Gaussian( + mean = gpx.mean_functions.Zero() + kernel = gpx.kernels.Matern52(lengthscale=jnp.array(1.0), variance=jnp.array(1.0)) + prior = gpx.gps.Prior(mean_function=mean, kernel=kernel) + likelihood_builder = lambda x: gpx.likelihoods.Gaussian( num_datapoints=x, obs_stddev=jnp.array(1e-3) ) posterior_handler = PosteriorHandler( prior=prior, likelihood_builder=likelihood_builder, - optimization_objective=gpx.ConjugateMLL(negative=True), + optimization_objective=gpx.objectives.ConjugateMLL(negative=True), optimizer=ox.adam(learning_rate=0.01), num_optimization_iters=100, ) diff --git a/tests/test_mean_functions.py b/tests/test_mean_functions.py index fded647c..2fb74283 100644 --- a/tests/test_mean_functions.py +++ b/tests/test_mean_functions.py @@ -65,8 +65,10 @@ def test_zero_mean_remains_zero() -> None: constant=False ) # Prevent kernel from modelling non-zero mean meanf = Zero() - prior = gpx.Prior(mean_function=meanf, kernel=kernel) - likelihood = gpx.Gaussian(num_datapoints=D.n, obs_stddev=jnp.array(1e-3)) + prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel) + likelihood = gpx.likelihoods.Gaussian( + num_datapoints=D.n, obs_stddev=jnp.array(1e-3) + ) likelihood = likelihood.replace_trainable(obs_stddev=False) posterior = prior * likelihood diff --git a/tests/test_objectives.py b/tests/test_objectives.py index 5817f13a..fbb53f14 100644 --- a/tests/test_objectives.py +++ b/tests/test_objectives.py @@ -5,12 +5,9 @@ import pytest import gpjax as gpx -from gpjax import ( - Bernoulli, - Gaussian, - Prior, -) from gpjax.dataset import Dataset +from gpjax.gps import Prior +from gpjax.likelihoods import Gaussian from gpjax.objectives import ( ELBO, AbstractObjective, @@ -64,10 +61,11 @@ def test_conjugate_mll( D = build_data(num_datapoints, num_dims, key, binary=False) # Build model - p = Prior( - kernel=gpx.RBF(active_dims=list(range(num_dims))), mean_function=gpx.Constant() + p = gpx.gps.Prior( + kernel=gpx.kernels.RBF(active_dims=list(range(num_dims))), + mean_function=gpx.mean_functions.Constant(), ) - likelihood = Gaussian(num_datapoints=num_datapoints) + likelihood = gpx.likelihoods.Gaussian(num_datapoints=num_datapoints) post = p * likelihood mll = ConjugateMLL(negative=negative) @@ -94,7 +92,8 @@ def test_conjugate_loocv( # Build model p = Prior( - kernel=gpx.RBF(active_dims=list(range(num_dims))), mean_function=gpx.Constant() + kernel=gpx.kernels.RBF(active_dims=list(range(num_dims))), + mean_function=gpx.mean_functions.Constant(), ) likelihood = Gaussian(num_datapoints=num_datapoints) post = p * likelihood @@ -122,10 +121,11 @@ def test_non_conjugate_mll( D = build_data(num_datapoints, num_dims, key, binary=True) # Build model - p = Prior( - kernel=gpx.RBF(active_dims=list(range(num_dims))), mean_function=gpx.Constant() + p = gpx.gps.Prior( + kernel=gpx.kernels.RBF(active_dims=list(range(num_dims))), + mean_function=gpx.mean_functions.Constant(), ) - likelihood = Bernoulli(num_datapoints=num_datapoints) + likelihood = gpx.likelihoods.Bernoulli(num_datapoints=num_datapoints) post = p * likelihood mll = NonConjugateMLL(negative=negative) @@ -158,11 +158,14 @@ def test_collapsed_elbo( key=key, minval=-2.0, maxval=2.0, shape=(num_datapoints // 2, num_dims) ) - p = Prior( - kernel=gpx.RBF(active_dims=list(range(num_dims))), mean_function=gpx.Constant() + p = gpx.gps.Prior( + kernel=gpx.kernels.RBF(active_dims=list(range(num_dims))), + mean_function=gpx.mean_functions.Constant(), + ) + likelihood = gpx.likelihoods.Gaussian(num_datapoints=num_datapoints) + q = gpx.variational_families.CollapsedVariationalGaussian( + posterior=p * likelihood, inducing_inputs=z ) - likelihood = Gaussian(num_datapoints=num_datapoints) - q = gpx.CollapsedVariationalGaussian(posterior=p * likelihood, inducing_inputs=z) negative_elbo = CollapsedELBO(negative=negative) @@ -176,7 +179,9 @@ def test_collapsed_elbo( assert evaluation.shape == () # Data on the full dataset should be the same as the marginal likelihood - q = gpx.CollapsedVariationalGaussian(posterior=p * likelihood, inducing_inputs=D.X) + q = gpx.variational_families.CollapsedVariationalGaussian( + posterior=p * likelihood, inducing_inputs=D.X + ) mll = ConjugateMLL(negative=negative) expected_value = mll(p * likelihood, D) actual_value = negative_elbo(q, D) @@ -203,16 +208,17 @@ def test_elbo( key=key, minval=-2.0, maxval=2.0, shape=(num_datapoints // 2, num_dims) ) - p = Prior( - kernel=gpx.RBF(active_dims=list(range(num_dims))), mean_function=gpx.Constant() + p = gpx.gps.Prior( + kernel=gpx.kernels.RBF(active_dims=list(range(num_dims))), + mean_function=gpx.mean_functions.Constant(), ) if binary: - likelihood = Bernoulli(num_datapoints=num_datapoints) + likelihood = gpx.likelihoods.Bernoulli(num_datapoints=num_datapoints) else: - likelihood = Gaussian(num_datapoints=num_datapoints) + likelihood = gpx.likelihoods.Gaussian(num_datapoints=num_datapoints) post = p * likelihood - q = gpx.VariationalGaussian(posterior=post, inducing_inputs=z) + q = gpx.variational_families.VariationalGaussian(posterior=post, inducing_inputs=z) negative_elbo = ELBO( negative=negative, diff --git a/tests/test_variational_families.py b/tests/test_variational_families.py index 8a576c2c..66a3c650 100644 --- a/tests/test_variational_families.py +++ b/tests/test_variational_families.py @@ -114,8 +114,10 @@ def test_variational_gaussians( variational_family: AbstractVariationalFamily, ) -> None: # Initialise variational family: - prior = gpx.Prior(kernel=gpx.RBF(), mean_function=gpx.Constant()) - likelihood = gpx.Gaussian(123) + prior = gpx.gps.Prior( + kernel=gpx.kernels.RBF(), mean_function=gpx.mean_functions.Constant() + ) + likelihood = gpx.likelihoods.Gaussian(123) inducing_inputs = jnp.linspace(-5.0, 5.0, n_inducing).reshape(-1, 1) test_inputs = jnp.linspace(-5.0, 5.0, n_test).reshape(-1, 1) @@ -223,14 +225,16 @@ def test_collapsed_variational_gaussian( x = jnp.hstack([x] * point_dim) D = gpx.Dataset(X=x, y=y) - prior = gpx.Prior(kernel=gpx.RBF(), mean_function=gpx.Constant()) + prior = gpx.gps.Prior( + kernel=gpx.kernels.RBF(), mean_function=gpx.mean_functions.Constant() + ) inducing_inputs = jnp.linspace(-5.0, 5.0, n_inducing).reshape(-1, 1) inducing_inputs = jnp.hstack([inducing_inputs] * point_dim) test_inputs = jnp.linspace(-5.0, 5.0, n_test).reshape(-1, 1) test_inputs = jnp.hstack([test_inputs] * point_dim) - posterior = prior * gpx.Gaussian(num_datapoints=D.n) + posterior = prior * gpx.likelihoods.Gaussian(num_datapoints=D.n) variational_family = CollapsedVariationalGaussian( posterior=posterior, @@ -240,7 +244,7 @@ def test_collapsed_variational_gaussian( # We should raise an error for non-Gaussian likelihoods: with pytest.raises(TypeError): CollapsedVariationalGaussian( - posterior=prior * gpx.Bernoulli(num_datapoints=D.n), + posterior=prior * gpx.likelihoods.Bernoulli(num_datapoints=D.n), inducing_inputs=inducing_inputs, )