-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
142 lines (106 loc) · 3.18 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from data import jsonToDict, dictToJson
from models import SeqToProb
from data import TextData
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import wandb
import tqdm
class LanguageModel:
def __init__(self, params):
self.params = params
self.model = SeqToProb(
num_probs=params['num_probs'],
emb_dim=params['emb_dim'],
h_dim=params['h_dim'],
num_rnn=params['num_rnn'],
dropout=params['dropout'],
lr=params['lr'])
self.criterion = nn.CrossEntropyLoss()
def decode(self, decoder, sequence):
return [decoder[str(word)] for word in sequence]
def encode(self, encoder, sequence):
return [encoder[word] for word in sequence]
def accuracy(self, probs, y):
predicted_words = torch.argmax(probs, dim=-1)
total_correct = (predicted_words == y).sum()
return total_correct/predicted_words.shape[0]
def generate(self, encoder, decoder, seq_len, num_words):
self.model.eval()
seq = ['<sos>'] * seq_len
generated = []
for i in range(num_words):
x = self.encode(encoder, seq[i:i+seq_len])
x = torch.LongTensor([x]).to(self.model.device)
with torch.no_grad():
logits = self.model(x)
dist = torch.distributions.Categorical(logits=logits)
word = dist.sample().item()
word = decoder[str(word)]
if word == '<eos>':
break
seq.append(word)
generated.append(word)
return generated
def train(self, train, test, num_epochs, path):
## Putting data into loaders
train_loader = DataLoader(
TextData(train),
batch_size=self.params['batch_size'],
num_workers=1,
shuffle=True)
test_loader = DataLoader(
TextData(test),
batch_size=self.params['batch_size'],
num_workers=1,
shuffle=True)
## Begin training
for epoch in range(num_epochs):
## Train cycle
self.model.train()
train_loss = 0
train_batches = 0
for X, y in tqdm.tqdm(train_loader):
X, y = X.cuda(), y.cuda()
self.model.opt.zero_grad()
## Forward
logits = self.model(X)
loss = self.criterion(logits, y)
## Backward
loss.backward()
self.model.opt.step()
## Measure stats
train_loss += loss.item()
train_batches += 1
## Log train loss for each batch so we get a good graph
wandb.log({'train_loss':loss.item()})
## Testing cycle
self.model.eval()
test_loss = 0
test_acc = 0
test_batches = 0
for X, y in tqdm.tqdm(test_loader):
X, y = X.cuda(), y.cuda()
## Forward
with torch.no_grad():
logits = self.model(X)
loss = self.criterion(logits, y)
## Measure stats
test_loss += loss.item()
probs = torch.softmax(logits, dim=-1)
test_acc += self.accuracy(probs, y)
test_batches += 1
## Report statistics
avg_train_loss = train_loss / test_batches
avg_test_loss = test_loss / test_batches
avg_test_acc = test_acc / test_batches
print(f'\
Train loss: {train_loss / train_batches}, Test loss: {avg_test_loss}, Test acc: {avg_test_acc}')
wandb.log({
'test_loss':avg_test_loss,
'test_acc':avg_test_acc})
## Save model each epoch
torch.save(self.model.state_dict(), path / f'model{epoch}.pt')
## Save model parameters at the end
dictToJson(self.params, path / 'model_params.json')