From f61ab520c91de2b25bf09356735b9617278f5a44 Mon Sep 17 00:00:00 2001 From: lingbai-kong Date: Fri, 30 Jun 2023 21:25:35 +0800 Subject: [PATCH] fix inconsistent shape error while training Embedding layer. --- src/TensorFlowNET.Core/Framework/IndexedSlices.cs | 15 ++++++++++++++- .../Layers/LayersTest.cs | 11 +++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/src/TensorFlowNET.Core/Framework/IndexedSlices.cs b/src/TensorFlowNET.Core/Framework/IndexedSlices.cs index 24d356fbb..bac5e6fb1 100644 --- a/src/TensorFlowNET.Core/Framework/IndexedSlices.cs +++ b/src/TensorFlowNET.Core/Framework/IndexedSlices.cs @@ -49,12 +49,25 @@ public IndexedSlices(Tensor values, Tensor indices, Tensor dense_shape = null) public static implicit operator Tensor(IndexedSlices indexedSlices) { - return indexedSlices.values; + return _indexed_slices_to_tensor(indexedSlices); } public static implicit operator IndexedSlices(Tensor tensor) { return tensor.Tag as IndexedSlices; } + + /// + /// Converts an IndexedSlices object `value` to a Tensor. + /// + /// + /// + /// + /// + /// + public static Tensor _indexed_slices_to_tensor(IndexedSlices indexedSlices, TF_DataType dtype = TF_DataType.DtInvalid, String name = "", bool as_ref = false) + { + return gen_math_ops.unsorted_segment_sum(indexedSlices.values, indexedSlices.indices, indexedSlices.dense_shape.slice(0)); + } } } diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs index 98d909668..7ebb53db3 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs @@ -110,6 +110,17 @@ public void Embedding() var output_array = model.predict(input_array); Assert.AreEqual((32, 10, 64), output_array.shape); } + [TestMethod] + public void EmbeddingGrad() + { + var inputs = keras.layers.Input(shape: new[] { 32, 10 }); + var outputs = keras.layers.Embedding(1000, 64, input_length: 10).Apply(inputs); + var model = keras.Model(inputs: inputs, outputs: outputs); + var input_array = np.random.randint(1000, size: (1, 32, 10)); + var output_array = np.random.random(size: (1, 32, 10, 64)); + model.compile("rmsprop", "mse", new[] { "accuracy" }); + model.fit(input_array, output_array); + } /// /// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Dense