-
Notifications
You must be signed in to change notification settings - Fork 26
/
train.py
128 lines (97 loc) · 3.4 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import sys
import pickle
from collections import Counter
import numpy as np
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from dataset import CLEVR, collate_data, transform
from model import MACNetwork
batch_size = 64
n_epoch = 20
dim = 512
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def accumulate(model1, model2, decay=0.999):
par1 = dict(model1.named_parameters())
par2 = dict(model2.named_parameters())
for k in par1.keys():
par1[k].data.mul_(decay).add_(1 - decay, par2[k].data)
def train(epoch):
clevr = CLEVR(sys.argv[1], transform=transform)
train_set = DataLoader(
clevr, batch_size=batch_size, num_workers=4, collate_fn=collate_data
)
dataset = iter(train_set)
pbar = tqdm(dataset)
moving_loss = 0
net.train(True)
for image, question, q_len, answer, _ in pbar:
image, question, answer = (
image.to(device),
question.to(device),
answer.to(device),
)
net.zero_grad()
output = net(image, question, q_len)
loss = criterion(output, answer)
loss.backward()
optimizer.step()
correct = output.detach().argmax(1) == answer
correct = torch.tensor(correct, dtype=torch.float32).sum() / batch_size
if moving_loss == 0:
moving_loss = correct
else:
moving_loss = moving_loss * 0.99 + correct * 0.01
pbar.set_description(
'Epoch: {}; Loss: {:.5f}; Acc: {:.5f}'.format(
epoch + 1, loss.item(), moving_loss
)
)
accumulate(net_running, net)
clevr.close()
def valid(epoch):
clevr = CLEVR(sys.argv[1], 'val', transform=None)
valid_set = DataLoader(
clevr, batch_size=batch_size, num_workers=4, collate_fn=collate_data
)
dataset = iter(valid_set)
net_running.train(False)
family_correct = Counter()
family_total = Counter()
with torch.no_grad():
for image, question, q_len, answer, family in tqdm(dataset):
image, question = image.to(device), question.to(device)
output = net_running(image, question, q_len)
correct = output.detach().argmax(1) == answer.to(device)
for c, fam in zip(correct, family):
if c:
family_correct[fam] += 1
family_total[fam] += 1
with open('log/log_{}.txt'.format(str(epoch + 1).zfill(2)), 'w') as w:
for k, v in family_total.items():
w.write('{}: {:.5f}\n'.format(k, family_correct[k] / v))
print(
'Avg Acc: {:.5f}'.format(
sum(family_correct.values()) / sum(family_total.values())
)
)
clevr.close()
if __name__ == '__main__':
with open('data/dic.pkl', 'rb') as f:
dic = pickle.load(f)
n_words = len(dic['word_dic']) + 1
n_answers = len(dic['answer_dic'])
net = MACNetwork(n_words, dim).to(device)
net_running = MACNetwork(n_words, dim).to(device)
accumulate(net_running, net, 0)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=1e-4)
for epoch in range(n_epoch):
train(epoch)
valid(epoch)
with open(
'checkpoint/checkpoint_{}.model'.format(str(epoch + 1).zfill(2)), 'wb'
) as f:
torch.save(net_running.state_dict(), f)