Skip to content

Commit adc90af

Browse files
authored
Merge pull request #1129 from Beacontownfc/mybranch3
fix: optimize some APIs
2 parents dfd9dd0 + a76cd67 commit adc90af

File tree

2 files changed

+2
-11
lines changed

2 files changed

+2
-11
lines changed

src/TensorFlowNET.Core/APIs/tf.nn.cs

+2-10
Original file line numberDiff line numberDiff line change
@@ -144,16 +144,8 @@ public Tensor batch_normalization(Tensor x,
144144
Tensor offset,
145145
Tensor scale,
146146
float variance_epsilon,
147-
string name = null)
148-
{
149-
var inv = math_ops.rsqrt(variance + variance_epsilon);
150-
tf_with(ops.name_scope(name, "batchnorm", (x, mean, variance, scale, offset)), scope =>
151-
{
152-
if (scale != null) inv *= scale;
153-
});
154-
if (offset != null) return x * math_ops.cast(inv, x.dtype) + math_ops.cast(offset - mean * inv, dtype: x.dtype);
155-
else return x * math_ops.cast(inv, x.dtype) + math_ops.cast(-mean * inv, dtype: x.dtype);
156-
}
147+
string name = null) => nn_impl.batch_normalization(x, mean, variance, offset, scale, variance_epsilon, name);
148+
157149

158150
public Tensor max_pool(Tensor value, int[] ksize, int[] strides, string padding, string data_format = "NHWC", string name = null)
159151
=> nn_ops.max_pool(value, ksize, strides, padding, data_format: data_format, name: name);

src/TensorFlowNET.Core/Operations/array_ops.cs

-1
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,6 @@ public static Tensor stop_gradient(Tensor input, string name = null)
678678
var tape = tf.GradientTape().stop_recording();
679679
var result = gen_array_ops.stop_gradient(input, name);
680680
tape.StartRecord();
681-
tf.GradientTape().PushTape(tape);
682681
return result;
683682
}
684683

0 commit comments

Comments
 (0)