From 0c9437afcb9cc5852abcbd31bcb85c08afef0ab7 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E2=80=9CWanglongzhi2001=E2=80=9D?= <“583087864@qq.com”>
Date: Tue, 18 Jul 2023 23:31:45 +0800
Subject: [PATCH] feat: add Bidirectional layer
---
.../ArgsDefinition/Rnn/BidirectionalArgs.cs | 20 ++
.../Keras/ArgsDefinition/Rnn/LSTMArgs.cs | 5 +
.../Keras/ArgsDefinition/Rnn/RNNArgs.cs | 5 +
.../Keras/ArgsDefinition/Rnn/WrapperArgs.cs | 24 ++
.../Keras/Layers/ILayersApi.cs | 14 +-
src/TensorFlowNET.Keras/Layers/LayersApi.cs | 14 +
.../Layers/Rnn/BaseWrapper.cs | 33 +++
.../Layers/Rnn/Bidirectional.cs | 276 ++++++++++++++++++
src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs | 31 +-
src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs | 11 +-
.../Layers/Rnn.Test.cs | 13 +-
11 files changed, 428 insertions(+), 18 deletions(-)
create mode 100644 src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/BidirectionalArgs.cs
create mode 100644 src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/WrapperArgs.cs
create mode 100644 src/TensorFlowNET.Keras/Layers/Rnn/BaseWrapper.cs
create mode 100644 src/TensorFlowNET.Keras/Layers/Rnn/Bidirectional.cs
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/BidirectionalArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/BidirectionalArgs.cs
new file mode 100644
index 000000000..d658a82e9
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/BidirectionalArgs.cs
@@ -0,0 +1,20 @@
+using Newtonsoft.Json;
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Tensorflow.NumPy;
+
+namespace Tensorflow.Keras.ArgsDefinition
+{
+ public class BidirectionalArgs : AutoSerializeLayerArgs
+ {
+ [JsonProperty("layer")]
+ public ILayer Layer { get; set; }
+ [JsonProperty("merge_mode")]
+ public string? MergeMode { get; set; }
+ [JsonProperty("backward_layer")]
+ public ILayer BackwardLayer { get; set; }
+ public NDArray Weights { get; set; }
+ }
+
+}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMArgs.cs
index d816b0ff7..a6beb77e8 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMArgs.cs
@@ -5,5 +5,10 @@ public class LSTMArgs : RNNArgs
// TODO: maybe change the `RNNArgs` and implement this class.
public bool UnitForgetBias { get; set; }
public int Implementation { get; set; }
+
+ public LSTMArgs Clone()
+ {
+ return (LSTMArgs)MemberwiseClone();
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs
index b84d30d3d..d0b73ba44 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs
@@ -40,5 +40,10 @@ public class RNNArgs : AutoSerializeLayerArgs
public bool ZeroOutputForMask { get; set; } = false;
[JsonProperty("recurrent_dropout")]
public float RecurrentDropout { get; set; } = .0f;
+
+ public RNNArgs Clone()
+ {
+ return (RNNArgs)MemberwiseClone();
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/WrapperArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/WrapperArgs.cs
new file mode 100644
index 000000000..ec8e16d59
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/WrapperArgs.cs
@@ -0,0 +1,24 @@
+using Newtonsoft.Json;
+using System;
+using System.Collections.Generic;
+using System.Runtime.CompilerServices;
+using System.Text;
+
+
+namespace Tensorflow.Keras.ArgsDefinition
+{
+ public class WrapperArgs : AutoSerializeLayerArgs
+ {
+ [JsonProperty("layer")]
+ public ILayer Layer { get; set; }
+
+ public WrapperArgs(ILayer layer)
+ {
+ Layer = layer;
+ }
+
+ public static implicit operator WrapperArgs(BidirectionalArgs args)
+ => new WrapperArgs(args.Layer);
+ }
+
+}
diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
index 1670f9d1d..b8aff5fb6 100644
--- a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
+++ b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
@@ -258,7 +258,19 @@ public IRnnCell GRUCell(
float dropout = 0f,
float recurrent_dropout = 0f,
bool reset_after = true);
-
+
+ ///
+ /// Bidirectional wrapper for RNNs.
+ ///
+ /// `keras.layers.RNN` instance, such as `keras.layers.LSTM` or `keras.layers.GRU`
+ /// automatically.
+ ///
+ public ILayer Bidirectional(
+ ILayer layer,
+ string merge_mode = "concat",
+ NDArray weights = null,
+ ILayer backward_layer = null);
+
public ILayer Subtract();
}
}
diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs
index cb85bbba1..a04a9c051 100644
--- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs
+++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs
@@ -908,6 +908,20 @@ public IRnnCell GRUCell(
ResetAfter = reset_after
});
+ public ILayer Bidirectional(
+ ILayer layer,
+ string merge_mode = "concat",
+ NDArray weights = null,
+ ILayer backward_layer = null)
+ => new Bidirectional(new BidirectionalArgs
+ {
+ Layer = layer,
+ MergeMode = merge_mode,
+ Weights = weights,
+ BackwardLayer = backward_layer
+ });
+
+
///
///
///
diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/BaseWrapper.cs b/src/TensorFlowNET.Keras/Layers/Rnn/BaseWrapper.cs
new file mode 100644
index 000000000..737f88cd4
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Layers/Rnn/BaseWrapper.cs
@@ -0,0 +1,33 @@
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Text;
+using Tensorflow.Keras.ArgsDefinition;
+using Tensorflow.Keras.Saving;
+
+namespace Tensorflow.Keras.Layers
+{
+ ///
+ /// Abstract wrapper base class. Wrappers take another layer and augment it in various ways.
+ /// Do not use this class as a layer, it is only an abstract base class.
+ /// Two usable wrappers are the `TimeDistributed` and `Bidirectional` wrappers.
+ ///
+ public abstract class Wrapper: Layer
+ {
+ public ILayer _layer;
+ public Wrapper(WrapperArgs args):base(args)
+ {
+ _layer = args.Layer;
+ }
+
+ public virtual void Build(KerasShapesWrapper input_shape)
+ {
+ if (!_layer.Built)
+ {
+ _layer.build(input_shape);
+ }
+ built = true;
+ }
+
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/Bidirectional.cs b/src/TensorFlowNET.Keras/Layers/Rnn/Bidirectional.cs
new file mode 100644
index 000000000..6114d9c7c
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Layers/Rnn/Bidirectional.cs
@@ -0,0 +1,276 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using Tensorflow.Common.Types;
+using Tensorflow.Keras.ArgsDefinition;
+using Tensorflow.Keras.Saving;
+
+namespace Tensorflow.Keras.Layers
+{
+ ///
+ /// Bidirectional wrapper for RNNs.
+ ///
+ public class Bidirectional: Wrapper
+ {
+ BidirectionalArgs _args;
+ RNN _forward_layer;
+ RNN _backward_layer;
+ RNN _layer;
+ bool _support_masking = true;
+ int _num_constants = 0;
+ bool _return_state;
+ bool _stateful;
+ bool _return_sequences;
+ InputSpec _input_spec;
+ RNNArgs _layer_args_copy;
+ public Bidirectional(BidirectionalArgs args):base(args)
+ {
+ _args = args;
+ if (_args.Layer is not ILayer)
+ throw new ValueError(
+ "Please initialize `Bidirectional` layer with a " +
+ $"`tf.keras.layers.Layer` instance. Received: {_args.Layer}");
+
+ if (_args.BackwardLayer is not null && _args.BackwardLayer is not ILayer)
+ throw new ValueError(
+ "`backward_layer` need to be a `tf.keras.layers.Layer` " +
+ $"instance. Received: {_args.BackwardLayer}");
+ if (!new List { "sum", "mul", "ave", "concat", null }.Contains(_args.MergeMode))
+ {
+ throw new ValueError(
+ $"Invalid merge mode. Received: {_args.MergeMode}. " +
+ "Merge mode should be one of " +
+ "{\"sum\", \"mul\", \"ave\", \"concat\", null}"
+ );
+ }
+ if (_args.Layer is RNN)
+ {
+ _layer = _args.Layer as RNN;
+ }
+ else
+ {
+ throw new ValueError(
+ "Bidirectional only support RNN instance such as LSTM or GRU");
+ }
+ _return_state = _layer.Args.ReturnState;
+ _return_sequences = _layer.Args.ReturnSequences;
+ _stateful = _layer.Args.Stateful;
+ _layer_args_copy = _layer.Args.Clone();
+ // We don't want to track `layer` since we're already tracking the two
+ // copies of it we actually run.
+ // TODO(Wanglongzhi2001), since the feature of setattr_tracking has not been implemented.
+ // _setattr_tracking = false;
+ // super().__init__(layer, **kwargs)
+ // _setattr_tracking = true;
+
+ // Recreate the forward layer from the original layer config, so that it
+ // will not carry over any state from the layer.
+ var actualType = _layer.GetType();
+ if (actualType == typeof(LSTM))
+ {
+ var arg = _layer_args_copy as LSTMArgs;
+ _forward_layer = new LSTM(arg);
+ }
+ // TODO(Wanglongzhi2001), add GRU if case.
+ else
+ {
+ _forward_layer = new RNN(_layer.Cell, _layer_args_copy);
+ }
+ //_forward_layer = _recreate_layer_from_config(_layer);
+ if (_args.BackwardLayer is null)
+ {
+ _backward_layer = _recreate_layer_from_config(_layer, go_backwards:true);
+ }
+ else
+ {
+ _backward_layer = _args.BackwardLayer as RNN;
+ }
+ _forward_layer.Name = "forward_" + _forward_layer.Name;
+ _backward_layer.Name = "backward_" + _backward_layer.Name;
+ _verify_layer_config();
+
+ void force_zero_output_for_mask(RNN layer)
+ {
+ layer.Args.ZeroOutputForMask = layer.Args.ReturnSequences;
+ }
+
+ force_zero_output_for_mask(_forward_layer);
+ force_zero_output_for_mask(_backward_layer);
+
+ if (_args.Weights is not null)
+ {
+ var nw = len(_args.Weights);
+ _forward_layer.set_weights(_args.Weights[$":,{nw / 2}"]);
+ _backward_layer.set_weights(_args.Weights[$"{nw / 2},:"]);
+ }
+
+ _input_spec = _layer.InputSpec;
+ }
+
+ private void _verify_layer_config()
+ {
+ if (_forward_layer.Args.GoBackwards == _backward_layer.Args.GoBackwards)
+ {
+ throw new ValueError(
+ "Forward layer and backward layer should have different " +
+ "`go_backwards` value." +
+ "forward_layer.go_backwards = " +
+ $"{_forward_layer.Args.GoBackwards}," +
+ "backward_layer.go_backwards = " +
+ $"{_backward_layer.Args.GoBackwards}");
+ }
+ if (_forward_layer.Args.Stateful != _backward_layer.Args.Stateful)
+ {
+ throw new ValueError(
+ "Forward layer and backward layer are expected to have "+
+ $"the same value for attribute stateful, got "+
+ $"{_forward_layer.Args.Stateful} for forward layer and "+
+ $"{_backward_layer.Args.Stateful} for backward layer");
+ }
+ if (_forward_layer.Args.ReturnState != _backward_layer.Args.ReturnState)
+ {
+ throw new ValueError(
+ "Forward layer and backward layer are expected to have " +
+ $"the same value for attribute return_state, got " +
+ $"{_forward_layer.Args.ReturnState} for forward layer and " +
+ $"{_backward_layer.Args.ReturnState} for backward layer");
+ }
+ if (_forward_layer.Args.ReturnSequences != _backward_layer.Args.ReturnSequences)
+ {
+ throw new ValueError(
+ "Forward layer and backward layer are expected to have " +
+ $"the same value for attribute return_sequences, got " +
+ $"{_forward_layer.Args.ReturnSequences} for forward layer and " +
+ $"{_backward_layer.Args.ReturnSequences} for backward layer");
+ }
+ }
+
+ private RNN _recreate_layer_from_config(RNN layer, bool go_backwards = false)
+ {
+ var config = layer.get_config() as RNNArgs;
+ var cell = layer.Cell;
+ if (go_backwards)
+ {
+ config.GoBackwards = !config.GoBackwards;
+ }
+ var actualType = layer.GetType();
+ if (actualType == typeof(LSTM))
+ {
+ var arg = config as LSTMArgs;
+ return new LSTM(arg);
+ }
+ else
+ {
+ return new RNN(cell, config);
+ }
+ }
+
+ public override void build(KerasShapesWrapper input_shape)
+ {
+ _buildInputShape = input_shape;
+ tf_with(ops.name_scope(_forward_layer.Name), scope=>
+ {
+ _forward_layer.build(input_shape);
+ });
+ tf_with(ops.name_scope(_backward_layer.Name), scope =>
+ {
+ _backward_layer.build(input_shape);
+ });
+ built = true;
+ }
+
+ protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
+ {
+ // `Bidirectional.call` implements the same API as the wrapped `RNN`.
+
+ Tensors forward_inputs;
+ Tensors backward_inputs;
+ Tensors forward_state;
+ Tensors backward_state;
+ // if isinstance(inputs, list) and len(inputs) > 1:
+ if (inputs.Length > 1)
+ {
+ // initial_states are keras tensors, which means they are passed
+ // in together with inputs as list. The initial_states need to be
+ // split into forward and backward section, and be feed to layers
+ // accordingly.
+ forward_inputs = new Tensors { inputs[0] };
+ backward_inputs = new Tensors { inputs[0] };
+ var pivot = (len(inputs) - _num_constants) / 2 + 1;
+ // add forward initial state
+ forward_inputs.Concat(new Tensors { inputs[$"1:{pivot}"] });
+ if (_num_constants != 0)
+ // add backward initial state
+ backward_inputs.Concat(new Tensors { inputs[$"{pivot}:"] });
+ else
+ {
+ // add backward initial state
+ backward_inputs.Concat(new Tensors { inputs[$"{pivot}:{-_num_constants}"] });
+ // add constants for forward and backward layers
+ forward_inputs.Concat(new Tensors { inputs[$"{-_num_constants}:"] });
+ backward_inputs.Concat(new Tensors { inputs[$"{-_num_constants}:"] });
+ }
+ forward_state = null;
+ backward_state = null;
+ }
+ else if (state is not null)
+ {
+ // initial_states are not keras tensors, eg eager tensor from np
+ // array. They are only passed in from kwarg initial_state, and
+ // should be passed to forward/backward layer via kwarg
+ // initial_state as well.
+ forward_inputs = inputs;
+ backward_inputs = inputs;
+ var half = len(state) / 2;
+ forward_state = state[$":{half}"];
+ backward_state = state[$"{half}:"];
+ }
+ else
+ {
+ forward_inputs = inputs;
+ backward_inputs = inputs;
+ forward_state = null;
+ backward_state = null;
+ }
+ var y = _forward_layer.Apply(forward_inputs, forward_state);
+ var y_rev = _backward_layer.Apply(backward_inputs, backward_state);
+
+ Tensors states = new();
+ if (_return_state)
+ {
+ states = y["1:"] + y_rev["1:"];
+ y = y[0];
+ y_rev = y_rev[0];
+ }
+
+ if (_return_sequences)
+ {
+ int time_dim = _forward_layer.Args.TimeMajor ? 0 : 1;
+ y_rev = keras.backend.reverse(y_rev, time_dim);
+ }
+ Tensors output;
+ if (_args.MergeMode == "concat")
+ output = keras.backend.concatenate(new Tensors { y.Single(), y_rev.Single() });
+ else if (_args.MergeMode == "sum")
+ output = y.Single() + y_rev.Single();
+ else if (_args.MergeMode == "ave")
+ output = (y.Single() + y_rev.Single()) / 2;
+ else if (_args.MergeMode == "mul")
+ output = y.Single() * y_rev.Single();
+ else if (_args.MergeMode is null)
+ output = new Tensors { y.Single(), y_rev.Single() };
+ else
+ throw new ValueError(
+ "Unrecognized value for `merge_mode`. " +
+ $"Received: {_args.MergeMode}" +
+ "Expected values are [\"concat\", \"sum\", \"ave\", \"mul\"]");
+ if (_return_state)
+ {
+ if (_args.MergeMode is not null)
+ return new Tensors { output.Single(), states.Single()};
+ }
+ return output;
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs b/src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs
index b5d583248..c766e8d69 100644
--- a/src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs
+++ b/src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs
@@ -3,6 +3,7 @@
using Tensorflow.Keras.Engine;
using Tensorflow.Common.Types;
using Tensorflow.Common.Extensions;
+using Tensorflow.Keras.Saving;
namespace Tensorflow.Keras.Layers
{
@@ -14,15 +15,15 @@ namespace Tensorflow.Keras.Layers
///
public class LSTM : RNN
{
- LSTMArgs args;
+ LSTMArgs _args;
InputSpec[] _state_spec;
InputSpec _input_spec;
bool _could_use_gpu_kernel;
-
+ public LSTMArgs Args { get => _args; }
public LSTM(LSTMArgs args) :
base(CreateCell(args), args)
{
- this.args = args;
+ _args = args;
_input_spec = new InputSpec(ndim: 3);
_state_spec = new[] { args.Units, args.Units }.Select(dim => new InputSpec(shape: (-1, dim))).ToArray();
_could_use_gpu_kernel = args.Activation == keras.activations.Tanh
@@ -71,7 +72,7 @@ protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bo
var single_input = inputs.Single;
var input_shape = single_input.shape;
- var timesteps = args.TimeMajor ? input_shape[0] : input_shape[1];
+ var timesteps = _args.TimeMajor ? input_shape[0] : input_shape[1];
_maybe_reset_cell_dropout_mask(Cell);
@@ -87,26 +88,26 @@ protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bo
inputs,
initial_state,
constants: null,
- go_backwards: args.GoBackwards,
+ go_backwards: _args.GoBackwards,
mask: mask,
- unroll: args.Unroll,
+ unroll: _args.Unroll,
input_length: ops.convert_to_tensor(timesteps),
- time_major: args.TimeMajor,
- zero_output_for_mask: args.ZeroOutputForMask,
- return_all_outputs: args.ReturnSequences
+ time_major: _args.TimeMajor,
+ zero_output_for_mask: _args.ZeroOutputForMask,
+ return_all_outputs: _args.ReturnSequences
);
Tensor output;
- if (args.ReturnSequences)
+ if (_args.ReturnSequences)
{
- output = keras.backend.maybe_convert_to_ragged(false, outputs, (int)timesteps, args.GoBackwards);
+ output = keras.backend.maybe_convert_to_ragged(false, outputs, (int)timesteps, _args.GoBackwards);
}
else
{
output = last_output;
}
- if (args.ReturnState)
+ if (_args.ReturnState)
{
return new Tensor[] { output }.Concat(states).ToArray().ToTensors();
}
@@ -115,5 +116,11 @@ protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bo
return output;
}
}
+
+ public override IKerasConfig get_config()
+ {
+ return _args;
+ }
+
}
}
diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs
index 0e81d20e3..c19222614 100644
--- a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs
+++ b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs
@@ -31,7 +31,9 @@ public class RNN : RnnBase
protected IVariableV1 _kernel;
protected IVariableV1 _bias;
private IRnnCell _cell;
- protected IRnnCell Cell
+
+ public RNNArgs Args { get => _args; }
+ public IRnnCell Cell
{
get
{
@@ -570,10 +572,13 @@ protected Tensors get_initial_state(Tensors inputs)
var input_shape = array_ops.shape(inputs);
var batch_size = _args.TimeMajor ? input_shape[1] : input_shape[0];
var dtype = input.dtype;
-
Tensors init_state = Cell.GetInitialState(null, batch_size, dtype);
-
return init_state;
}
+
+ public override IKerasConfig get_config()
+ {
+ return _args;
+ }
}
}
diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs
index 5f7bd574e..03159346a 100644
--- a/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs
+++ b/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs
@@ -5,6 +5,7 @@
using System.Text;
using System.Threading.Tasks;
using Tensorflow.Common.Types;
+using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Saving;
@@ -38,8 +39,6 @@ public void StackedRNNCell()
var cells = new IRnnCell[] { tf.keras.layers.SimpleRNNCell(4), tf.keras.layers.SimpleRNNCell(5) };
var stackedRNNCell = tf.keras.layers.StackedRNNCells(cells);
var (output, state) = stackedRNNCell.Apply(inputs, states);
- Console.WriteLine(output);
- Console.WriteLine(state.shape);
Assert.AreEqual((32, 5), output.shape);
Assert.AreEqual((32, 4), state[0].shape);
}
@@ -108,6 +107,7 @@ public void RNNForSimpleRNNCell()
var inputs = tf.random.normal((32, 10, 8));
var cell = tf.keras.layers.SimpleRNNCell(10, dropout: 0.5f, recurrent_dropout: 0.5f);
var rnn = tf.keras.layers.RNN(cell: cell);
+ var cgf = rnn.get_config();
var output = rnn.Apply(inputs);
Assert.AreEqual((32, 10), output.shape);
@@ -145,5 +145,14 @@ public void GRUCell()
Assert.AreEqual((32, 4), output.shape);
}
+
+ [TestMethod]
+ public void Bidirectional()
+ {
+ var bi = tf.keras.layers.Bidirectional(keras.layers.LSTM(10, return_sequences:true));
+ var inputs = tf.random.normal((32, 10, 8));
+ var outputs = bi.Apply(inputs);
+ Assert.AreEqual((32, 10, 20), outputs.shape);
+ }
}
}