From 0c9437afcb9cc5852abcbd31bcb85c08afef0ab7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CWanglongzhi2001=E2=80=9D?= <“583087864@qq.com”> Date: Tue, 18 Jul 2023 23:31:45 +0800 Subject: [PATCH] feat: add Bidirectional layer --- .../ArgsDefinition/Rnn/BidirectionalArgs.cs | 20 ++ .../Keras/ArgsDefinition/Rnn/LSTMArgs.cs | 5 + .../Keras/ArgsDefinition/Rnn/RNNArgs.cs | 5 + .../Keras/ArgsDefinition/Rnn/WrapperArgs.cs | 24 ++ .../Keras/Layers/ILayersApi.cs | 14 +- src/TensorFlowNET.Keras/Layers/LayersApi.cs | 14 + .../Layers/Rnn/BaseWrapper.cs | 33 +++ .../Layers/Rnn/Bidirectional.cs | 276 ++++++++++++++++++ src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs | 31 +- src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs | 11 +- .../Layers/Rnn.Test.cs | 13 +- 11 files changed, 428 insertions(+), 18 deletions(-) create mode 100644 src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/BidirectionalArgs.cs create mode 100644 src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/WrapperArgs.cs create mode 100644 src/TensorFlowNET.Keras/Layers/Rnn/BaseWrapper.cs create mode 100644 src/TensorFlowNET.Keras/Layers/Rnn/Bidirectional.cs diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/BidirectionalArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/BidirectionalArgs.cs new file mode 100644 index 000000000..d658a82e9 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/BidirectionalArgs.cs @@ -0,0 +1,20 @@ +using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.NumPy; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class BidirectionalArgs : AutoSerializeLayerArgs + { + [JsonProperty("layer")] + public ILayer Layer { get; set; } + [JsonProperty("merge_mode")] + public string? MergeMode { get; set; } + [JsonProperty("backward_layer")] + public ILayer BackwardLayer { get; set; } + public NDArray Weights { get; set; } + } + +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMArgs.cs index d816b0ff7..a6beb77e8 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMArgs.cs @@ -5,5 +5,10 @@ public class LSTMArgs : RNNArgs // TODO: maybe change the `RNNArgs` and implement this class. public bool UnitForgetBias { get; set; } public int Implementation { get; set; } + + public LSTMArgs Clone() + { + return (LSTMArgs)MemberwiseClone(); + } } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs index b84d30d3d..d0b73ba44 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs @@ -40,5 +40,10 @@ public class RNNArgs : AutoSerializeLayerArgs public bool ZeroOutputForMask { get; set; } = false; [JsonProperty("recurrent_dropout")] public float RecurrentDropout { get; set; } = .0f; + + public RNNArgs Clone() + { + return (RNNArgs)MemberwiseClone(); + } } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/WrapperArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/WrapperArgs.cs new file mode 100644 index 000000000..ec8e16d59 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/WrapperArgs.cs @@ -0,0 +1,24 @@ +using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Text; + + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class WrapperArgs : AutoSerializeLayerArgs + { + [JsonProperty("layer")] + public ILayer Layer { get; set; } + + public WrapperArgs(ILayer layer) + { + Layer = layer; + } + + public static implicit operator WrapperArgs(BidirectionalArgs args) + => new WrapperArgs(args.Layer); + } + +} diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs index 1670f9d1d..b8aff5fb6 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs @@ -258,7 +258,19 @@ public IRnnCell GRUCell( float dropout = 0f, float recurrent_dropout = 0f, bool reset_after = true); - + + /// + /// Bidirectional wrapper for RNNs. + /// + /// `keras.layers.RNN` instance, such as `keras.layers.LSTM` or `keras.layers.GRU` + /// automatically. + /// + public ILayer Bidirectional( + ILayer layer, + string merge_mode = "concat", + NDArray weights = null, + ILayer backward_layer = null); + public ILayer Subtract(); } } diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs index cb85bbba1..a04a9c051 100644 --- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs @@ -908,6 +908,20 @@ public IRnnCell GRUCell( ResetAfter = reset_after }); + public ILayer Bidirectional( + ILayer layer, + string merge_mode = "concat", + NDArray weights = null, + ILayer backward_layer = null) + => new Bidirectional(new BidirectionalArgs + { + Layer = layer, + MergeMode = merge_mode, + Weights = weights, + BackwardLayer = backward_layer + }); + + /// /// /// diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/BaseWrapper.cs b/src/TensorFlowNET.Keras/Layers/Rnn/BaseWrapper.cs new file mode 100644 index 000000000..737f88cd4 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Rnn/BaseWrapper.cs @@ -0,0 +1,33 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Saving; + +namespace Tensorflow.Keras.Layers +{ + /// + /// Abstract wrapper base class. Wrappers take another layer and augment it in various ways. + /// Do not use this class as a layer, it is only an abstract base class. + /// Two usable wrappers are the `TimeDistributed` and `Bidirectional` wrappers. + /// + public abstract class Wrapper: Layer + { + public ILayer _layer; + public Wrapper(WrapperArgs args):base(args) + { + _layer = args.Layer; + } + + public virtual void Build(KerasShapesWrapper input_shape) + { + if (!_layer.Built) + { + _layer.build(input_shape); + } + built = true; + } + + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/Bidirectional.cs b/src/TensorFlowNET.Keras/Layers/Rnn/Bidirectional.cs new file mode 100644 index 000000000..6114d9c7c --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Rnn/Bidirectional.cs @@ -0,0 +1,276 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Common.Types; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Saving; + +namespace Tensorflow.Keras.Layers +{ + /// + /// Bidirectional wrapper for RNNs. + /// + public class Bidirectional: Wrapper + { + BidirectionalArgs _args; + RNN _forward_layer; + RNN _backward_layer; + RNN _layer; + bool _support_masking = true; + int _num_constants = 0; + bool _return_state; + bool _stateful; + bool _return_sequences; + InputSpec _input_spec; + RNNArgs _layer_args_copy; + public Bidirectional(BidirectionalArgs args):base(args) + { + _args = args; + if (_args.Layer is not ILayer) + throw new ValueError( + "Please initialize `Bidirectional` layer with a " + + $"`tf.keras.layers.Layer` instance. Received: {_args.Layer}"); + + if (_args.BackwardLayer is not null && _args.BackwardLayer is not ILayer) + throw new ValueError( + "`backward_layer` need to be a `tf.keras.layers.Layer` " + + $"instance. Received: {_args.BackwardLayer}"); + if (!new List { "sum", "mul", "ave", "concat", null }.Contains(_args.MergeMode)) + { + throw new ValueError( + $"Invalid merge mode. Received: {_args.MergeMode}. " + + "Merge mode should be one of " + + "{\"sum\", \"mul\", \"ave\", \"concat\", null}" + ); + } + if (_args.Layer is RNN) + { + _layer = _args.Layer as RNN; + } + else + { + throw new ValueError( + "Bidirectional only support RNN instance such as LSTM or GRU"); + } + _return_state = _layer.Args.ReturnState; + _return_sequences = _layer.Args.ReturnSequences; + _stateful = _layer.Args.Stateful; + _layer_args_copy = _layer.Args.Clone(); + // We don't want to track `layer` since we're already tracking the two + // copies of it we actually run. + // TODO(Wanglongzhi2001), since the feature of setattr_tracking has not been implemented. + // _setattr_tracking = false; + // super().__init__(layer, **kwargs) + // _setattr_tracking = true; + + // Recreate the forward layer from the original layer config, so that it + // will not carry over any state from the layer. + var actualType = _layer.GetType(); + if (actualType == typeof(LSTM)) + { + var arg = _layer_args_copy as LSTMArgs; + _forward_layer = new LSTM(arg); + } + // TODO(Wanglongzhi2001), add GRU if case. + else + { + _forward_layer = new RNN(_layer.Cell, _layer_args_copy); + } + //_forward_layer = _recreate_layer_from_config(_layer); + if (_args.BackwardLayer is null) + { + _backward_layer = _recreate_layer_from_config(_layer, go_backwards:true); + } + else + { + _backward_layer = _args.BackwardLayer as RNN; + } + _forward_layer.Name = "forward_" + _forward_layer.Name; + _backward_layer.Name = "backward_" + _backward_layer.Name; + _verify_layer_config(); + + void force_zero_output_for_mask(RNN layer) + { + layer.Args.ZeroOutputForMask = layer.Args.ReturnSequences; + } + + force_zero_output_for_mask(_forward_layer); + force_zero_output_for_mask(_backward_layer); + + if (_args.Weights is not null) + { + var nw = len(_args.Weights); + _forward_layer.set_weights(_args.Weights[$":,{nw / 2}"]); + _backward_layer.set_weights(_args.Weights[$"{nw / 2},:"]); + } + + _input_spec = _layer.InputSpec; + } + + private void _verify_layer_config() + { + if (_forward_layer.Args.GoBackwards == _backward_layer.Args.GoBackwards) + { + throw new ValueError( + "Forward layer and backward layer should have different " + + "`go_backwards` value." + + "forward_layer.go_backwards = " + + $"{_forward_layer.Args.GoBackwards}," + + "backward_layer.go_backwards = " + + $"{_backward_layer.Args.GoBackwards}"); + } + if (_forward_layer.Args.Stateful != _backward_layer.Args.Stateful) + { + throw new ValueError( + "Forward layer and backward layer are expected to have "+ + $"the same value for attribute stateful, got "+ + $"{_forward_layer.Args.Stateful} for forward layer and "+ + $"{_backward_layer.Args.Stateful} for backward layer"); + } + if (_forward_layer.Args.ReturnState != _backward_layer.Args.ReturnState) + { + throw new ValueError( + "Forward layer and backward layer are expected to have " + + $"the same value for attribute return_state, got " + + $"{_forward_layer.Args.ReturnState} for forward layer and " + + $"{_backward_layer.Args.ReturnState} for backward layer"); + } + if (_forward_layer.Args.ReturnSequences != _backward_layer.Args.ReturnSequences) + { + throw new ValueError( + "Forward layer and backward layer are expected to have " + + $"the same value for attribute return_sequences, got " + + $"{_forward_layer.Args.ReturnSequences} for forward layer and " + + $"{_backward_layer.Args.ReturnSequences} for backward layer"); + } + } + + private RNN _recreate_layer_from_config(RNN layer, bool go_backwards = false) + { + var config = layer.get_config() as RNNArgs; + var cell = layer.Cell; + if (go_backwards) + { + config.GoBackwards = !config.GoBackwards; + } + var actualType = layer.GetType(); + if (actualType == typeof(LSTM)) + { + var arg = config as LSTMArgs; + return new LSTM(arg); + } + else + { + return new RNN(cell, config); + } + } + + public override void build(KerasShapesWrapper input_shape) + { + _buildInputShape = input_shape; + tf_with(ops.name_scope(_forward_layer.Name), scope=> + { + _forward_layer.build(input_shape); + }); + tf_with(ops.name_scope(_backward_layer.Name), scope => + { + _backward_layer.build(input_shape); + }); + built = true; + } + + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) + { + // `Bidirectional.call` implements the same API as the wrapped `RNN`. + + Tensors forward_inputs; + Tensors backward_inputs; + Tensors forward_state; + Tensors backward_state; + // if isinstance(inputs, list) and len(inputs) > 1: + if (inputs.Length > 1) + { + // initial_states are keras tensors, which means they are passed + // in together with inputs as list. The initial_states need to be + // split into forward and backward section, and be feed to layers + // accordingly. + forward_inputs = new Tensors { inputs[0] }; + backward_inputs = new Tensors { inputs[0] }; + var pivot = (len(inputs) - _num_constants) / 2 + 1; + // add forward initial state + forward_inputs.Concat(new Tensors { inputs[$"1:{pivot}"] }); + if (_num_constants != 0) + // add backward initial state + backward_inputs.Concat(new Tensors { inputs[$"{pivot}:"] }); + else + { + // add backward initial state + backward_inputs.Concat(new Tensors { inputs[$"{pivot}:{-_num_constants}"] }); + // add constants for forward and backward layers + forward_inputs.Concat(new Tensors { inputs[$"{-_num_constants}:"] }); + backward_inputs.Concat(new Tensors { inputs[$"{-_num_constants}:"] }); + } + forward_state = null; + backward_state = null; + } + else if (state is not null) + { + // initial_states are not keras tensors, eg eager tensor from np + // array. They are only passed in from kwarg initial_state, and + // should be passed to forward/backward layer via kwarg + // initial_state as well. + forward_inputs = inputs; + backward_inputs = inputs; + var half = len(state) / 2; + forward_state = state[$":{half}"]; + backward_state = state[$"{half}:"]; + } + else + { + forward_inputs = inputs; + backward_inputs = inputs; + forward_state = null; + backward_state = null; + } + var y = _forward_layer.Apply(forward_inputs, forward_state); + var y_rev = _backward_layer.Apply(backward_inputs, backward_state); + + Tensors states = new(); + if (_return_state) + { + states = y["1:"] + y_rev["1:"]; + y = y[0]; + y_rev = y_rev[0]; + } + + if (_return_sequences) + { + int time_dim = _forward_layer.Args.TimeMajor ? 0 : 1; + y_rev = keras.backend.reverse(y_rev, time_dim); + } + Tensors output; + if (_args.MergeMode == "concat") + output = keras.backend.concatenate(new Tensors { y.Single(), y_rev.Single() }); + else if (_args.MergeMode == "sum") + output = y.Single() + y_rev.Single(); + else if (_args.MergeMode == "ave") + output = (y.Single() + y_rev.Single()) / 2; + else if (_args.MergeMode == "mul") + output = y.Single() * y_rev.Single(); + else if (_args.MergeMode is null) + output = new Tensors { y.Single(), y_rev.Single() }; + else + throw new ValueError( + "Unrecognized value for `merge_mode`. " + + $"Received: {_args.MergeMode}" + + "Expected values are [\"concat\", \"sum\", \"ave\", \"mul\"]"); + if (_return_state) + { + if (_args.MergeMode is not null) + return new Tensors { output.Single(), states.Single()}; + } + return output; + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs b/src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs index b5d583248..c766e8d69 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs @@ -3,6 +3,7 @@ using Tensorflow.Keras.Engine; using Tensorflow.Common.Types; using Tensorflow.Common.Extensions; +using Tensorflow.Keras.Saving; namespace Tensorflow.Keras.Layers { @@ -14,15 +15,15 @@ namespace Tensorflow.Keras.Layers /// public class LSTM : RNN { - LSTMArgs args; + LSTMArgs _args; InputSpec[] _state_spec; InputSpec _input_spec; bool _could_use_gpu_kernel; - + public LSTMArgs Args { get => _args; } public LSTM(LSTMArgs args) : base(CreateCell(args), args) { - this.args = args; + _args = args; _input_spec = new InputSpec(ndim: 3); _state_spec = new[] { args.Units, args.Units }.Select(dim => new InputSpec(shape: (-1, dim))).ToArray(); _could_use_gpu_kernel = args.Activation == keras.activations.Tanh @@ -71,7 +72,7 @@ protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bo var single_input = inputs.Single; var input_shape = single_input.shape; - var timesteps = args.TimeMajor ? input_shape[0] : input_shape[1]; + var timesteps = _args.TimeMajor ? input_shape[0] : input_shape[1]; _maybe_reset_cell_dropout_mask(Cell); @@ -87,26 +88,26 @@ protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bo inputs, initial_state, constants: null, - go_backwards: args.GoBackwards, + go_backwards: _args.GoBackwards, mask: mask, - unroll: args.Unroll, + unroll: _args.Unroll, input_length: ops.convert_to_tensor(timesteps), - time_major: args.TimeMajor, - zero_output_for_mask: args.ZeroOutputForMask, - return_all_outputs: args.ReturnSequences + time_major: _args.TimeMajor, + zero_output_for_mask: _args.ZeroOutputForMask, + return_all_outputs: _args.ReturnSequences ); Tensor output; - if (args.ReturnSequences) + if (_args.ReturnSequences) { - output = keras.backend.maybe_convert_to_ragged(false, outputs, (int)timesteps, args.GoBackwards); + output = keras.backend.maybe_convert_to_ragged(false, outputs, (int)timesteps, _args.GoBackwards); } else { output = last_output; } - if (args.ReturnState) + if (_args.ReturnState) { return new Tensor[] { output }.Concat(states).ToArray().ToTensors(); } @@ -115,5 +116,11 @@ protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bo return output; } } + + public override IKerasConfig get_config() + { + return _args; + } + } } diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs index 0e81d20e3..c19222614 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs @@ -31,7 +31,9 @@ public class RNN : RnnBase protected IVariableV1 _kernel; protected IVariableV1 _bias; private IRnnCell _cell; - protected IRnnCell Cell + + public RNNArgs Args { get => _args; } + public IRnnCell Cell { get { @@ -570,10 +572,13 @@ protected Tensors get_initial_state(Tensors inputs) var input_shape = array_ops.shape(inputs); var batch_size = _args.TimeMajor ? input_shape[1] : input_shape[0]; var dtype = input.dtype; - Tensors init_state = Cell.GetInitialState(null, batch_size, dtype); - return init_state; } + + public override IKerasConfig get_config() + { + return _args; + } } } diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs index 5f7bd574e..03159346a 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs @@ -5,6 +5,7 @@ using System.Text; using System.Threading.Tasks; using Tensorflow.Common.Types; +using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Layers; using Tensorflow.Keras.Saving; @@ -38,8 +39,6 @@ public void StackedRNNCell() var cells = new IRnnCell[] { tf.keras.layers.SimpleRNNCell(4), tf.keras.layers.SimpleRNNCell(5) }; var stackedRNNCell = tf.keras.layers.StackedRNNCells(cells); var (output, state) = stackedRNNCell.Apply(inputs, states); - Console.WriteLine(output); - Console.WriteLine(state.shape); Assert.AreEqual((32, 5), output.shape); Assert.AreEqual((32, 4), state[0].shape); } @@ -108,6 +107,7 @@ public void RNNForSimpleRNNCell() var inputs = tf.random.normal((32, 10, 8)); var cell = tf.keras.layers.SimpleRNNCell(10, dropout: 0.5f, recurrent_dropout: 0.5f); var rnn = tf.keras.layers.RNN(cell: cell); + var cgf = rnn.get_config(); var output = rnn.Apply(inputs); Assert.AreEqual((32, 10), output.shape); @@ -145,5 +145,14 @@ public void GRUCell() Assert.AreEqual((32, 4), output.shape); } + + [TestMethod] + public void Bidirectional() + { + var bi = tf.keras.layers.Bidirectional(keras.layers.LSTM(10, return_sequences:true)); + var inputs = tf.random.normal((32, 10, 8)); + var outputs = bi.Apply(inputs); + Assert.AreEqual((32, 10, 20), outputs.shape); + } } }