Official Code Repository for the paper Generative Modeling on Manifolds Through Mixture of Riemannian Diffusion Processes.
In this repository, we implement the Riemannian Diffusion Mixture using JAX.
We provide additional code repo for PyTorch implementation in riemannian-diffusion-mixture-torch.
- Simple design of the generative process as a mixture of Riemannian bridge processes, which does not require heat kernel estimation as previous denoising approach.
- Geometrical interpretation for the mixture process as the weighted mean of tangent directions on manifolds
- Scales to higher dimensions with significantly faster training compared to previous diffusion models.
Create an environment with Python 3.9.0, and install JAX using the following command:
pip install https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.4.13+cuda11.cudnn86-cp39-cp39-manylinux2014_x86_64.whl
pip install jax==0.4.13
Install requirements with the following command:
pip install -r requirements.txt
conda install -c conda-forge cartopy python-kaleido
Following manifolds are supported in this repo:
- Euclidean
- Hypersphere
- Torus
- Hyperboloid
- Triangular mesh
- Special orthogonal group
To implement new manifolds, add python files that define the geometry of the manifold in /geomstats/geometry
.
Please refer to geomstats/geometry for examples.
This repo supports experiments on the following datasets:
- Earth and climate science datasets:
Volcano
,Earthquake
,Flood
, an dFire
- Triangular mesh datasets:
Spot the Cow
andStandford Bunny
- Hyperboloid datasets
Please refer to riemannian-diffusion-mixture-torch for running expreiments on protein datasets
and high-dimensional tori
.
Create triangular mesh datasets with the following command:
python data/create_mesh_dataset.py --data $DATA --k $K --plot
where $DATA
denotes spot
or bunny
and $K
denotes 10, 50, or 100
.
Running the commands will create .pkl files in /data/mesh
directory.
The configurations are provided in the config/
directory in YAML
format.
CUDA_VISIBLE_DEVICES=0 python main.py -m \
experiment=<exp> \
seed=0,1,2,3,4 \
n_jobs=5 \
where <exp>
is one of the experiments in config/experiment/*.yaml
For example,
CUDA_VISIBLE_DEVICES=0 python main.py -m \
experiment=earthquake \
seed=0,1,2,3,4 \
n_jobs=5 \
If you found the provided code with our paper useful in your work, we kindly request that you cite our work.
@inproceedings{jo2024riemannian,
author = {Jaehyeong Jo and
Sung Ju Hwang},
title = {Generative Modeling on Manifolds Through Mixture of Riemannian Diffusion Processes},
booktitle = {International Conference on Machine Learning},
year = {2024},
}
Our code builds upon geomstats with jax functionality added. We thank Riemannian Score-Based Generative Modelling for their pioneering work.