Skip to content

Commit

Permalink
Merge pull request #1125 from lingbai-kong/bug-IndexedSlices
Browse files Browse the repository at this point in the history
fix: inconsistent shape error while training embedding layer
  • Loading branch information
Oceania2018 authored Jun 30, 2023
2 parents 6264c79 + f61ab52 commit 991c6b6
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
15 changes: 14 additions & 1 deletion src/TensorFlowNET.Core/Framework/IndexedSlices.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,25 @@ public IndexedSlices(Tensor values, Tensor indices, Tensor dense_shape = null)

public static implicit operator Tensor(IndexedSlices indexedSlices)
{
return indexedSlices.values;
return _indexed_slices_to_tensor(indexedSlices);
}

public static implicit operator IndexedSlices(Tensor tensor)
{
return tensor.Tag as IndexedSlices;
}

/// <summary>
/// Converts an IndexedSlices object `value` to a Tensor.
/// </summary>
/// <param name="indexedSlices"></param>
/// <param name="dtype"></param>
/// <param name="name"></param>
/// <param name="as_ref"></param>
/// <returns></returns>
public static Tensor _indexed_slices_to_tensor(IndexedSlices indexedSlices, TF_DataType dtype = TF_DataType.DtInvalid, String name = "", bool as_ref = false)
{
return gen_math_ops.unsorted_segment_sum(indexedSlices.values, indexedSlices.indices, indexedSlices.dense_shape.slice(0));
}
}
}
11 changes: 11 additions & 0 deletions test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,17 @@ public void Embedding()
var output_array = model.predict(input_array);
Assert.AreEqual((32, 10, 64), output_array.shape);
}
[TestMethod]
public void EmbeddingGrad()
{
var inputs = keras.layers.Input(shape: new[] { 32, 10 });
var outputs = keras.layers.Embedding(1000, 64, input_length: 10).Apply(inputs);
var model = keras.Model(inputs: inputs, outputs: outputs);
var input_array = np.random.randint(1000, size: (1, 32, 10));
var output_array = np.random.random(size: (1, 32, 10, 64));
model.compile("rmsprop", "mse", new[] { "accuracy" });
model.fit(input_array, output_array);
}

/// <summary>
/// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Dense
Expand Down

0 comments on commit 991c6b6

Please sign in to comment.