Skip to content

Commit a73694a

Browse files
fix: add the implementation of the tile's grad
1 parent d5f5c57 commit a73694a

File tree

3 files changed

+39
-1
lines changed

3 files changed

+39
-1
lines changed

src/TensorFlowNET.Core/Gradients/array_grad.cs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,5 +381,29 @@ public static Tensor[] _ReverseV2Grad(Operation op, Tensor[] grads)
381381
var axis = op.inputs[1];
382382
return new Tensor[] { array_ops.reverse(grad, axis), null };
383383
}
384+
385+
[RegisterGradient("Tile")]
386+
public static Tensor[] _TileGrad(Operation op, Tensor[] grads)
387+
{
388+
var grad = grads[0];
389+
var input_shape = array_ops.shape(op.inputs[0], out_type: op.inputs[1].dtype);
390+
var split_shape = array_ops.reshape(array_ops.transpose(array_ops.stack(new Tensor[] { op.inputs[1], input_shape })), new Shape(-1));
391+
var axes = math_ops.range(0, array_ops.size(split_shape), 2);
392+
393+
//# Sum reduces grad along the first dimension for IndexedSlices
394+
//if isinstance(grad, indexed_slices_lib.IndexedSlices):
395+
//input_shape_0 = math_ops.cast(input_shape[0], grad.indices.dtype)
396+
//grad = math_ops.unsorted_segment_sum(
397+
// grad.values, math_ops.mod(grad.indices, input_shape_0), input_shape_0)
398+
//split_shape = array_ops.concat([[1], split_shape[1:]], axis = 0)
399+
400+
var input_grad = math_ops.reduce_sum(array_ops.reshape(grad, split_shape), axes);
401+
if (!tf.Context.executing_eagerly())
402+
{
403+
input_grad.set_shape(op.inputs[0].GetShape());
404+
}
405+
return new Tensor[] { input_grad, null };
406+
407+
}
384408
}
385409
}

src/TensorFlowNET.Core/Operations/array_ops.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -990,7 +990,7 @@ public static Tensor gather(ResourceVariable @params, Tensor indices, string nam
990990
return @params.sparse_read(indices, name);
991991
}
992992

993-
public static Tensor transpose<T1>(T1 a, Axis perm, string name = "transpose", bool conjugate = false)
993+
public static Tensor transpose<T1>(T1 a, Axis perm = null, string name = "transpose", bool conjugate = false)
994994
{
995995
return tf_with(ops.name_scope(name, "transpose", new { a }), scope =>
996996
{

test/TensorFlowNET.UnitTest/GradientTest/GradientEagerTest.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,5 +173,19 @@ public void ConditionalMultiply()
173173
var result = grad(x, 4);
174174
Assert.AreEqual((float)result, 4.0f);
175175
}
176+
177+
[TestMethod]
178+
public void Tile()
179+
{
180+
var a = tf.constant(new int[] { 1 }, TF_DataType.TF_FLOAT);
181+
var b = tf.constant(new int[] { 2 });
182+
using (var tape = tf.GradientTape())
183+
{
184+
tape.watch(a);
185+
var y = tf.tile(a, b);
186+
var grad = tape.gradient(y, a);
187+
Assert.AreEqual((float)grad.numpy(), 2.0f);
188+
}
189+
}
176190
}
177191
}

0 commit comments

Comments
 (0)