From 9775fa55d625aa580eb76212d7c3cd64a06b06e0 Mon Sep 17 00:00:00 2001 From: Yu Cong Date: Sat, 27 Aug 2022 00:33:54 +0000 Subject: [PATCH] Add masked LSTM support Signed-off-by: Yu Cong --- tests/test_lstm.py | 207 ++++++++++++++++++++++++++ tf2onnx/onnx_opset/tensor.py | 28 ++-- tf2onnx/rewriter/lstm_tf2_rewriter.py | 111 +++++++++++--- 3 files changed, 320 insertions(+), 26 deletions(-) diff --git a/tests/test_lstm.py b/tests/test_lstm.py index f33bf470c..489df374b 100644 --- a/tests/test_lstm.py +++ b/tests/test_lstm.py @@ -793,5 +793,212 @@ def func(x): return tf.identity(y[0], name="output") self.run_test_case(func, {"input:0": x_val}, [], ["output:0"], rtol=1e-05, atol=1e-06) + @check_tf_min_version("2.0") + @skip_tf_versions("2.1", "Bug in TF 2.1") + def test_keras_masked_lstm_embedding_unidirectional(self): + for go_backwards in [True, False]: + for return_sequences in [True, False]: + timesteps = 4 + # Note: masked LSTM only support post-padded input after conversion + # test case sequence_lens = [4, 2, 0] + x_val = np.array([ + [1, 2, 3, 4], + [5, 6, 0, 0], + [0, 0, 0, 0] + ], dtype=np.int32) + + model_in = tf.keras.layers.Input((timesteps,), dtype="int32") + x_embedding = tf.keras.layers.Embedding( + input_dim=10, + output_dim=5, + mask_zero=True, + embeddings_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=41), + )(model_in) + + # RNN layer inherits the mask propagated from above embedding layer + model_out = tf.keras.layers.LSTM( + units=5, + go_backwards=go_backwards, + return_sequences=return_sequences, + return_state=True, + kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=42), + bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=43), + recurrent_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=44), + )(x_embedding) + model = tf.keras.models.Model(inputs=model_in, outputs=model_out) + + def func(x): + y = model(x) + if return_sequences: + return ( + # skipping output Y when return_sequences=True due to inconsistent + # ORT and TF behaviors: https://sim.amazon.com/issues/NEMORT-1712 + tf.identity(y[1], name="output_yh"), + tf.identity(y[2], name="output_yc")) + return( + tf.identity(y[0], name="output_y"), + tf.identity(y[1], name="output_yh"), + tf.identity(y[2], name="output_yc")) + + output_list = ["output_yh:0", "output_yc:0"] if return_sequences \ + else ["output_y:0", "output_yh:0", "output_yc:0"] + self.run_test_case(func, {"input:0": x_val}, [], output_list, rtol=1e-05, atol=1e-06) + + @check_tf_min_version("2.0") + @skip_tf_versions("2.1", "Bug in TF 2.1") + def test_keras_masked_lstm_embedding_bidirectional(self): + for return_sequences in [False, True]: + timesteps = 4 + # Note: masked LSTM only support post-padded input after conversion + # test case sequence_lens = [4, 2, 0] + x_val = np.array([ + [1, 2, 3, 4], + [5, 6, 0, 0], + [0, 0, 0, 0] + ], dtype=np.int32) + + model_in = tf.keras.layers.Input((timesteps,), dtype="int32") + x_embedding = tf.keras.layers.Embedding( + input_dim=10, + output_dim=5, + mask_zero=True, + embeddings_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=41), + )(model_in) + + # RNN layer inherits the mask propagated from above embedding layer + lstm_layer = tf.keras.layers.LSTM( + units=5, + go_backwards=False, + return_sequences=return_sequences, + return_state=True, + kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=42), + bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=43), + recurrent_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=44), + ) + model_out = tf.keras.layers.Bidirectional(lstm_layer)(x_embedding) + model = tf.keras.models.Model(inputs=model_in, outputs=model_out) + + def func(x): + y = model(x) + if return_sequences: + return ( + # skipping output Y when return_sequences=True due to inconsistent + # ORT and TF behaviors: https://sim.amazon.com/issues/NEMORT-1712 + tf.identity(y[1], name="output_yh_f"), + tf.identity(y[2], name="output_yc_f"), + tf.identity(y[3], name="output_yh_r"), + tf.identity(y[4], name="output_yc_r")) + return( + tf.identity(y[0], name="output_y_concat"), + tf.identity(y[1], name="output_yh_f"), + tf.identity(y[2], name="output_yc_f"), + tf.identity(y[3], name="output_yh_r"), + tf.identity(y[4], name="output_yc_r")) + + output_list = ["output_yh_f:0", "output_yc_f:0", "output_yh_r:0", "output_yc_r:0"] if return_sequences \ + else ["output_y_concat:0", "output_yh_f:0", "output_yc_f:0", "output_yh_r:0", "output_yc_r:0"] + + # translate single BiLSTM to two forward LSTMs + self.run_test_case(func, {"input:0": x_val}, [], output_list, rtol=1e-05, atol=1e-06, + require_lstm_count=2) + + @check_tf_min_version("2.0") + @skip_tf_versions("2.1", "Bug in TF 2.1") + def test_keras_masked_lstm_unidirectional(self): + for go_backwards in [True, False]: + for return_sequences in [True, False]: + batch_size, timesteps, feat = 3, 4, 5 + in_shape = (timesteps, feat) + x_val = np.random.uniform(size=[batch_size, timesteps, feat]).astype(np.float32) + # Note: masked LSTM only support post-padded input after conversion + # test case sequence_lens = [4, 2, 0] + x_val[1, 2:, :] = 0. + x_val[2, :, :] = 0. + + model_in = tf.keras.layers.Input(shape=in_shape, dtype="float32") + x_masked = tf.keras.layers.Masking(mask_value=0.)(model_in) + + # RNN layer inherits the mask propagated from above mask layer + model_out = tf.keras.layers.LSTM( + units=5, + go_backwards=go_backwards, + return_sequences=return_sequences, + return_state=True, + kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=42), + bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=43), + recurrent_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=44), + )(x_masked) + model = tf.keras.models.Model(inputs=model_in, outputs=model_out) + + def func(x): + y = model(x) + if return_sequences: + return ( + # skipping output Y when return_sequences=True due to inconsistent + # ORT and TF behaviors: https://sim.amazon.com/issues/NEMORT-1712 + tf.identity(y[1], name="output_yh"), + tf.identity(y[2], name="output_yc")) + return( + tf.identity(y[0], name="output_y"), + tf.identity(y[1], name="output_yh"), + tf.identity(y[2], name="output_yc")) + + output_list = ["output_yh:0", "output_yc:0"] if return_sequences \ + else ["output_y:0", "output_yh:0", "output_yc:0"] + self.run_test_case(func, {"input:0": x_val}, [], output_list, rtol=1e-05, atol=1e-06) + + @check_tf_min_version("2.0") + @skip_tf_versions("2.1", "Bug in TF 2.1") + def test_keras_masked_lstm_bidirectional(self): + for return_sequences in [False, True]: + batch_size, timesteps, feat = 3, 4, 5 + in_shape = (timesteps, feat) + x_val = np.random.uniform(size=[batch_size, timesteps, feat]).astype(np.float32) + # Note: masked LSTM only support post-padded input after conversion + # test case sequence_lens = [4, 2, 0] + x_val[1, 2:, :] = 0. + x_val[2, :, :] = 0. + + model_in = tf.keras.layers.Input(shape=in_shape, dtype="float32") + x_masked = tf.keras.layers.Masking(mask_value=0.)(model_in) + + # RNN layer inherits the mask propagated from above mask layer + lstm_layer = tf.keras.layers.LSTM( + units=5, + go_backwards=False, + return_sequences=return_sequences, + return_state=True, + kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=42), + bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=43), + recurrent_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=44), + ) + model_out = tf.keras.layers.Bidirectional(lstm_layer)(x_masked) + model = tf.keras.models.Model(inputs=model_in, outputs=model_out) + + def func(x): + y = model(x) + if return_sequences: + return ( + # skipping output Y when return_sequences=True due to inconsistent + # ORT and TF behaviors: https://sim.amazon.com/issues/NEMORT-1712 + tf.identity(y[1], name="output_yh_f"), + tf.identity(y[2], name="output_yc_f"), + tf.identity(y[3], name="output_yh_r"), + tf.identity(y[4], name="output_yc_r")) + return( + tf.identity(y[0], name="output_y_concat"), + tf.identity(y[1], name="output_yh_f"), + tf.identity(y[2], name="output_yc_f"), + tf.identity(y[3], name="output_yh_r"), + tf.identity(y[4], name="output_yc_r")) + + output_list = ["output_yh_f:0", "output_yc_f:0", "output_yh_r:0", "output_yc_r:0"] if return_sequences \ + else ["output_y_concat:0", "output_yh_f:0", "output_yc_f:0", "output_yh_r:0", "output_yc_r:0"] + + # translate single BiLSTM to two forward LSTMs + self.run_test_case(func, {"input:0": x_val}, [], output_list, rtol=1e-05, atol=1e-06, + require_lstm_count=2) + + if __name__ == '__main__': unittest_main() diff --git a/tf2onnx/onnx_opset/tensor.py b/tf2onnx/onnx_opset/tensor.py index 044a253d4..ca18d5888 100644 --- a/tf2onnx/onnx_opset/tensor.py +++ b/tf2onnx/onnx_opset/tensor.py @@ -2260,15 +2260,25 @@ def version_10(cls, ctx, node, **kwargs): const_axis_name = utils.make_name(f'const_{axis}') const_axis = ctx.make_const(name=const_axis_name, np_val=np.array([axis], dtype=np.int64)) - # Add a Constant node (seq_len) for ReverseSequence. - # Index 1 for the shape should not return 0, since rank(input) >=2 - input_shape = ctx.make_node("Shape", [inputs[-1]], op_name_scope=rv2_node_name) - batch_size = ctx.make_node("Gather", [input_shape.output[0], const_one.output[0]], - op_name_scope=rv2_node_name) - axis_dim = ctx.make_node("Gather", [input_shape_node.output[0], const_axis.output[0]], - op_name_scope=rv2_node_name) - seq_array = ctx.make_node("Expand", [axis_dim.output[0], batch_size.output[0]]) - inputs.append(seq_array.output[0]) + # Add sequence_lens as ReverseSequence input + has_sequence_lens = node.get_attr_value("has_sequence_lens", False) + if not has_sequence_lens: + # open-source impl: fill in dummy sequence_lens based on input shape + # Add a Constant node (seq_len) for ReverseSequence. + # Index 1 for the shape should not return 0, since rank(input) >=2 + input_shape = ctx.make_node("Shape", [inputs[-1]], op_name_scope=rv2_node_name) + batch_size = ctx.make_node("Gather", [input_shape.output[0], const_one.output[0]], + op_name_scope=rv2_node_name) + axis_dim = ctx.make_node("Gather", [input_shape_node.output[0], const_axis.output[0]], + op_name_scope=rv2_node_name) + seq_array = ctx.make_node("Expand", [axis_dim.output[0], batch_size.output[0]]) + inputs.append(seq_array.output[0]) + else: + # masked backward LSTM: + # sequence_lens is appended to ReverseV2's input by lstm_tf2_rewriter + # to keep tensor post-padded after reverse + seq_lens_casted = ctx.make_node("Cast", [node.input[-1]], attr={'to': TensorProto.INT64}).output[0] + inputs.append(seq_lens_casted) # Add a ReverseSequence node. diff --git a/tf2onnx/rewriter/lstm_tf2_rewriter.py b/tf2onnx/rewriter/lstm_tf2_rewriter.py index 845bb2a84..6bb3facee 100644 --- a/tf2onnx/rewriter/lstm_tf2_rewriter.py +++ b/tf2onnx/rewriter/lstm_tf2_rewriter.py @@ -4,8 +4,10 @@ """ tf2onnx.rewriter.lstm_tf2_rewriter - Rewrites LSTM pattern used by tf2. """ - +import logging import numpy as np +from onnx import onnx_pb + from tf2onnx.graph_matcher import GraphMatcher from tf2onnx.rewriter.rnn_utils import make_lstm_pattern from tf2onnx.tf_loader import find_function @@ -79,21 +81,35 @@ def rewriter_lstm_tf2(g, ops): # extract output h_t ht_mul = match_result.get_op("ht") final_consumers = g.find_output_consumers(ht_mul.output[0]) - select_ops = [n for n in final_consumers if n.type == "Select"] + select_ops = [n for n in final_consumers if n.type == "Select" or n.type == "SelectV2"] def has_tensor_list_consumer(n): return any(c.type == "TensorListSetItem" for c in g.find_output_consumers(n.output[0])) select_ops = [n for n in select_ops if has_tensor_list_consumer(n)] + + # extract sequence length + seq_len_idx, mask_idx = None, None if len(select_ops) == 1: - greater_eq = select_ops[0].inputs[0] - if greater_eq.type != "GreaterEqual": - continue - seq_len = greater_eq.inputs[1] - if not seq_len.is_graph_input(): + select_op_condition = select_ops[0].inputs[0] + while select_op_condition.type == "Identity": + select_op_condition = select_op_condition.inputs[0] + + # open-source impl: skip timestpes based on speicific sequence length + if select_op_condition.type == "GreaterEqual": + seq_len = select_op_condition.inputs[1] + if not seq_len.is_graph_input(): + continue + seq_len_idx = g.input_names.index(seq_len.output[0]) + + # masked LSTM: skip timesteps based on dynamically-computed boolean mask tensor + elif select_op_condition.type == "TensorListGetItem": + mask = select_op_condition.inputs[0] + if not mask.is_graph_input(): + continue + mask_idx = g.input_names.index(mask.output[0]) + else: continue - seq_len_idx = g.input_names.index(seq_len.output[0]) + final_consumers = g.find_output_consumers(select_ops[0].output[0]) - else: - seq_len_idx = None tensor_set_items = [n for n in final_consumers if n.type == "TensorListSetItem"] if len(tensor_set_items) != 1: @@ -209,6 +225,7 @@ def has_tensor_list_consumer(n): # Keras "w_idx": gk_idx, "r_idx": hk_idx, + "mask_idx": mask_idx, } for op in ops: @@ -276,15 +293,63 @@ def has_tensor_list_consumer(n): tensor_array_inp = op.inputs[body_context["x_idx"]] if not tensor_array_inp.type == "TensorListFromTensor": continue + context.onnx_input_ids[0]["X"] = tensor_array_inp.input[0] - final_consumers = g.find_output_consumers(op.output[body_context["out_idx"]]) - output_ys = [n.output[0] for n in final_consumers if n.type == "TensorListStack"] + # parse sequence length + seq_len_idx = body_context["seq_len_idx"] + mask_idx = body_context["mask_idx"] + if seq_len_idx: + context.onnx_input_ids[0]["sequence_lens"] = op.input[seq_len_idx] + elif mask_idx: + logging.warning( + "Found mask-enabled LSTM. Converted ONNX model will only support post-padded LSTM input. " + "If input is pre- or randomly-padded, masked timesteps will not be correctly skipped.") + + # parse sequence length + tensor_array_mask = op.inputs[body_context["mask_idx"]] + if not tensor_array_mask.type == "TensorListFromTensor": + continue + mask_mat = tensor_array_mask.input[0] + mask_mat_node = g.get_node_by_output(mask_mat) + is_mask_reverse = mask_mat_node.type == "ReverseV2" + # no need to reverse the mask sequence + # the positions of skipped timesteps per batch is irrelevant assuming post-padded input + if is_mask_reverse: + mask_mat = mask_mat_node.input[0] + + # reduce mask tensor to sequence_lens assuming post-padded input + # tranpose (1,0,2) -> boolean mask tensor (N, timesteps, 1) + # squeeze on dim(-1) -> boolean mask matrix (N, timesteps) + # reduceSum on dim(-1) -> sequence_lens (N) + mask_transpose_node = g.make_node(op_type="Transpose", inputs=[mask_mat], attr={"perm": [1, 0, 2]}) + mask_squeeze = GraphBuilder(g).make_squeeze({"data": mask_transpose_node.output[0], "axes": [-1]}) + mask_cast_node = g.make_node(op_type="Cast", inputs=[mask_squeeze], + attr={"to": onnx_pb.TensorProto.INT32}) + sequence_lens = GraphBuilder(g).make_reduce_sum({"data": mask_cast_node.output[0], + "axes": [-1], "keepdims": 0}) + context.onnx_input_ids[0]["sequence_lens"] = sequence_lens + + # handle backward LSTM + tensor_array_inp_producer = tensor_array_inp.inputs[0] + is_input_reverse = tensor_array_inp_producer.type == "ReverseV2" + # backward LSTM is identified by the reverses of both input and mask tensors pre-LSTM + if is_mask_reverse != is_input_reverse: + continue + if is_input_reverse: + # TF uses simple "ReverseV2" to reverse input tensor with no assumption on padding position + # because reversed mask with shape (batch_size, timesteps) is explicit per-timestep. + # ONNX requires "ReverseSequence" to keep the reversed input tensor post-padded because mask + # is implied by sequence_lens. This requires passing sequence_lens to such "ReverseSequence" op. + + # Note: tensor op conversions run after rewriters. Appending sequence_lens as a "ReverseV2" input + # signalizes alternative behavior in "ReverseV2" conversion in onnx_opset/tensor.py. + tensor_array_inp_producer.set_attr("has_sequence_lens", True) + inp_reverse_inputs = tensor_array_inp_producer.input + inp_reverse_inputs.append(sequence_lens) - context.onnx_input_ids[0]["X"] = tensor_array_inp.input[0] - if body_context["seq_len_idx"] is None: - context.onnx_input_ids[0]["sequence_lens"] = "" else: - context.onnx_input_ids[0]["sequence_lens"] = op.input[body_context["seq_len_idx"]] + context.onnx_input_ids[0]["sequence_lens"] = "" + context.onnx_input_ids[0]["initial_c"] = initial_c context.onnx_input_ids[0]["initial_h"] = initial_h @@ -294,10 +359,22 @@ def has_tensor_list_consumer(n): lstm_rewriter.process_weights_and_bias(context) lstm_node = lstm_rewriter.create_rnn_node(context)[0] - squeeze_output = GraphBuilder(g).make_squeeze({"data": lstm_node.output[0], "axes": [1]}) + # connect LSTM output Y to downstream nodes + squeeze_output_y = GraphBuilder(g).make_squeeze({"data": lstm_node.output[0], "axes": [1]}) + squeeze_output_yh = GraphBuilder(g).make_squeeze({"data": lstm_node.output[1], "axes": [0]}) + final_consumers = g.find_output_consumers(op.output[body_context["out_idx"]]) + output_ys = [n.output[0] for n in final_consumers if n.type == "TensorListStack"] for output in output_ys: + output_y_consumers = g.find_output_consumers(output) + # if hidden output from LSTM is sliced on the last timestep in TF, i.e. return_sequences=False, + # do not slice on the last timestep of y in ONNX, instead propagate y_h directly + # to avoid exposing inference issue: https://sim.amazon.com/issues/NEMORT-1712 + y_is_yh = len(output_y_consumers) == 1 and output_y_consumers[0].type == "StridedSlice" + output = output_y_consumers[0].output[0] if y_is_yh else output + squeeze_output = squeeze_output_yh if y_is_yh else squeeze_output_y g.replace_all_inputs(output, squeeze_output) + # connect LSTM output Y_h and Y_c to downstream nodes if body_context["state_is_tuple"]: c_squeeze = GraphBuilder(g).make_squeeze({"data": lstm_node.output[2], "axes": [0]}) h_squeeze = GraphBuilder(g).make_squeeze({"data": lstm_node.output[1], "axes": [0]})