-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathdemo.py
158 lines (138 loc) · 4.9 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import argparse
import numpy as np
import copy
import torch
from scipy.spatial.distance import cosine
from scipy.spatial import KDTree
from allennlp.commands.elmo import ElmoEmbedder
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
'--elmo_weights_path',
type=str,
default='models/$l_weights.hdf5',
help="Path to elmo weights files - use $l as a placeholder for language.")
parser.add_argument(
'--elmo_options_path',
type=str,
default='models/options262.json',
help="Path to elmo options file. n_characters in the file should be 262")
parser.add_argument(
'--align_path',
type=str,
default='models/align/$l_best_mapping.pth',
help="Path to the aligning matrix saved in a pyTorch format. Use $l as a placeholder for language.")
parser.add_argument(
'-l1',
'--language1',
type=str,
default='en',
help="language of sentence 1")
parser.add_argument(
'-s1',
'--sent1',
type=str,
default=
'A house cat is valued by humans for companionship and for its ability to hunt rodents.',
help="sentence in language 1")
parser.add_argument(
'-w1',
'--word1',
type=str,
default='cat',
help=
"Examined word from the sentence of language 1 (first occurrence will be used)"
)
parser.add_argument(
'-l2',
'--language2',
type=str,
default='es',
help="language of sentence 2")
parser.add_argument(
'-s2',
'--sent2',
type=str,
default=
'el gato doméstico está incluido en la lista 100 de las especies exóticas invasoras más dañinas del mundo.',
help="sentence in language 2")
parser.add_argument(
'-w2',
'--word2',
type=str,
default='gato',
help=
"Examined word from the sentence of language 2 (first occurrence will be used)"
)
parser.add_argument(
'--layer', type=int, default=1, help="Layer of Elmo to compute for")
parser.add_argument(
'-c', '--cuda_device', type=int, default=-1, help="Cuda device")
args = parser.parse_args()
def parse_config(args):
'''
Replaces $l for the two languages.
Prints the args
'''
new_args = copy.deepcopy(args)
for k in vars(args):
val = getattr(args, k)
if type(val) is str and "$l" in val:
new_val = val.replace("$l", args.language1)
new_k = "{}_{}".format(k, "l1")
setattr(new_args, new_k, new_val)
new_val = val.replace("$l", args.language2)
new_k = "{}_{}".format(k, "l2")
setattr(new_args, new_k, new_val)
print('-' * 30)
for k in vars(new_args):
print("{}: {}".format(k, getattr(new_args, k)))
print('-' * 30)
return new_args
def get_sent_embeds(sent, elmo_options_file, elmo_weights_file, layer,
cuda_device):
'''
Get the embeddings of the sentence words.
sent - list of tokens
elmo_options_file - json for model. n_characters should be 262
elmo_weights_file - saved model
layer - what layer of ELMo to output
cuda_device - cuda device
returns a numpy array with the embeddings per token for the selected layer
'''
elmo = ElmoEmbedder(elmo_options_file, elmo_weights_file, cuda_device)
s_embeds = elmo.embed_sentence(sent)
layer_embeds = s_embeds[layer,:,:]
return layer_embeds
def analyze_sents(embeds_l1, embeds_l2, sent1, sent2, w1_ind, w2_ind, k=5):
kdt = KDTree(embeds_l1)
emb2 = embeds_l2[w2_ind]
top_k_inds = kdt.query(emb2, k)[1]
top_k_words = [sent1[i] for i in top_k_inds]
print('Nearest {} neighbors for {} in "{}":\n{}'.format(k, sent2[w2_ind], ' '.join(sent1), ' ,'.join(top_k_words)))
emb1 = embeds_l1[w1_ind]
dist = cosine(emb1, emb2)
print("Cosine distance between {} and {}: {}".format(sent1[w1_ind], sent2[w2_ind],dist))
if __name__ == '__main__':
args = parse_config(args)
# Language 1
sent1_tokens = args.sent1.strip().split()
w1_ind = sent1_tokens.index(args.word1)
s1_embeds = get_sent_embeds(sent1_tokens, args.elmo_options_path,
args.elmo_weights_path_l1, args.layer,
args.cuda_device)
align1 = torch.load(args.align_path_l1)
s1_embeds_aligned = np.matmul(s1_embeds, align1.transpose())
# Language 2
sent2_tokens = args.sent2.strip().split()
w2_ind = sent2_tokens.index(args.word2)
s2_embeds = get_sent_embeds(sent2_tokens, args.elmo_options_path,
args.elmo_weights_path_l2, args.layer,
args.cuda_device)
align2 = torch.load(args.align_path_l2)
s2_embeds_aligned = np.matmul(s2_embeds, align2.transpose())
# Analyse
print("--- Before alignment:")
analyze_sents(s1_embeds, s2_embeds, sent1_tokens, sent2_tokens, w1_ind, w2_ind)
print("\n--- After alignment:")
analyze_sents(s1_embeds_aligned, s2_embeds_aligned, sent1_tokens, sent2_tokens, w1_ind, w2_ind)