-
Notifications
You must be signed in to change notification settings - Fork 4
/
predict.py
86 lines (75 loc) · 3.65 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from torch import nn
from data_util import voc
from seq2seq import EncoderRNN, LuongAttnDecoderRNN, evaluateInput
from config import *
######################################################################
# 初始化模型参数
# 如果loadFilename不空,则从中加载模型
# if loadFilename:
# 因为这里是预测,所以直接加载训练好的模型。
# 如果训练和加载是一条机器,那么直接加载
checkpoint = torch.load(model_checkpoint)
# 否则比如checkpoint是在GPU上得到的,但是我们现在又用CPU来训练或者测试,那么注释掉下面的代码
#checkpoint = torch.load(loadFilename, map_location=torch.device('cpu'))
encoder_sd = checkpoint['en']
decoder_sd = checkpoint['de']
encoder_optimizer_sd = checkpoint['en_opt']
decoder_optimizer_sd = checkpoint['de_opt']
embedding_sd = checkpoint['embedding']
voc.__dict__ = checkpoint['voc_dict']
print('Building encoder and decoder ...')
# 初始化word embedding
embedding = nn.Embedding(voc.num_words, hidden_size)
if model_checkpoint:
embedding.load_state_dict(embedding_sd)
# 初始化encoder和decoder模型
encoder = EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout)
decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size, voc.num_words, decoder_n_layers, dropout)
if model_checkpoint:
encoder.load_state_dict(encoder_sd)
decoder.load_state_dict(decoder_sd)
# 使用合适的设备
encoder = encoder.to(device)
decoder = decoder.to(device)
print('Models built and ready to go!')
class GreedySearchDecoder(nn.Module):
def __init__(self, encoder, decoder):
super(GreedySearchDecoder, self).__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, input_seq, input_length, max_length):
# Encoder的Forward计算
encoder_outputs, encoder_hidden = self.encoder(input_seq, input_length)
# 把Encoder最后时刻的隐状态作为Decoder的初始值
decoder_hidden = encoder_hidden[:decoder.n_layers]
# 因为我们的函数都是要求(time,batch),因此即使只有一个数据,也要做出二维的。
# Decoder的初始输入是SOS
decoder_input = torch.ones(1, 1, device=device, dtype=torch.long) * SOS_token
# 用于保存解码结果的tensor
all_tokens = torch.zeros([0], device=device, dtype=torch.long)
all_scores = torch.zeros([0], device=device)
# 循环,这里只使用长度限制,后面处理的时候把EOS去掉了。
for _ in range(max_length):
# Decoder forward一步
decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs)
# decoder_outputs是(batch=1, vob_size)
# 使用max返回概率最大的词和得分
decoder_scores, decoder_input = torch.max(decoder_output, dim=1)
# 把解码结果保存到all_tokens和all_scores里
all_tokens = torch.cat((all_tokens, decoder_input), dim=0)
all_scores = torch.cat((all_scores, decoder_scores), dim=0)
# decoder_input是当前时刻输出的词的ID,这是个一维的向量,因为max会减少一维。
# 但是decoder要求有一个batch维度,因此用unsqueeze增加batch维度。
decoder_input = torch.unsqueeze(decoder_input, 0)
# 返回所有的词和得分。
return all_tokens, all_scores
######################################################################
# Run Evaluation
#
# 进入eval模式,从而去掉dropout。
encoder.eval()
decoder.eval()
# 构造searcher对象
searcher = GreedySearchDecoder(encoder, decoder)
# 测试
evaluateInput(encoder, decoder, searcher, voc)