We use Anaconda for managing Python environment.
conda env create --file environment.yml
conda activate neural_dsc
pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html
HOROVOD_GPU_OPERATIONS=NCCL pip install --upgrade --no-cache-dir "horovod[pytorch]"
conda activate neural_dsc
KITTI Stereo experiments (as shown in Figure 2) can be reproduced by executing the run-kitti.sh script.
NOTE: Please contact the corresponding author at acnagle@utexas.edu
for more information on gathering the KITTI dataset.
Training the VQ-VAE on Bernoulli sequences (as shown in Figure 6) can be reproduced by executing the run-discus.sh
script. Note that the LDPC code is implemented in MATLAB. The script to gather the LDPC results can be found in the baseline/nonlearning/
directory. This MATLAB script saves its results into a iid_flip.csv
Download celeba-tfr.tar
inside data/
directory, then run the following command:
python run_top.py prep celebahq256
Repeat the following with different --codebook_bits
argument to control the total rate.
# Joint VQ-VAE
horovodrun -n 2 python -O run_top.py train --dataset celebahq256 --arch vqvae_top_8x --ch_latent 1 \
--root_dir checkpoints/celeba256_vqvae_joint_4bit --dec_si True --enc_si True --codebook_bits 4
# Distributed VQ-VAE
horovodrun -n 2 python -O run_top.py train --dataset celebahq256 --arch vqvae_top_8x --ch_latent 1 \
--root_dir checkpoints/celeba256_vqvae_dist_4bit --dec_si True --enc_si False --codebook_bits 4
# Separate VQ-VAE
horovodrun -n 2 python -O run_top.py train --dataset celebahq256 --arch vqvae_top_8x --ch_latent 1 \
--root_dir checkpoints/celeba256_vqvae_separate_4bit --dec_si False --enc_si False --codebook_bits 4
horovodrun -n 2 python run_top.py eval --batch_size 250 \
All generated plots will be stored in the folder paper/
python plot_rd_curves.py
# Following command may take a while to finish due to slow download speed.
python run_mnist_grad.py prep mnist
python -O run_mnist_grad.py gather_gradients --out_dir checkpoints/mnist_grad_data
# Joint VQ-VAE
python -O run_mnist_grad.py train_vqvae --grad_dump checkpoints/mnist_grad_data/grads.pt --d_latent 40 --codebook_bits 8 \
--enc_si True --dec_si True --root_dir checkpoints/mnist_grad_vqvae_joint_40d_8bits
# Distributed VQ-VAE
python -O run_mnist_grad.py train_vqvae --grad_dump checkpoints/mnist_grad_data/grads.pt --d_latent 40 --codebook_bits 8 \
--enc_si False --dec_si True --root_dir checkpoints/mnist_grad_vqvae_dist_40d_8bits
# Separate VQ-VAE
python -O run_mnist_grad.py train_vqvae --grad_dump checkpoints/mnist_grad_data/grads.pt --d_latent 40 --codebook_bits 8 \
--enc_si False --dec_si False --root_dir checkpoints/mnist_grad_vqvae_separate_40d_8bits
for seed in $(seq 1 20); do
python run_mnist_grad.py eval checkpoints/mnist_grad_vqvae_{joint,dist,separate}_40d_8bits/ckpt_ep=500_step=0391000.pt --seed $seed;
python run_mnist_grad.py plot checkpoints/mnist_grad_vqvae_{joint,dist,separate}_40d_8bits/ckpt_ep=500_step=0391000.pt \
--seeds 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20 --out_dir paper --labels Joint,Distributed,Separate