Skip to content

Commit 991c6b6

Browse files
authored
Merge pull request #1125 from lingbai-kong/bug-IndexedSlices
fix: inconsistent shape error while training embedding layer
2 parents 6264c79 + f61ab52 commit 991c6b6

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

src/TensorFlowNET.Core/Framework/IndexedSlices.cs

+14-1
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,25 @@ public IndexedSlices(Tensor values, Tensor indices, Tensor dense_shape = null)
4949

5050
public static implicit operator Tensor(IndexedSlices indexedSlices)
5151
{
52-
return indexedSlices.values;
52+
return _indexed_slices_to_tensor(indexedSlices);
5353
}
5454

5555
public static implicit operator IndexedSlices(Tensor tensor)
5656
{
5757
return tensor.Tag as IndexedSlices;
5858
}
59+
60+
/// <summary>
61+
/// Converts an IndexedSlices object `value` to a Tensor.
62+
/// </summary>
63+
/// <param name="indexedSlices"></param>
64+
/// <param name="dtype"></param>
65+
/// <param name="name"></param>
66+
/// <param name="as_ref"></param>
67+
/// <returns></returns>
68+
public static Tensor _indexed_slices_to_tensor(IndexedSlices indexedSlices, TF_DataType dtype = TF_DataType.DtInvalid, String name = "", bool as_ref = false)
69+
{
70+
return gen_math_ops.unsorted_segment_sum(indexedSlices.values, indexedSlices.indices, indexedSlices.dense_shape.slice(0));
71+
}
5972
}
6073
}

test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs

+11
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,17 @@ public void Embedding()
110110
var output_array = model.predict(input_array);
111111
Assert.AreEqual((32, 10, 64), output_array.shape);
112112
}
113+
[TestMethod]
114+
public void EmbeddingGrad()
115+
{
116+
var inputs = keras.layers.Input(shape: new[] { 32, 10 });
117+
var outputs = keras.layers.Embedding(1000, 64, input_length: 10).Apply(inputs);
118+
var model = keras.Model(inputs: inputs, outputs: outputs);
119+
var input_array = np.random.randint(1000, size: (1, 32, 10));
120+
var output_array = np.random.random(size: (1, 32, 10, 64));
121+
model.compile("rmsprop", "mse", new[] { "accuracy" });
122+
model.fit(input_array, output_array);
123+
}
113124

114125
/// <summary>
115126
/// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Dense

0 commit comments

Comments
 (0)