-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhelper.py
31 lines (29 loc) · 1.18 KB
/
helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import tensorflow as tf
def last_relevant_output(output, sequence_length):
"""
Given the outputs of a LSTM, get the last relevant output that
is not padding. We assume that the last 2 dimensions of the input
represent (sequence_length, hidden_size).
Parameters
----------
output: Tensor
A tensor, generally the output of a tensorflow RNN.
The tensor index sequence_lengths+1 is selected for each
instance in the output.
sequence_length: Tensor
A tensor of dimension (batch_size, ) indicating the length
of the sequences before padding was applied.
Returns
-------
last_relevant_output: Tensor
The last relevant output (last element of the sequence), as retrieved
by the output Tensor and indicated by the sequence_length Tensor.
"""
with tf.name_scope("last_relevant_output"):
batch_size = tf.shape(output)[0]
max_length = tf.shape(output)[-2]
out_size = int(output.get_shape()[-1])
index = tf.range(0, batch_size) * max_length + (sequence_length - 1)
flat = tf.reshape(output, [-1, out_size])
relevant = tf.gather(flat, index)
return relevant