Please see the adversarial data programming website page to make yourself comfortable with the adversarial data programming concept. We will train a basic generative adversarial network (GAN) on MNIST digit dataset - using labeling functions. We will do the following steps:
- Model a basic generative adversarial network with linear layers
- We will formalise our Labeling Functions Block using a labeling function based on Kmeans clustering
- We will train the model end-to-end on MNIST dataset
You need to install:
- Pytorch > 1.2+
- Python 3.6+
- matplotlib
- torchvision
- install kmeans of pytorch using the command
pip install kmeans-pytorch
- pylab
The generated images are stored in samples
folder. However, you can see the labels after every iteration
The GAN code is based on . And we thank kmeans-pytorch
for providing the kmeans unsupervised clustering code.
You can see the tutorial to write labeling functions for real dataset.
If you find the code useful, please cite the paper:
@InProceedings{Pal_2018_CVPR,
author = {Pal, Arghya and Balasubramanian, Vineeth N.},
title = {Adversarial Data Programming: Using GANs to Relax the Bottleneck of Curated Labeled Data},
booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2018}
}