Skip to content

Commit b968fd7

Browse files
committed
add avg_pool_grad function
1 parent 8574881 commit b968fd7

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

src/TensorFlowNET.Core/Gradients/nn_grad.cs

+17
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,23 @@ public static Tensor[] _MaxPoolGrad(Operation op, Tensor[] grads)
365365
};
366366
}
367367

368+
[RegisterGradient("AvgPool")]
369+
public static Tensor[] _AvgPoolGrad(Operation op, Tensor[] grads)
370+
{
371+
Tensor grad = grads[0];
372+
373+
return new Tensor[]
374+
{
375+
gen_nn_ops.avg_pool_grad(
376+
array_ops.shape(op.inputs[0]),
377+
grad,
378+
op.get_attr_list<int>("ksize"),
379+
op.get_attr_list<int>("strides"),
380+
op.get_attr("padding").ToString(),
381+
op.get_attr("data_format").ToString())
382+
};
383+
}
384+
368385
/// <summary>
369386
/// Return the gradients for TopK.
370387
/// </summary>

0 commit comments

Comments
 (0)