diff --git a/src/TensorFlowNET.Core/Framework/IndexedSlices.cs b/src/TensorFlowNET.Core/Framework/IndexedSlices.cs
index 24d356fbb..bac5e6fb1 100644
--- a/src/TensorFlowNET.Core/Framework/IndexedSlices.cs
+++ b/src/TensorFlowNET.Core/Framework/IndexedSlices.cs
@@ -49,12 +49,25 @@ public IndexedSlices(Tensor values, Tensor indices, Tensor dense_shape = null)
public static implicit operator Tensor(IndexedSlices indexedSlices)
{
- return indexedSlices.values;
+ return _indexed_slices_to_tensor(indexedSlices);
}
public static implicit operator IndexedSlices(Tensor tensor)
{
return tensor.Tag as IndexedSlices;
}
+
+ ///
+ /// Converts an IndexedSlices object `value` to a Tensor.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static Tensor _indexed_slices_to_tensor(IndexedSlices indexedSlices, TF_DataType dtype = TF_DataType.DtInvalid, String name = "", bool as_ref = false)
+ {
+ return gen_math_ops.unsorted_segment_sum(indexedSlices.values, indexedSlices.indices, indexedSlices.dense_shape.slice(0));
+ }
}
}
diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs
index 98d909668..7ebb53db3 100644
--- a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs
+++ b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs
@@ -110,6 +110,17 @@ public void Embedding()
var output_array = model.predict(input_array);
Assert.AreEqual((32, 10, 64), output_array.shape);
}
+ [TestMethod]
+ public void EmbeddingGrad()
+ {
+ var inputs = keras.layers.Input(shape: new[] { 32, 10 });
+ var outputs = keras.layers.Embedding(1000, 64, input_length: 10).Apply(inputs);
+ var model = keras.Model(inputs: inputs, outputs: outputs);
+ var input_array = np.random.randint(1000, size: (1, 32, 10));
+ var output_array = np.random.random(size: (1, 32, 10, 64));
+ model.compile("rmsprop", "mse", new[] { "accuracy" });
+ model.fit(input_array, output_array);
+ }
///
/// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Dense