Tensorflow implementation of Matching Networks for One Shot Learning by Vinyals et al.
- Python 2.7+
- NumPy
- SciPy
- tqdm
- Tensorflow r1.0+
-
Download and extract omniglot dataset, modify
omniglot_train
andomniglot_test
inutils.py
to your location. -
First time training will generate
omniglot.npy
to the directory. The shape should be (1632, 80, 28, 28, 1) , meaning 1623 classes, 20 * 4 90-degree-transforms (0, 90, 180, 270), height, width, channel. 1200 classes used for training and 423 used for testing.
python main.py --train
Train from a previous checkpoint at epoch X:
python main.py --train --modelpath=ckpt/model-X
Check out tunable hyper-parameters:
python main.py
python main.py --eval
- The model will test the evaluation accuracy after every epoch.
- As the paper indicated, training on Omniglot with FCE does not do any better but I still implemented them (as far as I'm concerned there are no repos that fully implement the FCEs by far).
- The authors did not mentioned the value of time steps K in FCE_f, in the sited paper, K is tested with 0, 1, 5, 10 as shown in table 1.
- When using the data generated by myself (through
utils.py
), the evaluation accuracy at epoch 100 is around 82.00% (training accuracy 83.14%) without data augmentation. - Nevertheless, when using data provided by zergylord in his repo, this implementation can achieve up to 96.61% accuracy (training 97.22%) at epoch 100.
- Issues are welcome!
- The paper.
- Referred to this repo.
- Karpathy's note helps a lot.