Skip to content

Commit 079b9a3

Browse files
authored
Merge pull request #1202 from Wanglongzhi2001/master
fix: add the implementation of the tile's and GatherND's grad and add OptionalArgs
2 parents e79ecb7 + d0ec659 commit 079b9a3

File tree

7 files changed

+108
-6
lines changed

7 files changed

+108
-6
lines changed

src/TensorFlowNET.Core/APIs/tf.array.cs

+10
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,16 @@ public Tensor identity(Tensor input, string name = null)
140140
public Tensor gather(Tensor @params, Tensor indices, string name = null, int axis = 0)
141141
=> array_ops.gather(@params, indices, name: name, axis: ops.convert_to_tensor(axis));
142142

143+
/// <summary>
144+
/// Gather slices from `params` into a Tensor with shape specified by `indices`.
145+
/// </summary>
146+
/// <param name="params"></param>
147+
/// <param name="indices"></param>
148+
/// <param name="name"></param>
149+
/// <returns></returns>
150+
public Tensor gather_nd(Tensor @params, Tensor indices, string name = null)
151+
=> gen_array_ops.gather_nd(@params, indices, name: name);
152+
143153
/// <summary>
144154
/// Return the elements, either from `x` or `y`, depending on the `condition`.
145155
/// </summary>

src/TensorFlowNET.Core/Gradients/array_grad.cs

+43
Original file line numberDiff line numberDiff line change
@@ -381,5 +381,48 @@ public static Tensor[] _ReverseV2Grad(Operation op, Tensor[] grads)
381381
var axis = op.inputs[1];
382382
return new Tensor[] { array_ops.reverse(grad, axis), null };
383383
}
384+
385+
[RegisterGradient("Tile")]
386+
public static Tensor[] _TileGrad(Operation op, Tensor[] grads)
387+
{
388+
var grad = grads[0];
389+
var input_shape = array_ops.shape(op.inputs[0], out_type: op.inputs[1].dtype);
390+
var split_shape = array_ops.reshape(array_ops.transpose(array_ops.stack(new Tensor[] { op.inputs[1], input_shape })), new Shape(-1));
391+
var axes = math_ops.range(0, array_ops.size(split_shape), 2);
392+
393+
//# Sum reduces grad along the first dimension for IndexedSlices
394+
//if isinstance(grad, indexed_slices_lib.IndexedSlices):
395+
//input_shape_0 = math_ops.cast(input_shape[0], grad.indices.dtype)
396+
//grad = math_ops.unsorted_segment_sum(
397+
// grad.values, math_ops.mod(grad.indices, input_shape_0), input_shape_0)
398+
//split_shape = array_ops.concat([[1], split_shape[1:]], axis = 0)
399+
400+
var input_grad = math_ops.reduce_sum(array_ops.reshape(grad, split_shape), axes);
401+
if (!tf.Context.executing_eagerly())
402+
{
403+
input_grad.set_shape(op.inputs[0].GetShape());
404+
}
405+
return new Tensor[] { input_grad, null };
406+
}
407+
408+
[RegisterGradient("GatherNd")]
409+
public static Tensor[] _GatherNdGrad(Operation op, Tensor[] grads)
410+
{
411+
var @ref = op.inputs[0];
412+
var indices = op.inputs[1];
413+
var grad = grads[0];
414+
var ref_shape = array_ops.shape(@ref, out_type: indices.dtype);
415+
Tensor ref_grad = null;
416+
if (indices.shape.ndim == 2 && indices.shape.dims[indices.shape.Length - 1] == 1)
417+
{
418+
ref_grad = (Tensor)new IndexedSlices(grad, array_ops.squeeze(indices, axis: -1), ref_shape);
419+
}
420+
else
421+
{
422+
ref_grad = gen_array_ops.scatter_nd(indices, grad, ref_shape);
423+
}
424+
return new Tensor[] { ref_grad, null };
425+
}
426+
384427
}
385428
}

src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUOptionalArgs.cs

+1-3
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,8 @@
44

55
namespace Tensorflow.Keras.ArgsDefinition
66
{
7-
public class GRUOptionalArgs
7+
public class GRUOptionalArgs : RnnOptionalArgs
88
{
99
public string Identifier => "GRU";
10-
11-
public Tensor Mask { get; set; } = null;
1210
}
1311
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Keras.ArgsDefinition.Rnn
6+
{
7+
public class LSTMOptionalArgs : RnnOptionalArgs
8+
{
9+
public string Identifier => "LSTM";
10+
}
11+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Keras.ArgsDefinition.Rnn
6+
{
7+
public class SimpleRNNOptionalArgs : RnnOptionalArgs
8+
{
9+
public string Identifier => "SimpleRNN";
10+
}
11+
}

src/TensorFlowNET.Core/Operations/array_ops.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -829,7 +829,7 @@ public static Tensor strided_slice_grad(Tensor shape, Tensor begin, Tensor end,
829829
/// <returns>A `Tensor`. Has the same type as `input`.
830830
/// Contains the same data as `input`, but has one or more dimensions of
831831
/// size 1 removed.</returns>
832-
public static Tensor squeeze(Tensor input, int[] axis = null, string name = null)
832+
public static Tensor squeeze(Tensor input, Axis axis = null, string name = null)
833833
=> gen_array_ops.squeeze(input, axis, name);
834834

835835
public static Tensor identity(Tensor input, string name = null)
@@ -990,7 +990,7 @@ public static Tensor gather(ResourceVariable @params, Tensor indices, string nam
990990
return @params.sparse_read(indices, name);
991991
}
992992

993-
public static Tensor transpose<T1>(T1 a, Axis perm, string name = "transpose", bool conjugate = false)
993+
public static Tensor transpose<T1>(T1 a, Axis perm = null, string name = "transpose", bool conjugate = false)
994994
{
995995
return tf_with(ops.name_scope(name, "transpose", new { a }), scope =>
996996
{

test/TensorFlowNET.UnitTest/GradientTest/GradientEagerTest.cs

+30-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ public void SquaredDifference_1D()
6262
// Calcute the gradient of (x1-x2)^2
6363
// by Automatic Differentiation in Eager mode
6464
// Expected is 2*(abs(x1-x2))
65-
Tensor x1 = new NDArray( new float[] { 1, 3, 5, 21, 19, 17 });
65+
Tensor x1 = new NDArray(new float[] { 1, 3, 5, 21, 19, 17 });
6666
Tensor x2 = new NDArray(new float[] { 29, 27, 23, 7, 11, 13 });
6767
float[] expected = new float[]
6868
{
@@ -173,5 +173,34 @@ public void ConditionalMultiply()
173173
var result = grad(x, 4);
174174
Assert.AreEqual((float)result, 4.0f);
175175
}
176+
177+
[TestMethod]
178+
public void Tile()
179+
{
180+
var a = tf.constant(new int[] { 1 }, TF_DataType.TF_FLOAT);
181+
var b = tf.constant(new int[] { 2 });
182+
using (var tape = tf.GradientTape())
183+
{
184+
tape.watch(a);
185+
var y = tf.tile(a, b);
186+
var grad = tape.gradient(y, a);
187+
Assert.AreEqual((float)grad.numpy(), 2.0f);
188+
}
189+
}
190+
191+
[TestMethod]
192+
public void GatherNdTest()
193+
{
194+
var x = tf.constant(new float[,] { { 1.0f, 2.0f, 3.0f }, { 1.0f, 2.0f, 3.0f }, { 1.0f, 2.0f, 3.0f } }, dtype: TF_DataType.TF_FLOAT);
195+
var indices = tf.constant(new int[,] { { 0, 1 }, { 1, 1 }, { 2, 1 } }, dtype: TF_DataType.TF_INT32);
196+
using (var tape = tf.GradientTape())
197+
{
198+
tape.watch(x);
199+
var res = tf.gather_nd(x, indices);
200+
var grad = tape.gradient(res, x);
201+
var expected = np.array(new float[,] { { 0f, 1f, 0f }, { 0f, 1f, 0f }, { 0f, 1f, 0f } });
202+
Assert.IsTrue(Enumerable.SequenceEqual(grad.ToArray<float>(), expected.ToArray<float>()));
203+
}
204+
}
176205
}
177206
}

0 commit comments

Comments
 (0)