-
Notifications
You must be signed in to change notification settings - Fork 1
/
tester.py
67 lines (38 loc) · 1.33 KB
/
tester.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from trainer import *
from model import *
import pickle
hidden_size =256
lang = []
sentenses = []
datapairs = []
'''
lang, sentenses = prepareData("data")
with open("lang", 'wb') as f:
pickle.dump(lang, f)
with open("sentenses", 'wb') as f:
pickle.dump(sentenses, f)
pairs = preparepairs('new 1')
for i in range(len(pairs)):
onehot = [0]*81
onehot[pairs[i][1]] = 1
datapairs.append([pairs[i][0], onehot])
with open("datapairs.txt", 'wb') as f:
pickle.dump(datapairs, f)
'''
with open("lang", 'rb') as f:
lang = pickle.load(f)
with open("sentenses", 'rb') as f:
sentenses = pickle.load(f)
with open("datapairs.txt", 'rb') as f:
datapairs = pickle.load(f)
encoder_ae = EncoderRNN(lang.n_words, hidden_size).to(device)
attn_decoder_ae = AttnDecoderRNN(hidden_size, lang.n_words, dropout_p=0.1).to(device)
decoder = DNNDecoder(hidden_size, 81).to(device)
trainItersae(encoder_ae, attn_decoder_ae, lang, sentenses, 75000, print_every=5000, learning_rate=0.005)
torch.save(encoder_ae, 'encoder.pt')
torch.save(attn_decoder_ae, 'atdecoder.pt')
encoder_ae = torch.load('encoder.pt')
trainIters(encoder_ae, decoder, lang, datapairs, 20000, print_every=2000, learning_rate=0.005)
torch.save(decoder, 'dnndecoder.pt')
decoder = torch.load('dnndecoder.pt')
evaluateAll(encoder_ae, decoder, datapairs, lang)