-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
110 lines (83 loc) · 2.83 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
import numpy as np
from torch.autograd import Variable
from collections import defaultdict, Counter, OrderedDict
class OrderedCounter(Counter, OrderedDict):
"""Counter that remembers the order elements are first encountered"""
def __repr__(self):
return '%s(%r)' % (self.__class__.__name__, OrderedDict(self))
def __reduce__(self):
return self.__class__, (OrderedDict(self),)
def to_var(x):
if torch.cuda.is_available():
x = x.cuda()
return x
def idx2word(idx, i2w, pad_idx):
sent_str = [str()]*len(idx)
for i, sent in enumerate(idx):
for word_id in sent:
if word_id == pad_idx:
break
sent_str[i] += i2w[str(word_id.item())] + " "
sent_str[i] = sent_str[i].strip()
return sent_str
def interpolate(start, end, steps):
interpolation = np.zeros((start.shape[0], steps + 2))
for dim, (s, e) in enumerate(zip(start, end)):
interpolation[dim] = np.linspace(s, e, steps+2)
return interpolation.T
def expierment_name(args, ts):
exp_name = str()
exp_name += "BS=%i_" % args.batch_size
exp_name += "LR={}_".format(args.learning_rate)
exp_name += "EB=%i_" % args.embedding_size
exp_name += "%s_" % args.rnn_type.upper()
exp_name += "HS=%i_" % args.hidden_size
exp_name += "L=%i_" % args.num_layers
exp_name += "BI=%i_" % args.bidirectional
exp_name += "LS=%i_" % args.latent_size
exp_name += "WD={}_".format(args.word_dropout)
exp_name += "ANN=%s_" % args.anneal_function.upper()
exp_name += "K={}_".format(args.k)
exp_name += "X0=%i_" % args.x0
exp_name += "TS=%s" % ts
return exp_name
def similarity(a, b, do_len=True):
a = ' '.join(a.split(' ')[3:]).strip().split('|')
b = ' '.join(b.split(' ')[3:]).strip().split('|')
while '' in a:
a.remove('')
while '' in b:
b.remove('')
bars = []
for i in range(min(len(a), len(b))):
original = a[i].split(' ')
sampled = b[i].split(' ')
while '' in original:
original.remove('')
while '' in sampled:
sampled.remove('')
'''
print(original)
print(sampled)
'''
tot = 0
for c in original:
num = original.count(c)
#print(f'{c}\t:\t{num}')
weight = num/len(original)
diff = abs(num - sampled.count(c))/num
tot += diff*weight
bars.append(tot)
# correct bar measure
correct = np.array(bars).sum()/len(bars)
# correct weight: num of 'correct' bars / num of generated bars
c_w = min(len(a), len(b))/len(b)
# wrong bars
wrong = abs(len(a) - len(b))
# wrong weight: num of wrong bars / num of expected bars
if do_len:
w_w = wrong / len(a)
else:
w_w = 0
return correct*c_w + wrong*w_w