Skip to content

Latest commit

 

History

History
51 lines (40 loc) · 2.12 KB

README.md

File metadata and controls

51 lines (40 loc) · 2.12 KB

Matching Networks for One Shot Learning

Tensorflow implementation of Matching Networks for One Shot Learning by Vinyals et al.

Prerequisites

Data

Preparation

  1. Download and extract omniglot dataset, modify omniglot_train and omniglot_test in utils.py to your location.

  2. First time training will generate omniglot.npy to the directory. The shape should be (1632, 80, 28, 28, 1) , meaning 1623 classes, 20 * 4 90-degree-transforms (0, 90, 180, 270), height, width, channel. 1200 classes used for training and 423 used for testing.

Train

python main.py --train

Train from a previous checkpoint at epoch X:

python main.py --train --modelpath=ckpt/model-X

Check out tunable hyper-parameters:

python main.py

Test

python main.py --eval

Notes

  • The model will test the evaluation accuracy after every epoch.
  • As the paper indicated, training on Omniglot with FCE does not do any better but I still implemented them (as far as I'm concerned there are no repos that fully implement the FCEs by far).
  • The authors did not mentioned the value of time steps K in FCE_f, in the sited paper, K is tested with 0, 1, 5, 10 as shown in table 1.
  • When using the data generated by myself (through utils.py), the evaluation accuracy at epoch 100 is around 82.00% (training accuracy 83.14%) without data augmentation.
  • Nevertheless, when using data provided by zergylord in his repo, this implementation can achieve up to 96.61% accuracy (training 97.22%) at epoch 100.
  • Issues are welcome!

Resources