Skip to content

Commit

Permalink
Merge pull request #1148 from Wanglongzhi2001/master
Browse files Browse the repository at this point in the history
fix: make the initialization of the layer's name correct
  • Loading branch information
Oceania2018 authored Jul 14, 2023
2 parents dffc465 + f6f792a commit 12e3f54
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 5 deletions.
14 changes: 9 additions & 5 deletions src/TensorFlowNET.Keras/Utils/generic_utils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ limitations under the License.
using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Saving;
using Tensorflow.Train;
using System.Text.RegularExpressions;

namespace Tensorflow.Keras.Utils
{
Expand Down Expand Up @@ -126,12 +127,15 @@ public static FunctionalConfig deserialize_model_config(JToken json)

public static string to_snake_case(string name)
{
return string.Concat(name.Select((x, i) =>
string intermediate = Regex.Replace(name, "(.)([A-Z][a-z0-9]+)", "$1_$2");
string insecure = Regex.Replace(intermediate, "([a-z])([A-Z])", "$1_$2").ToLower();

if (insecure[0] != '_')
{
return i > 0 && char.IsUpper(x) && !Char.IsDigit(name[i - 1]) ?
"_" + x.ToString() :
x.ToString();
})).ToLower();
return insecure;
}

return "private" + insecure;
}

/// <summary>
Expand Down
33 changes: 33 additions & 0 deletions test/TensorFlowNET.Keras.UnitTest/InitLayerNameTest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow.Keras.Layers;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;

namespace Tensorflow.Keras.UnitTest
{
[TestClass]
public class InitLayerNameTest
{
[TestMethod]
public void RNNLayerNameTest()
{
var simpleRnnCell = keras.layers.SimpleRNNCell(1);
Assert.AreEqual("simple_rnn_cell", simpleRnnCell.Name);
var simpleRnn = keras.layers.SimpleRNN(2);
Assert.AreEqual("simple_rnn", simpleRnn.Name);
var lstmCell = keras.layers.LSTMCell(2);
Assert.AreEqual("lstm_cell", lstmCell.Name);
var lstm = keras.layers.LSTM(3);
Assert.AreEqual("lstm", lstm.Name);
}

[TestMethod]
public void ConvLayerNameTest()
{
var conv2d = keras.layers.Conv2D(8, activation: "linear");
Assert.AreEqual("conv2d", conv2d.Name);
var conv2dTranspose = keras.layers.Conv2DTranspose(8);
Assert.AreEqual("conv2d_transpose", conv2dTranspose.Name);
}
}
}

0 comments on commit 12e3f54

Please sign in to comment.