-
Notifications
You must be signed in to change notification settings - Fork 0
/
util.py
72 lines (59 loc) · 2.08 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import cPickle
import torch
def convert_data(batch, vocab, device, reverse=False, unk=None, pad=None, sos=None, eos=None):
max_len = max(len(x) for x in batch)
padded = []
for x in batch:
if reverse:
padded.append(
([] if eos is None else [eos]) +
list(x[::-1]) +
([] if sos is None else [sos]))
else:
padded.append(
([] if sos is None else [sos]) +
list(x) +
([] if eos is None else [eos]))
padded[-1] = padded[-1] + [pad] * max(0, max_len - len(x))
padded[-1] = map(lambda v: vocab['stoi'][v] if v in vocab['stoi'] else vocab['stoi'][unk], padded[-1])
padded = torch.LongTensor(padded).to(device)
mask = padded.ne(vocab['stoi'][pad]).float()
return padded, mask
def convert_str(batch, vocab):
output = []
for x in batch:
output.append(map(lambda v: vocab['itos'][v], x))
return output
def invert_vocab(vocab):
v = {}
for k, idx in vocab.iteritems():
v[idx] = k
return v
def load_vocab(path):
f = open(path, 'rb')
vocab = cPickle.load(f)
f.close()
return vocab
def sort_batch(batch):
batch = zip(*batch)
batch = sorted(batch, key=lambda x: len(x[0]), reverse=True)
batch = zip(*batch)
return batch
def reverse_padded_sequence(inputs, lengths, batch_first=False):
if batch_first:
inputs = inputs.transpose(0, 1)
max_length, batch_size = inputs.size(0), inputs.size(1)
if len(lengths) != batch_size:
raise ValueError('inputs is incompatible with lengths.')
ind = [list(reversed(range(0, length))) + list(range(length, max_length))
for length in lengths]
ind = torch.LongTensor(ind).transpose(0, 1)
for dim in range(2, inputs.dim()):
ind = ind.unsqueeze(dim)
ind = ind.expand_as(inputs)
if inputs.is_cuda:
ind = ind.cuda(inputs.get_device())
reversed_inputs = torch.gather(inputs, 0, ind)
if batch_first:
reversed_inputs = reversed_inputs.transpose(0, 1)
return reversed_inputs