Implementation of neural network model that can generate natural language captions for images. Three different architectures are proposed and compared: first one uses vanilla recurrent neural networks (RNNs), second one long-short term memory networks (LSTMs), and third one attention-based LSTMs.
To build a model that can generate a descriptive caption for an image we provide it.
In this project, we implemented a vanilla recurrent neural networks (RNNs), long-short term memory networks (LSTMs), and attention-based LSTMs to train a model that can generate natural language captions for images.
Models presented are highly similar to very early works in neural-network based image captioning. If you are interested to learn more, check out these two papers:
- Show and Tell: A Neural Image Caption Generator
- Show, Attend and Tell: Neural Image Caption Generation with Visual Attention
The Attention model learns where to look.
As you generate a caption, word by word, you can see the model's gaze shifting across the image.
This is possible because of its Attention mechanism, which allows it to focus on the part of the image most relevant to the word it is going to utter next.
Here is a caption generated example:
-
Encoder-Decoder architecture. Typically, a model that generates sequences will use an Encoder to encode the input into a fixed form and a Decoder to decode it, word by word, into a sequence.
-
Attention. The use of Attention networks is widespread in deep learning, and with good reason. This is a way for a model to choose only those parts of the encoding that it thinks is relevant to the task at hand. The same mechanism you see employed here can be used in any model where the Encoder's output has multiple points in space or time. In image captioning, you consider some pixels more important than others. In sequence to sequence tasks like machine translation, you consider some words more important than others.
-
Transfer Learning. This is when you borrow from an existing model by using parts of it in a new model. This is almost always better than training a new model from scratch (i.e., knowing nothing). As you will see, you can always fine-tune this second-hand knowledge to the specific task at hand. Using pretrained word embeddings is a dumb but valid example. For our image captioning problem, we will use a pretrained Encoder, and then fine-tune it as needed.
The pipeline for the project looks as follows:
- The input is a dataset of images and 5 sentence descriptions that were collected with Amazon Mechanical Turk. We will use the 2014 release of the COCO Captions dataset which has become the standard testbed for image captioning. The dataset consists of 80,000 training images and 40,000 validation images, each annotated with 5 captions.
- In the training stage, the images are fed as input to RNN (or LSTM/LSTM with attention depending on the model) and the RNN is asked to predict the words of the sentence, conditioned on the current word and previous context as mediated by the hidden layers of the neural network. In this stage, the parameters of the networks are trained with backpropagation.
- In the prediction stage, a witheld set of images is passed to RNN and the RNN generates the sentence one word at a time. The code also includes utilities for visualizing the results.
Python 3.10, modern version of PyTorch, numpy and scipy module. Most of these are okay to install with pip. To install all dependencies at once, run the command pip install -r requirements.txt
I only tested this code with Ubuntu 20.04, but I tried to make it as generic as possible (e.g. use of os module for file system interactions etc. So it might work on Windows and Mac relatively easily.)
- Get the code.
$ git clone
the repo and install the Python dependencies - Train the models. Run the training
$ python train_rnn.py
or$ python train_lstm.py
or$ python train_lstm_attention.py
, depending on the model that you want to try (see many additional argument settings inside the file) and wait. You'll see that the learning code writes checkpoints intocv/
and periodically print its status. - Evaluate the models checkpoints and Visualize the predictions. To evaluate a checkpoint from
checkpoints/
, run the scripts$ python test_rnn.py
or$ python test_lstm.py
or$ python test_lstm_attention.py
and pass it the path to a checkpoint ( by adding --checkpoint /path/to/the/checkpoint after your python command).
For this project we used the 2014 release of the COCO Captions dataset which has become the standard testbed for image captioning. The dataset consists of 80,000 training images and 40,000 validation images, each annotated with 5 captions written by workers on Amazon Mechanical Turk.
We have preprocessed the data and saved them into a serialized data file. It contains 10,000 image-caption pairs for training and 500 for testing. The images have been downsampled to 112x112 for computation efficiency and captions are tokenized and numericalized, clamped to 15 words. You can download the file named coco.pt
(378MB) with the link below and run some useful stats.
We used RegNet-X 400MF model to extract features for the images. A few notes on the caption preprocessing:
Dealing with strings is inefficient, so we worked with an encoded version of the captions. Each word is assigned an integer ID, allowing us to represent a caption by a sequence of integers. The mapping between integer IDs and words is saved in an entry named vocab
(both idx_to_token
and token_to_idx
), and we used the function decode_captions
from utils.py
to convert tensors of integer IDs back into strings.
There are a couple special tokens that we added to the vocabulary. We prepended a special <START>
token and appent an <END>
token to the beginning and end of each caption respectively. Rare words are replaced with a special <UNK>
token (for "unknown"). In addition, since we wanted to train with minibatches containing captions of different lengths, we pad short captions with a special <NULL>
token after the <END>
token and didn't compute loss or gradient for <NULL>
tokens.
The first essential component in an image captioning model is an encoder that inputs an image and produces features for decoding the caption. Here, we used a small RegNetX-400MF as the backbone so we can train in reasonable time..
It accepts image batches of shape (B, C, H, W)
and outputs spatial features from final layer that have shape (B, C, H/32, W/32)
.
For vanilla RNN and LSTM, we used the average pooled features (shape (B, C)
) for decoding captions, whereas for attention LSTM we aggregated the spatial features by learning attention weights.
Checkout the ImageEncoder
method in rnn_lstm_captioning.py
to see the initialization of the model.
We used the implementation from torchvision and put a very thin wrapper module for our use-case.
In deep learning systems, we commonly represent words using vectors. Each word of the vocabulary is associated with a vector, and these vectors are learned jointly with the rest of the system.
In an RNN language model, at every timestep we produce a score for each word in the vocabulary.
This score is obtained by applying an affine transform to the hidden state (think nn.Linear
module).
We know the ground-truth word at each timestep, so we use a cross-entropy loss at each timestep.
We sum the losses over time and average them over the minibatch.
However there is one wrinkle: since we operate over minibatches and different captions may have different lengths, we appent <NULL>
tokens to the end of each caption so they all have the same length. We don't want these <NULL>
tokens to count toward the loss or gradient, so in addition to scores and ground-truth labels our loss function also accepts a ignore_index
that tells it which index in caption should be ignored when computing the loss.
Finally we wrapped everything into the captioning module.
This module has a generic structure for RNN, LST, and attention-based LSTM -- which we control by providing cell_type
argument (one of ["rnn", "lstm", "attn"]
).
LSTM stands for Long-Short Term Memory Networks, a variant of vanilla Recurrent Neural Networks. Vanilla RNNs can be tough to train on long sequences due to vanishing and exploding gradients caused by repeated matrix multiplication. LSTMs solve this problem by replacing the simple update rule of the vanilla RNN with a gating mechanism.
LSTM Update Rule: Similar to the vanilla RNN, at each timestep we receive an input
At each timestep we first compute an activation vector
where
Finally we compute the next cell state
where
Attention LSTM essentially adds an attention input
To get the attention input scaled dot-product attention
, as covered in the lecture. We first project the CNN feature activation from
To simplify the formulation here, we flattened the spatial dimensions of
The attention embedding given the attention weights is then
-
Scaled dot-product attention.
Given the LSTM hidden state from the previous time step
prev_h
(or$h_{t-1}$ ) and the projected CNN feature activationA
, the attention weightsattn_weights
(or$\tilde{M_{attn}^t}$ with a reshaping to$\mathbb{R}^{4 \times 4}$ ) attention embedding outputattn
(or$x_{attn}^t$ ) is computed using the formulation we provided.
Hence, at each timestep the activation vector