This is the code for the models in NeurIPS Submission AVAE
Folder contains code to train AVAE model in JAX, and we will be uploading evaluation setup soon.
Code files in the folder
- checkpointer.py: Checkpointing abstraction
- data_iterators.py: Datasets to be used
- decoders.py: VAE decoder network architectures
- encoders.py: VAE encoder network architectures
- kl.py: KL computation between 2 gaussians
- train.py: Function to train given ELBO, network and data
- train_main.py: Main file to train AVAE
- vae.py: VAE model defining various ELBOs
To set up a Python3 virtual environment with the required dependencies, run:
python -m venv avae_env
source avae_env/bin/activate
pip install --upgrade pip
pip install -r avae/requirements.txt
Following command will run AVAE training for ColorMnist dataset using MLP network architectures.
python -m avae.train_main \
--dataset='color_mnist' \
--latent_dim=64 \
--checkpoint_dir='/tmp/avae_checkpoints' \
--checkpoint_filename='color_mnist_mlp_avae' \
--rho=0.975 \
--encoder='color_mnist_mlp_encoder' \
--decoder='color_mnist_mlp_decoder'
If you use that code for your research, please consider citing our paper:
@article{cemgil2020autoencoding,
title={The Autoencoding Variational Autoencoder},
author={Cemgil, Taylan and Ghaisas, Sumedh and Dvijotham, Krishnamurthy and Gowal, Sven and Kohli, Pushmeet},
journal={Advances in Neural Information Processing Systems},
volume={33},
year={2020}
}