TensorFlow implementation of DRAW: A Recurrent Neural Network For Image Generation on the MNIST generation task.
With Attention | Without Attention |
---|---|
Although open-source implementations of this paper already exist (see links below), this implementation focuses on simplicity and ease of understanding. I tried to make the code resemble the raw equations as closely as posible.
For a gentle walkthrough through the paper and implementation, see the writeup here: https://evjang.com/articles/draw.
python draw.py --data_dir=/tmp/draw
downloads the binarized MNIST dataset to /tmp/draw/mnist and trains the DRAW model with attention enabled for both reading and writing. After training, output data is written to /tmp/draw/draw_data.npy
You can visualize the results by running the script python plot_data.py <prefix> <output_data>
For example,
python myattn /tmp/draw/draw_data.npy
To run training without attention, do:
python draw.py --working_dir=/tmp/draw --read_attn=False --write_attn=False
Instead of training from scratch, you can load pre-trained weights by uncommenting the following line in draw.py
and editing the path to your checkpoint file as needed. Save electricity!
saver.restore(sess, "/tmp/draw/drawmodel.ckpt")
This git repository contains the following pre-trained in the data/
folder:
Filename | Description |
---|---|
draw_data_attn.npy | Training outputs for DRAW with attention |
drawmodel_attn.ckpt | Saved weights for DRAW with attention |
draw_data_noattn.npy | Training outputs for DRAW without attention |
drawmodel_noattn.ckpt | Saved weights for DRAW without attention |
These were trained for 10000 iterations with minibatch size=100 on a GTX 970 GPU.