This repository contains Pytorch implementation of a method and experiments from the paper (q,p)-Wasserstein GANs: Comparing Ground Metrics for Wasserstein GANs.
.
├── bin
│ ├── cifar10.sh
│ └── mnist.sh
├── data
├── figs
│ └── gif
├── models
│ ├── __init__.py
│ ├── cifar10.py
│ ├── gaussian_model.py
│ └── mnist.py
├── src
│ ├── __init__.py
│ ├── main.py
│ ├── qpwgan.py
│ ├── discrete_measures.py
│ ├── gaussian_mixture.py
│ ├── metrics.py
│ ├── plot_nearest_distance.py
│ └── utils.py
├── README.md
├── requirements.txt
└── setup.py
pip install -r requirements.txt
pip install -e .
To use wandb tracking, do in advance
wandb login
Optimization of Wasserstein metric on discrete measure:
python src/discrete_measures.py
Approximating a Gaussian mixture distribution:
python src/gaussian_mixture.py \
--n_epoch 601 \
--search_space full \
--n_critic_iter 2 \
--reg_coef1 0.1 \
--reg_coef2 1 \
--batch_size 64
MNIST
bash bin/mnist.sh
CIFAR10
bash bin/cifar10.sh
CIFAR10 Progress
q=1, p=1, critic iters = 1 | q=1, p=1, critic iters = 5 |
q=1, p=2, critic iters = 1 | q=1, p=2, critic iters = 5 |
q=2, p=2, critic iters = 1 | q=2, p=2, critic iters = 5 |