Skip to content

Commit

Permalink
Merge pull request #1149 from Wanglongzhi2001/master
Browse files Browse the repository at this point in the history
feat: add Bidirectional layer
  • Loading branch information
Oceania2018 authored Jul 19, 2023
2 parents fa5d19d + 0c9437a commit fa2d2dc
Show file tree
Hide file tree
Showing 11 changed files with 428 additions and 18 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.NumPy;

namespace Tensorflow.Keras.ArgsDefinition
{
public class BidirectionalArgs : AutoSerializeLayerArgs
{
[JsonProperty("layer")]
public ILayer Layer { get; set; }
[JsonProperty("merge_mode")]
public string? MergeMode { get; set; }
[JsonProperty("backward_layer")]
public ILayer BackwardLayer { get; set; }
public NDArray Weights { get; set; }
}

}
5 changes: 5 additions & 0 deletions src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMArgs.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,10 @@ public class LSTMArgs : RNNArgs
// TODO: maybe change the `RNNArgs` and implement this class.
public bool UnitForgetBias { get; set; }
public int Implementation { get; set; }

public LSTMArgs Clone()
{
return (LSTMArgs)MemberwiseClone();
}
}
}
5 changes: 5 additions & 0 deletions src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,10 @@ public class RNNArgs : AutoSerializeLayerArgs
public bool ZeroOutputForMask { get; set; } = false;
[JsonProperty("recurrent_dropout")]
public float RecurrentDropout { get; set; } = .0f;

public RNNArgs Clone()
{
return (RNNArgs)MemberwiseClone();
}
}
}
24 changes: 24 additions & 0 deletions src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/WrapperArgs.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Text;


namespace Tensorflow.Keras.ArgsDefinition
{
public class WrapperArgs : AutoSerializeLayerArgs
{
[JsonProperty("layer")]
public ILayer Layer { get; set; }

public WrapperArgs(ILayer layer)
{
Layer = layer;
}

public static implicit operator WrapperArgs(BidirectionalArgs args)
=> new WrapperArgs(args.Layer);
}

}
14 changes: 13 additions & 1 deletion src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,19 @@ public IRnnCell GRUCell(
float dropout = 0f,
float recurrent_dropout = 0f,
bool reset_after = true);


/// <summary>
/// Bidirectional wrapper for RNNs.
/// </summary>
/// <param name="layer">`keras.layers.RNN` instance, such as `keras.layers.LSTM` or `keras.layers.GRU`</param>
/// automatically.</param>
/// <returns></returns>
public ILayer Bidirectional(
ILayer layer,
string merge_mode = "concat",
NDArray weights = null,
ILayer backward_layer = null);

public ILayer Subtract();
}
}
14 changes: 14 additions & 0 deletions src/TensorFlowNET.Keras/Layers/LayersApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,20 @@ public IRnnCell GRUCell(
ResetAfter = reset_after
});

public ILayer Bidirectional(
ILayer layer,
string merge_mode = "concat",
NDArray weights = null,
ILayer backward_layer = null)
=> new Bidirectional(new BidirectionalArgs
{
Layer = layer,
MergeMode = merge_mode,
Weights = weights,
BackwardLayer = backward_layer
});


/// <summary>
///
/// </summary>
Expand Down
33 changes: 33 additions & 0 deletions src/TensorFlowNET.Keras/Layers/Rnn/BaseWrapper.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.Layers
{
/// <summary>
/// Abstract wrapper base class. Wrappers take another layer and augment it in various ways.
/// Do not use this class as a layer, it is only an abstract base class.
/// Two usable wrappers are the `TimeDistributed` and `Bidirectional` wrappers.
/// </summary>
public abstract class Wrapper: Layer
{
public ILayer _layer;
public Wrapper(WrapperArgs args):base(args)
{
_layer = args.Layer;
}

public virtual void Build(KerasShapesWrapper input_shape)
{
if (!_layer.Built)
{
_layer.build(input_shape);
}
built = true;
}

}
}
Loading

0 comments on commit fa2d2dc

Please sign in to comment.