forked from salesforce/awd-lstm-lm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
pointer.py
135 lines (122 loc) · 5.55 KB
/
pointer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import argparse
import time
import math
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
import data
import model
from utils import batchify, get_batch, repackage_hidden
parser = argparse.ArgumentParser(description='PyTorch PennTreeBank RNN/LSTM Language Model')
parser.add_argument('--data', type=str, default='data/penn',
help='location of the data corpus')
parser.add_argument('--model', type=str, default='LSTM',
help='type of recurrent net (LSTM, QRNN)')
parser.add_argument('--save', type=str,default='best.pt',
help='model to use the pointer over')
parser.add_argument('--cuda', action='store_false',
help='use CUDA')
parser.add_argument('--bptt', type=int, default=5000,
help='sequence length')
parser.add_argument('--window', type=int, default=3785,
help='pointer window length')
parser.add_argument('--theta', type=float, default=0.6625523432485668,
help='mix between uniform distribution and pointer softmax distribution over previous words')
parser.add_argument('--lambdasm', type=float, default=0.12785920428335693,
help='linear mix between only pointer (1) and only vocab (0) distribution')
args = parser.parse_args()
###############################################################################
# Load data
###############################################################################
corpus = data.Corpus(args.data)
eval_batch_size = 1
test_batch_size = 1
#train_data = batchify(corpus.train, args.batch_size)
val_data = batchify(corpus.valid, test_batch_size, args)
test_data = batchify(corpus.test, test_batch_size, args)
###############################################################################
# Build the model
###############################################################################
ntokens = len(corpus.dictionary)
criterion = nn.CrossEntropyLoss()
def one_hot(idx, size, cuda=True):
a = np.zeros((1, size), np.float32)
a[0][idx] = 1
v = Variable(torch.from_numpy(a))
if cuda: v = v.cuda()
return v
def evaluate(data_source, batch_size=10, window=args.window):
# Turn on evaluation mode which disables dropout.
if args.model == 'QRNN': model.reset()
model.eval()
total_loss = 0
ntokens = len(corpus.dictionary)
hidden = model.init_hidden(batch_size)
next_word_history = None
pointer_history = None
for i in range(0, data_source.size(0) - 1, args.bptt):
if i > 0: print(i, len(data_source), math.exp(total_loss / i))
data, targets = get_batch(data_source, i, evaluation=True, args=args)
output, hidden, rnn_outs, _ = model(data, hidden, return_h=True, decode=True)
rnn_out = rnn_outs[-1].squeeze()
output_flat = output.view(-1, ntokens)
###
# Fill pointer history
start_idx = len(next_word_history) if next_word_history is not None else 0
next_word_history = torch.cat([one_hot(t.data, ntokens) for t in targets]) if next_word_history is None else torch.cat([next_word_history, torch.cat([one_hot(t.data, ntokens) for t in targets])])
#print(next_word_history)
pointer_history = Variable(rnn_out.data) if pointer_history is None else torch.cat([pointer_history, Variable(rnn_out.data)], dim=0)
#print(pointer_history)
###
# Built-in cross entropy
# total_loss += len(data) * criterion(output_flat, targets).data[0]
###
# Manual cross entropy
# softmax_output_flat = torch.nn.functional.softmax(output_flat)
# soft = torch.gather(softmax_output_flat, dim=1, index=targets.view(-1, 1))
# entropy = -torch.log(soft)
# total_loss += len(data) * entropy.mean().data[0]
###
# Pointer manual cross entropy
loss = 0
softmax_output_flat = torch.nn.functional.softmax(output_flat)
for idx, vocab_loss in enumerate(softmax_output_flat):
p = vocab_loss
if start_idx + idx > window:
valid_next_word = next_word_history[start_idx + idx - window:start_idx + idx]
valid_pointer_history = pointer_history[start_idx + idx - window:start_idx + idx]
logits = torch.mv(valid_pointer_history, rnn_out[idx])
theta = args.theta
ptr_attn = torch.nn.functional.softmax(theta * logits).view(-1, 1)
ptr_dist = (ptr_attn.expand_as(valid_next_word) * valid_next_word).sum(0).squeeze()
lambdah = args.lambdasm
p = lambdah * ptr_dist + (1 - lambdah) * vocab_loss
###
target_loss = p[targets[idx].data]
loss += (-torch.log(target_loss)).data
total_loss += loss / batch_size
###
hidden = repackage_hidden(hidden)
next_word_history = next_word_history[-window:]
pointer_history = pointer_history[-window:]
return total_loss / len(data_source)
# Load the best saved model.
with open(args.save, 'rb') as f:
if not args.cuda:
model = torch.load(f, map_location=lambda storage, loc: storage)
else:
model = torch.load(f)
print(model)
# Run on val data.
val_loss = evaluate(val_data, test_batch_size)
print('=' * 89)
print('| End of pointer | val loss {:5.2f} | val ppl {:8.2f}'.format(
val_loss, math.exp(val_loss)))
print('=' * 89)
# Run on test data.
test_loss = evaluate(test_data, test_batch_size)
print('=' * 89)
print('| End of pointer | test loss {:5.2f} | test ppl {:8.2f}'.format(
test_loss, math.exp(test_loss)))
print('=' * 89)