Skip to content

Commit

Permalink
Add masked LSTM support
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Cong <congyc@amazon.com>
  • Loading branch information
Yu Cong committed Aug 27, 2022
1 parent bc677a1 commit 9775fa5
Show file tree
Hide file tree
Showing 3 changed files with 320 additions and 26 deletions.
207 changes: 207 additions & 0 deletions tests/test_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,5 +793,212 @@ def func(x):
return tf.identity(y[0], name="output")
self.run_test_case(func, {"input:0": x_val}, [], ["output:0"], rtol=1e-05, atol=1e-06)

@check_tf_min_version("2.0")
@skip_tf_versions("2.1", "Bug in TF 2.1")
def test_keras_masked_lstm_embedding_unidirectional(self):
for go_backwards in [True, False]:
for return_sequences in [True, False]:
timesteps = 4
# Note: masked LSTM only support post-padded input after conversion
# test case sequence_lens = [4, 2, 0]
x_val = np.array([
[1, 2, 3, 4],
[5, 6, 0, 0],
[0, 0, 0, 0]
], dtype=np.int32)

model_in = tf.keras.layers.Input((timesteps,), dtype="int32")
x_embedding = tf.keras.layers.Embedding(
input_dim=10,
output_dim=5,
mask_zero=True,
embeddings_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=41),
)(model_in)

# RNN layer inherits the mask propagated from above embedding layer
model_out = tf.keras.layers.LSTM(
units=5,
go_backwards=go_backwards,
return_sequences=return_sequences,
return_state=True,
kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=42),
bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=43),
recurrent_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=44),
)(x_embedding)
model = tf.keras.models.Model(inputs=model_in, outputs=model_out)

def func(x):
y = model(x)
if return_sequences:
return (
# skipping output Y when return_sequences=True due to inconsistent
# ORT and TF behaviors: https://sim.amazon.com/issues/NEMORT-1712
tf.identity(y[1], name="output_yh"),
tf.identity(y[2], name="output_yc"))
return(
tf.identity(y[0], name="output_y"),
tf.identity(y[1], name="output_yh"),
tf.identity(y[2], name="output_yc"))

output_list = ["output_yh:0", "output_yc:0"] if return_sequences \
else ["output_y:0", "output_yh:0", "output_yc:0"]
self.run_test_case(func, {"input:0": x_val}, [], output_list, rtol=1e-05, atol=1e-06)

@check_tf_min_version("2.0")
@skip_tf_versions("2.1", "Bug in TF 2.1")
def test_keras_masked_lstm_embedding_bidirectional(self):
for return_sequences in [False, True]:
timesteps = 4
# Note: masked LSTM only support post-padded input after conversion
# test case sequence_lens = [4, 2, 0]
x_val = np.array([
[1, 2, 3, 4],
[5, 6, 0, 0],
[0, 0, 0, 0]
], dtype=np.int32)

model_in = tf.keras.layers.Input((timesteps,), dtype="int32")
x_embedding = tf.keras.layers.Embedding(
input_dim=10,
output_dim=5,
mask_zero=True,
embeddings_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=41),
)(model_in)

# RNN layer inherits the mask propagated from above embedding layer
lstm_layer = tf.keras.layers.LSTM(
units=5,
go_backwards=False,
return_sequences=return_sequences,
return_state=True,
kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=42),
bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=43),
recurrent_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=44),
)
model_out = tf.keras.layers.Bidirectional(lstm_layer)(x_embedding)
model = tf.keras.models.Model(inputs=model_in, outputs=model_out)

def func(x):
y = model(x)
if return_sequences:
return (
# skipping output Y when return_sequences=True due to inconsistent
# ORT and TF behaviors: https://sim.amazon.com/issues/NEMORT-1712
tf.identity(y[1], name="output_yh_f"),
tf.identity(y[2], name="output_yc_f"),
tf.identity(y[3], name="output_yh_r"),
tf.identity(y[4], name="output_yc_r"))
return(
tf.identity(y[0], name="output_y_concat"),
tf.identity(y[1], name="output_yh_f"),
tf.identity(y[2], name="output_yc_f"),
tf.identity(y[3], name="output_yh_r"),
tf.identity(y[4], name="output_yc_r"))

output_list = ["output_yh_f:0", "output_yc_f:0", "output_yh_r:0", "output_yc_r:0"] if return_sequences \
else ["output_y_concat:0", "output_yh_f:0", "output_yc_f:0", "output_yh_r:0", "output_yc_r:0"]

# translate single BiLSTM to two forward LSTMs
self.run_test_case(func, {"input:0": x_val}, [], output_list, rtol=1e-05, atol=1e-06,
require_lstm_count=2)

@check_tf_min_version("2.0")
@skip_tf_versions("2.1", "Bug in TF 2.1")
def test_keras_masked_lstm_unidirectional(self):
for go_backwards in [True, False]:
for return_sequences in [True, False]:
batch_size, timesteps, feat = 3, 4, 5
in_shape = (timesteps, feat)
x_val = np.random.uniform(size=[batch_size, timesteps, feat]).astype(np.float32)
# Note: masked LSTM only support post-padded input after conversion
# test case sequence_lens = [4, 2, 0]
x_val[1, 2:, :] = 0.
x_val[2, :, :] = 0.

model_in = tf.keras.layers.Input(shape=in_shape, dtype="float32")
x_masked = tf.keras.layers.Masking(mask_value=0.)(model_in)

# RNN layer inherits the mask propagated from above mask layer
model_out = tf.keras.layers.LSTM(
units=5,
go_backwards=go_backwards,
return_sequences=return_sequences,
return_state=True,
kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=42),
bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=43),
recurrent_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=44),
)(x_masked)
model = tf.keras.models.Model(inputs=model_in, outputs=model_out)

def func(x):
y = model(x)
if return_sequences:
return (
# skipping output Y when return_sequences=True due to inconsistent
# ORT and TF behaviors: https://sim.amazon.com/issues/NEMORT-1712
tf.identity(y[1], name="output_yh"),
tf.identity(y[2], name="output_yc"))
return(
tf.identity(y[0], name="output_y"),
tf.identity(y[1], name="output_yh"),
tf.identity(y[2], name="output_yc"))

output_list = ["output_yh:0", "output_yc:0"] if return_sequences \
else ["output_y:0", "output_yh:0", "output_yc:0"]
self.run_test_case(func, {"input:0": x_val}, [], output_list, rtol=1e-05, atol=1e-06)

@check_tf_min_version("2.0")
@skip_tf_versions("2.1", "Bug in TF 2.1")
def test_keras_masked_lstm_bidirectional(self):
for return_sequences in [False, True]:
batch_size, timesteps, feat = 3, 4, 5
in_shape = (timesteps, feat)
x_val = np.random.uniform(size=[batch_size, timesteps, feat]).astype(np.float32)
# Note: masked LSTM only support post-padded input after conversion
# test case sequence_lens = [4, 2, 0]
x_val[1, 2:, :] = 0.
x_val[2, :, :] = 0.

model_in = tf.keras.layers.Input(shape=in_shape, dtype="float32")
x_masked = tf.keras.layers.Masking(mask_value=0.)(model_in)

# RNN layer inherits the mask propagated from above mask layer
lstm_layer = tf.keras.layers.LSTM(
units=5,
go_backwards=False,
return_sequences=return_sequences,
return_state=True,
kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=42),
bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=43),
recurrent_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=44),
)
model_out = tf.keras.layers.Bidirectional(lstm_layer)(x_masked)
model = tf.keras.models.Model(inputs=model_in, outputs=model_out)

def func(x):
y = model(x)
if return_sequences:
return (
# skipping output Y when return_sequences=True due to inconsistent
# ORT and TF behaviors: https://sim.amazon.com/issues/NEMORT-1712
tf.identity(y[1], name="output_yh_f"),
tf.identity(y[2], name="output_yc_f"),
tf.identity(y[3], name="output_yh_r"),
tf.identity(y[4], name="output_yc_r"))
return(
tf.identity(y[0], name="output_y_concat"),
tf.identity(y[1], name="output_yh_f"),
tf.identity(y[2], name="output_yc_f"),
tf.identity(y[3], name="output_yh_r"),
tf.identity(y[4], name="output_yc_r"))

output_list = ["output_yh_f:0", "output_yc_f:0", "output_yh_r:0", "output_yc_r:0"] if return_sequences \
else ["output_y_concat:0", "output_yh_f:0", "output_yc_f:0", "output_yh_r:0", "output_yc_r:0"]

# translate single BiLSTM to two forward LSTMs
self.run_test_case(func, {"input:0": x_val}, [], output_list, rtol=1e-05, atol=1e-06,
require_lstm_count=2)


if __name__ == '__main__':
unittest_main()
28 changes: 19 additions & 9 deletions tf2onnx/onnx_opset/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2260,15 +2260,25 @@ def version_10(cls, ctx, node, **kwargs):
const_axis_name = utils.make_name(f'const_{axis}')
const_axis = ctx.make_const(name=const_axis_name, np_val=np.array([axis], dtype=np.int64))

# Add a Constant node (seq_len) for ReverseSequence.
# Index 1 for the shape should not return 0, since rank(input) >=2
input_shape = ctx.make_node("Shape", [inputs[-1]], op_name_scope=rv2_node_name)
batch_size = ctx.make_node("Gather", [input_shape.output[0], const_one.output[0]],
op_name_scope=rv2_node_name)
axis_dim = ctx.make_node("Gather", [input_shape_node.output[0], const_axis.output[0]],
op_name_scope=rv2_node_name)
seq_array = ctx.make_node("Expand", [axis_dim.output[0], batch_size.output[0]])
inputs.append(seq_array.output[0])
# Add sequence_lens as ReverseSequence input
has_sequence_lens = node.get_attr_value("has_sequence_lens", False)
if not has_sequence_lens:
# open-source impl: fill in dummy sequence_lens based on input shape
# Add a Constant node (seq_len) for ReverseSequence.
# Index 1 for the shape should not return 0, since rank(input) >=2
input_shape = ctx.make_node("Shape", [inputs[-1]], op_name_scope=rv2_node_name)
batch_size = ctx.make_node("Gather", [input_shape.output[0], const_one.output[0]],
op_name_scope=rv2_node_name)
axis_dim = ctx.make_node("Gather", [input_shape_node.output[0], const_axis.output[0]],
op_name_scope=rv2_node_name)
seq_array = ctx.make_node("Expand", [axis_dim.output[0], batch_size.output[0]])
inputs.append(seq_array.output[0])
else:
# masked backward LSTM:
# sequence_lens is appended to ReverseV2's input by lstm_tf2_rewriter
# to keep tensor post-padded after reverse
seq_lens_casted = ctx.make_node("Cast", [node.input[-1]], attr={'to': TensorProto.INT64}).output[0]
inputs.append(seq_lens_casted)

# Add a ReverseSequence node.

Expand Down
Loading

0 comments on commit 9775fa5

Please sign in to comment.