forked from Sarasra/models
-
Notifications
You must be signed in to change notification settings - Fork 0
/
lstm_ops.py
104 lines (86 loc) · 3.88 KB
/
lstm_ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# Copyright 2016 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Convolutional LSTM implementation."""
import tensorflow as tf
from tensorflow.contrib.slim import add_arg_scope
from tensorflow.contrib.slim import layers
def init_state(inputs,
state_shape,
state_initializer=tf.zeros_initializer(),
dtype=tf.float32):
"""Helper function to create an initial state given inputs.
Args:
inputs: input Tensor, at least 2D, the first dimension being batch_size
state_shape: the shape of the state.
state_initializer: Initializer(shape, dtype) for state Tensor.
dtype: Optional dtype, needed when inputs is None.
Returns:
A tensors representing the initial state.
"""
if inputs is not None:
# Handle both the dynamic shape as well as the inferred shape.
inferred_batch_size = inputs.get_shape().with_rank_at_least(1)[0]
dtype = inputs.dtype
else:
inferred_batch_size = 0
initial_state = state_initializer(
[inferred_batch_size] + state_shape, dtype=dtype)
return initial_state
@add_arg_scope
def basic_conv_lstm_cell(inputs,
state,
num_channels,
filter_size=5,
forget_bias=1.0,
scope=None,
reuse=None):
"""Basic LSTM recurrent network cell, with 2D convolution connctions.
We add forget_bias (default: 1) to the biases of the forget gate in order to
reduce the scale of forgetting in the beginning of the training.
It does not allow cell clipping, a projection layer, and does not
use peep-hole connections: it is the basic baseline.
Args:
inputs: input Tensor, 4D, batch x height x width x channels.
state: state Tensor, 4D, batch x height x width x channels.
num_channels: the number of output channels in the layer.
filter_size: the shape of the each convolution filter.
forget_bias: the initial value of the forget biases.
scope: Optional scope for variable_scope.
reuse: whether or not the layer and the variables should be reused.
Returns:
a tuple of tensors representing output and the new state.
"""
spatial_size = inputs.get_shape()[1:3]
if state is None:
state = init_state(inputs, list(spatial_size) + [2 * num_channels])
with tf.variable_scope(scope,
'BasicConvLstmCell',
[inputs, state],
reuse=reuse):
inputs.get_shape().assert_has_rank(4)
state.get_shape().assert_has_rank(4)
c, h = tf.split(axis=3, num_or_size_splits=2, value=state)
inputs_h = tf.concat(axis=3, values=[inputs, h])
# Parameters of gates are concatenated into one conv for efficiency.
i_j_f_o = layers.conv2d(inputs_h,
4 * num_channels, [filter_size, filter_size],
stride=1,
activation_fn=None,
scope='Gates')
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i, j, f, o = tf.split(axis=3, num_or_size_splits=4, value=i_j_f_o)
new_c = c * tf.sigmoid(f + forget_bias) + tf.sigmoid(i) * tf.tanh(j)
new_h = tf.tanh(new_c) * tf.sigmoid(o)
return new_h, tf.concat(axis=3, values=[new_c, new_h])