-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
193 lines (158 loc) · 7.03 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import time
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.models as models
from nltk.translate.bleu_score import corpus_bleu
from load_dataset import Flickr8kDataset
from model import CaptionDecoder, Encoder
from utils.utils import save_checkpoint, log_gradient_norm, set_up_causal_mask, greedy_decoding
def evaluate(subset, encoder, decoder, config, device):
"""Evaluates (BLEU score) caption generation model on a given subset.
Arguments:
subset (Flickr8KDataset): Train/Val/Test subset
encoder (nn.Module): CNN which generates image features
decoder (nn.Module): Transformer Decoder which generates captions for images
config (object): Contains configuration for the evaluation pipeline
device (torch.device): Device on which to port used tensors
Returns:
bleu (float): BLEU-{1:4} scores performance metric on the entire subset - corpus bleu
"""
batch_size = config["batch_size"]["eval"]
max_len = config["max_len"]
bleu_w = config["bleu_weights"]
# Mapping from vocab index to string representation
idx2word = subset.idx2token
# Ids for special tokens
sos_id = subset.start_idx
eos_id = subset.end_idx
pad_id = subset.pad_idx
references_total = []
predictions_total = []
print("Evaluating model.")
for x_img, y_caption in subset.inference_batch(batch_size):
x_img = x_img.to(device)
# Extract image features
img_features = encoder(x_img)
img_features = img_features.view(img_features.size(0), img_features.size(1), -1)
img_features = img_features.permute(0, 2, 1)
img_features = img_features.detach()
# Get the caption prediction for each image in the mini-batch
predictions = greedy_decoding(decoder, img_features, sos_id, eos_id, pad_id, idx2word, max_len, device)
references_total += y_caption
predictions_total += predictions
# Evaluate BLEU score of the generated captions
bleu_1 = corpus_bleu(references_total, predictions_total, weights=bleu_w["bleu-1"]) * 100
bleu_2 = corpus_bleu(references_total, predictions_total, weights=bleu_w["bleu-2"]) * 100
bleu_3 = corpus_bleu(references_total, predictions_total, weights=bleu_w["bleu-3"]) * 100
bleu_4 = corpus_bleu(references_total, predictions_total, weights=bleu_w["bleu-4"]) * 100
bleu = [bleu_1, bleu_2, bleu_3, bleu_4]
return bleu
def train(config, writer, device):
"""Performs the training of the model.
Arguments:
config (object): Contains configuration of the pipeline
writer: tensorboardX writer object
device: device on which to map the model and data
"""
torch.manual_seed(config["seed"])
np.random.seed(config["seed"])
# Define dataloader hyper-parameters
train_hyperparams = {
"batch_size": config["batch_size"]["train"],
"shuffle": True,
"num_workers": 1,
"drop_last": True
}
# Create dataloaders
train_set = Flickr8kDataset(config, config["split_save"]["train"], training=True)
valid_set = Flickr8kDataset(config, config["split_save"]["validation"], training=False)
train_loader = DataLoader(train_set, **train_hyperparams)
#######################
# Set up the encoder
#######################
# Instantiate the encoder
encoder = Encoder()
encoder = encoder.to(device)
######################
# Set up the decoder
######################
# Instantiate the decoder
decoder = CaptionDecoder(config)
decoder = decoder.to(device)
if config["checkpoint"]["load"]:
checkpoint_path = config["checkpoint"]["path"]
decoder.load_state_dict(torch.load(checkpoint_path))
decoder.train()
# Set up causal mask for transformer decoder
causal_mask = set_up_causal_mask(config["max_len"], device)
# Load training configuration
train_config = config["train_config"]
learning_rate = train_config["learning_rate"]
# Prepare the model optimizer
optimizer = torch.optim.AdamW(
decoder.parameters(),
lr=train_config["learning_rate"],
weight_decay=train_config["l2_penalty"]
)
# Loss function
loss_fcn = nn.CrossEntropyLoss(label_smoothing=0.1)
start_time = time.strftime("%b-%d_%H-%M-%S")
train_step = 0
for epoch in range(train_config["num_of_epochs"]):
print("Epoch:", epoch)
decoder.train()
for x_img, x_words, y, tgt_padding_mask in train_loader:
optimizer.zero_grad()
train_step += 1
# Move the used tensors to defined device
x_img, x_words = x_img.to(device), x_words.to(device)
y = y.to(device)
tgt_padding_mask = tgt_padding_mask.to(device)
# Extract image features
with torch.no_grad():
img_features = encoder(x_img)
print(img_features.size())
img_features = img_features.view(img_features.size(0), img_features.size(1), -1)
img_features = img_features.permute(0, 2, 1)
img_features = img_features.detach()
print(img_features.size())
# Get the prediction of the decoder
y_pred = decoder(x_words, img_features, tgt_padding_mask, causal_mask)
print(y_pred.size())
tgt_padding_mask = torch.logical_not(tgt_padding_mask)
tgt_padding_mask.size()
y_pred = y_pred[tgt_padding_mask]
print(y_pred.size())
y = y[tgt_padding_mask]
print(y.size())
# Calculate the loss
loss = loss_fcn(y_pred, y.long())
print(loss.size())
# Update model weights
loss.backward()
log_gradient_norm(decoder, writer, train_step, "Before")
torch.nn.utils.clip_grad_norm_(decoder.parameters(), train_config["gradient_clipping"])
log_gradient_norm(decoder, writer, train_step, "After")
optimizer.step()
# print(loss.item())
writer.add_scalar("Train/Step-Loss", loss.item(), train_step)
writer.add_scalar("Train/Learning-Rate", learning_rate, train_step)
# Save the model and optimizer state
save_checkpoint(decoder, optimizer, start_time, epoch)
# Evaluate model performance
if (epoch + 1) % train_config["eval_period"] == 0:
with torch.no_grad():
encoder.eval()
decoder.eval()
# Evaluate model performance on subsets
train_bleu = evaluate(train_set, encoder, decoder, config, device)
valid_bleu = evaluate(valid_set, encoder, decoder, config, device)
# Log the evaluated BLEU score
for i, t_b in enumerate(train_bleu):
writer.add_scalar(f"Train/BLEU-{i+1}", t_b, epoch)
for i, v_b in enumerate(valid_bleu):
writer.add_scalar(f"Valid/BLEU-{i+1}", v_b, epoch)
decoder.train()
print()