Skip to content
/ ram Public
forked from amasky/ram

Recurrent Models of Visual Attention (RAM) with Chainer

License

Notifications You must be signed in to change notification settings

kencharo/ram

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Recurrent Attention Model

Recurrent Attention Model with Chainer based on the following paper
arXiv:1406.6247: Recurrent Models of Visual Attention [Volodymyr Mnih+ 2014]

Features

  • RAM model on Chainer
  • script to train RAM and classify digits with RAM
  • including Original MNIST, Translated MNIST, and Cluttered Translated MNIST tasks
  • 1.13% error on 28x28 MNIST (RAM, 6 glimpses, 8x8, 1 scale)
  • 1.58% error on 60x60 Translated MNIST (RAM, 6 glimpses, 12x12, 3 scales)
  • 5.30% error on 60x60 Cluttered Translated MNIST (RAM, 6 glimpses, 12x12, 3 scales)

Examples

glimpses and output probabilities at each time step

  • Original MNIST

example on original MNIST

  • Translated MNIST

example on translated MNIST

  • Cluttered Translated MNIST

example on cluttered translated MNIST

Dependencies

Python(2 or 3), Chainer, PIL, matplotlib, tqdm

Usage

train.py: optimizes weights of a RAM model and ouputs learned weights to *.chainermodel file every 100 epoch

  • with "--original" for 28x28 original MNIST task, "--translated" for 60x60 translated MNIST, and "--cluttered" for 60x60 cluttered translated MNIST
  • training for Cluttered Translated MNIST will take around half a day with CPU
  • train on CPU with default setting or GPU with "-g YOUR_GPU_DEVICE_ID"
python train.py --original  

infer.py: plot result of inference by a trained RAM model (result will show up with your matplotlib's backend)

python infer.py --original -m models/ram_original_epoch800.chainermodel  

About

Recurrent Models of Visual Attention (RAM) with Chainer

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%