-
Notifications
You must be signed in to change notification settings - Fork 0
/
util.py
106 lines (90 loc) · 3.31 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
'''
Utility functions
'''
import sys
import json
import cPickle as pkl
import numpy
# batch preparation
def prepare_data(seqs_x, seqs_y, n_factors, maxlen=None):
# x: a list of sentences
lengths_x = [len(s) for s in seqs_x]
lengths_y = [len(s) for s in seqs_y]
if maxlen is not None:
new_seqs_x = []
new_seqs_y = []
new_lengths_x = []
new_lengths_y = []
for l_x, s_x, l_y, s_y in zip(lengths_x, seqs_x, lengths_y, seqs_y):
if l_x < maxlen and l_y < maxlen:
new_seqs_x.append(s_x)
new_lengths_x.append(l_x)
new_seqs_y.append(s_y)
new_lengths_y.append(l_y)
lengths_x = new_lengths_x
seqs_x = new_seqs_x
lengths_y = new_lengths_y
seqs_y = new_seqs_y
if len(lengths_x) < 1 or len(lengths_y) < 1:
return None, None, None, None
n_samples = len(seqs_x)
maxlen_x = numpy.max(lengths_x) + 1
maxlen_y = numpy.max(lengths_y) + 1
x = numpy.zeros((n_factors, maxlen_x, n_samples)).astype('int64')
y = numpy.zeros((maxlen_y, n_samples)).astype('int64')
x_mask = numpy.zeros((maxlen_x, n_samples)).astype('float32')
y_mask = numpy.zeros((maxlen_y, n_samples)).astype('float32')
for idx, [s_x, s_y] in enumerate(zip(seqs_x, seqs_y)):
x[:, :lengths_x[idx], idx] = zip(*s_x)
x_mask[:lengths_x[idx]+1, idx] = 1.
y[:lengths_y[idx], idx] = s_y
y_mask[:lengths_y[idx]+1, idx] = 1.
return x, x_mask, y, y_mask
#json loads strings as unicode; we currently still work with Python 2 strings, and need conversion
def unicode_to_utf8(d):
return dict((key.encode("UTF-8"), value) for (key,value) in d.items())
def load_dict(filename):
try:
with open(filename, 'rb') as f:
return unicode_to_utf8(json.load(f))
except:
with open(filename, 'rb') as f:
return pkl.load(f)
def load_config(basename):
try:
with open('%s.json' % basename, 'rb') as f:
return json.load(f)
except:
try:
with open('%s.pkl' % basename, 'rb') as f:
return pkl.load(f)
except:
sys.stderr.write('Error: config file {0}.json is missing\n'.format(basename))
sys.exit(1)
def seq2words(seq, inverse_dictionary, join=True):
seq = numpy.array(seq, dtype='int64')
assert len(seq.shape) == 1
return factoredseq2words(seq.reshape([seq.shape[0], 1]),
[inverse_dictionary],
join)
def factoredseq2words(seq, inverse_dictionaries, join=True):
assert len(seq.shape) == 2
assert len(inverse_dictionaries) == seq.shape[1]
words = []
for i, w in enumerate(seq):
factors = []
for j, f in enumerate(w):
if f == 0:
assert (i == len(seq) - 1) or (seq[i+1][j] == 0), \
('Zero not at the end of sequence', seq)
elif f in inverse_dictionaries[j]:
factors.append(inverse_dictionaries[j][f])
else:
factors.append('UNK')
word = '|'.join(factors)
words.append(word)
return ' '.join(words) if join else words
def reverse_dict(dictt):
keys, values = zip(*dictt.items())
r_dictt = dict(zip(values, keys))
return r_dictt