This repository contains the implementation of the algorithms in the paper Learning-Rate-Free Stochastic Optimization over Riemannian Manifolds by Daniel Dodd, Louis Sharrock, and Christopher Nemeth.
git clone https://github.com/daniel-dodd/riemannian_dog.git
cd riemannian_dog
poetry install
cd riemannian_dog
pytest
We consider the problem of maximizing the Rayleigh quotient
We begin by importing necessary dependencies and define a key for reproducibility.
import jax.numpy as jnp
from jaxtyping import Float, Array, Key
from jax import random, config, vmap
config.update("jax_enable_x64", True)
from manifold.geometry import Sphere
from manifold.optimisers import rdog
from manifold.gradient import rgrad
# Settings.
SEED = 123
DIMENSION = 10
key = random.PRNGKey(SEED)
Now we generate a positive definite matrix that we wish to compute the Rayleigh quotient of.
key, subkey = random.split(key)
sqrt = random.normal(subkey, shape = (DIMENSION, DIMENSION))
matrix = sqrt @ sqrt.T
Next we write our loss function.
def loss(point: Float[Array, "N"], matrix: Float[Array, "N N"]) -> Float[Array, ""]:
"""Rayleigh quotient objective.
Args:
point: Point to evaluate the Rayleigh quotient at.
matrix: Matrix we are evaluating the Rayleigh quotient of.
Returns:
Rayleigh quotient objective.
"""
return -0.5 * jnp.dot(point, jnp.dot(matrix, point))
Now we produce to define our Sphere
manifold and generate an initial point via the .random_point
method.
manifold = Sphere(DIMENSION)
key, subkey = random.split(key)
point = manifold.random_point(subkey)
We define our parameter-free optimiser.
opt = rdog()
opt_state = opt.init(manifold, point)
And optimise!
for _ in range(500):
rg = rgrad(manifold, loss)(point, matrix) # I'm a Riemannian gradient!
updates, opt_state = opt.update(manifold, rg, opt_state, point)
point = manifold.exp(point, updates)
Finally to sense check what we have run, we compare to the traditional numerical solution.
eigenvalues, eigenvectors = jnp.linalg.eigh(matrix)
max_eig = jnp.max(eigenvalues)
min_eig = jnp.min(eigenvalues)
sol = eigenvectors[:, jnp.argmax(eigenvalues)]/jnp.linalg.norm(eigenvectors[:, jnp.argmax(eigenvalues)])
Computing the distance between the numerical solution and a local optima.
jnp.mininum(manifold.distance(point, sol), manifold.distance(point, -sol))
We see that we obtain the same answer!
The experiment scipts are contained in the experiments
directory.
experiments/toy
is code to reproduce Figure 1. Please runexperiments/toy/toy.py
to run the experiments and cache the results, followed byexperiments/toy/plot_toy.py
to generate the plot.experiments/sphere_rayleigh
is code to reproduce Figure 2. Please runexperiments/sphere_rayleigh/sphere.py
to run the experiments and cache the results, followed byexperiments/sphere_rayleigh/plot_sphere.py
to generate the plot.experiments/grassmann_pca
is code to reproduce Figure 3. For (a) please runexperiments/grassmann_pca/wine.py
to run the experiments and cache the results, followed byexperiments/grassmann_pca/plot_wine.py
to generate the plot. For (b) please runexperiments/grassmann_pca/waveform.py
to run the experiments and cache the results, followed byexperiments/grassmann_pca/plot_waveform.py
to generate the plot. For (c) please runexperiments/grassmann_pca/tiny_image_net.py
to run the experiments and cache the results, followed byexperiments/grassmann_pca/plot_tiny_image_net.py
to generate the plot.experiments/poincare_wordnet
is code to reproduce Figure 4. For (a) please run the files inexperiments/five_dimensional
followed by the plotting scripts. For (b) please run the files inexperiments/two_dimensional
followed by the plotting scripts.