diff --git a/src/TensorFlowNET.Core/Gradients/math_grad.cs b/src/TensorFlowNET.Core/Gradients/math_grad.cs
index 22d3c641b..89699d6bc 100644
--- a/src/TensorFlowNET.Core/Gradients/math_grad.cs
+++ b/src/TensorFlowNET.Core/Gradients/math_grad.cs
@@ -840,7 +840,7 @@ public static Tensor[] _PowGrad(Operation op, Tensor[] grads)
///
///
///
- private static (Tensor, Tensor, bool)[] SmartBroadcastGradientArgs(Tensor x, Tensor y, Tensor grad)
+ public static (Tensor, Tensor, bool)[] SmartBroadcastGradientArgs(Tensor x, Tensor y, Tensor grad)
{
Tensor sx, sy;
if (x.shape.IsFullyDefined &&
diff --git a/src/TensorFlowNET.Core/Gradients/nn_grad.cs b/src/TensorFlowNET.Core/Gradients/nn_grad.cs
index 15b72f55c..e95163930 100644
--- a/src/TensorFlowNET.Core/Gradients/nn_grad.cs
+++ b/src/TensorFlowNET.Core/Gradients/nn_grad.cs
@@ -15,6 +15,7 @@ limitations under the License.
******************************************************************************/
using System;
+using System.Diagnostics;
using System.Linq;
using Tensorflow.Operations;
using static Tensorflow.Binding;
@@ -135,13 +136,35 @@ public static Tensor[] _SquaredDifferenceGrad(Operation op, Tensor[] grads)
{
Tensor x = op.inputs[0];
Tensor y = op.inputs[1];
+ var grad = grads[0];
var scale = ops.convert_to_tensor(2.0f, dtype: x.dtype);
- var x_grad = math_ops.scalar_mul(scale, grads[0]) * (x - y);
- return new Tensor[]
+ var x_grad = math_ops.scalar_mul(scale, grad) * (x - y);
+ if (math_grad._ShapesFullySpecifiedAndEqual(x, y, grad))
{
- x_grad,
- -x_grad
- };
+ return new Tensor[] { x_grad, -x_grad };
+ }
+ var broadcast_info = math_grad.SmartBroadcastGradientArgs(x, y, grad);
+ Debug.Assert(broadcast_info.Length == 2);
+ var (sx, rx, must_reduce_x) = broadcast_info[0];
+ var (sy, ry, must_reduce_y) = broadcast_info[1];
+ Tensor gx, gy;
+ if (must_reduce_x)
+ {
+ gx = array_ops.reshape(math_ops.reduce_sum(x_grad, rx), sx);
+ }
+ else
+ {
+ gx = x_grad;
+ }
+ if (must_reduce_y)
+ {
+ gy = -array_ops.reshape(math_ops.reduce_sum(x_grad, ry), sy);
+ }
+ else
+ {
+ gy = -x_grad;
+ }
+ return new Tensor[] { gx, gy };
}
///