-
Notifications
You must be signed in to change notification settings - Fork 2
/
act_model.py
150 lines (123 loc) · 5.22 KB
/
act_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import functools
import tensorflow as tf
import tensorflow.contrib.rnn as rnn
import tensorflow.contrib.seq2seq as s2s
from act_wrapper import ACTWrapper
def lazy_property(func):
attribute = '_cache_' + func.__name__
@property
@functools.wraps(func)
def decorator(self):
if not hasattr(self, attribute):
setattr(self, attribute, func(self))
return getattr(self, attribute)
return decorator
class ACTModel:
def __init__(self, data, target, time_steps, num_classes, rnn_cell,
num_outputs=1, time_penalty=0.001, seq_length=None,
target_offset=None, optimizer=None):
self.data = data
self.target = target
self.time_steps = time_steps
self.num_classes = num_classes
self.cell = rnn_cell
self.num_outputs = num_outputs
self.time_penalty = time_penalty
self.seq_length = seq_length
self.target_offset = target_offset
self.optimizer = optimizer if optimizer \
else tf.train.AdamOptimizer(0.001)
self.num_hidden = rnn_cell.output_size
self._softmax_loss = None
self._ponder_loss = None
self._ponder_steps = None
self._boolean_mask = None
self._numerical_mask = None
if self.seq_length is not None:
self._boolean_mask = tf.sequence_mask(self.seq_length, self.time_steps)
if self.target_offset is not None:
offset_mask = tf.logical_not(tf.sequence_mask(self.target_offset, self.time_steps))
self._boolean_mask = tf.logical_and(self._boolean_mask, offset_mask)
self._numerical_mask = tf.cast(self._boolean_mask, data.dtype)
self.logits
self.training
self.evaluation
@lazy_property
def logits(self):
rnn_outputs, rnn_state = rnn.static_rnn(
self.cell, tf.unstack(self.data, axis=1),
dtype=tf.float32
)
rnn_outputs = tf.reshape(rnn_outputs, [-1, self.num_hidden])
logits_per_output = []
for i in range(self.num_outputs):
initial_weights = tf.truncated_normal([self.num_hidden, self.num_classes], stddev=0.01)
initial_biases = tf.constant(0.1, shape=[self.num_classes])
output_weights = tf.Variable(initial_weights, name="output_weights_" + str(i))
output_biases = tf.Variable(initial_biases, name="output_biases_" + str(i))
logits = tf.matmul(rnn_outputs, output_weights) + output_biases
reshaped = tf.reshape(logits, [self.time_steps, -1, self.num_classes])
logits_per_output.append(tf.transpose(reshaped, perm=(1, 0, 2)))
return logits_per_output
@lazy_property
def evaluation(self):
mistakes_per_output = []
for i in range(len(self.logits)):
if self.seq_length is not None:
mistakes = tf.reduce_any(
tf.logical_and(
tf.not_equal(self.target[:, :, i], tf.argmax(self.logits[i], 2)),
self._boolean_mask
), axis=1
)
else:
mistakes = tf.reduce_any(
tf.not_equal(self.target[:, :, i], tf.argmax(self.logits[i], 2)), axis=1
)
mistakes_per_output.append(mistakes)
if len(mistakes_per_output) == 1:
all_mistakes = mistakes_per_output[0]
else:
stacked_mistakes = tf.stack(mistakes_per_output)
all_mistakes = tf.reduce_any(stacked_mistakes, axis=0)
return tf.reduce_mean(tf.cast(all_mistakes, tf.float32))
@lazy_property
def training(self):
softmax_loss_per_output = []
for i in range(len(self.logits)):
if self.seq_length is not None:
softmax_loss = s2s.sequence_loss(
self.logits[i], self.target[:, :, i], self._numerical_mask
)
else:
softmax_loss = s2s.sequence_loss(
self.logits[i], self.target[:, :, i],
tf.ones_like(self.target[:, :, i], self.logits[i].dtype)
)
softmax_loss_per_output.append(softmax_loss)
if len(softmax_loss_per_output) == 1:
self._softmax_loss = softmax_loss_per_output[0]
else:
self._softmax_loss = tf.add_n(softmax_loss_per_output)
if isinstance(self.cell, ACTWrapper):
self._ponder_loss = self.time_penalty * self.cell.get_ponder_cost(self.seq_length)
self._ponder_steps = self.cell.get_ponder_steps(self.seq_length)
total_loss = self._softmax_loss + self._ponder_loss
else:
total_loss = self._softmax_loss
return self.optimizer.minimize(total_loss)
@lazy_property
def softmax_loss(self):
return self._softmax_loss
@lazy_property
def ponder_loss(self):
if isinstance(self.cell, ACTWrapper):
return self._ponder_loss
else:
raise TypeError("ACT wrapper is not used")
@lazy_property
def ponder_steps(self):
if isinstance(self.cell, ACTWrapper):
return self._ponder_steps
else:
raise TypeError("ACT wrapper is not used")