Skip to content

Commit 8630438

Browse files
authored
Merge pull request #1114 from Beacontownfc/mybranch
Improve the API of LayerNormalization
2 parents 4c6063d + 786b266 commit 8630438

File tree

3 files changed

+65
-1
lines changed

3 files changed

+65
-1
lines changed

src/TensorFlowNET.Core/APIs/tf.nn.cs

+29
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using System.Xml.Linq;
1718
using Tensorflow.Operations;
1819
using Tensorflow.Operations.Activation;
1920
using static Tensorflow.Binding;
@@ -126,6 +127,34 @@ public Tensor[] fused_batch_norm(Tensor x,
126127
name: name,
127128
exponential_avg_factor: exponential_avg_factor);
128129

130+
/// <summary>
131+
/// Normalizes a tensor by `mean` and `variance`, and applies (optionally) a`scale` \\(\gamma\\) to it, as well as an `offset` \\(\beta\\).
132+
/// </summary>
133+
/// <param name="x">A floating point tensor.</param>
134+
/// <param name="mean">A mean `Tensor`.</param>
135+
/// <param name="variance">A variance `Tensor`.</param>
136+
/// <param name="offset"> An offset `Tensor`, often denoted \\(\beta\\) in equations, or NULL. If present, will be added to the normalized tensor.</param>
137+
/// <param name="scale"> A scale `Tensor`, often denoted \\(\gamma\\) in equations, or NULL. If present, the scale is applied to the normalized tensor.</param>
138+
/// <param name="variance_epsilon"> A small float number to avoid dividing by 0.</param>
139+
/// <param name="name">A name for this operation.</param>
140+
/// <returns>the normalized, scaled, offset tensor.</returns>
141+
public Tensor batch_normalization(Tensor x,
142+
Tensor mean,
143+
Tensor variance,
144+
Tensor offset,
145+
Tensor scale,
146+
float variance_epsilon,
147+
string name = null)
148+
{
149+
var inv = math_ops.rsqrt(variance + variance_epsilon);
150+
tf_with(ops.name_scope(name, "batchnorm", (x, mean, variance, scale, offset)), scope =>
151+
{
152+
if (scale != null) inv *= scale;
153+
});
154+
if (offset != null) return x * math_ops.cast(inv, x.dtype) + math_ops.cast(offset - mean * inv, dtype: x.dtype);
155+
else return x * math_ops.cast(inv, x.dtype) + math_ops.cast(-mean * inv, dtype: x.dtype);
156+
}
157+
129158
public Tensor max_pool(Tensor value, int[] ksize, int[] strides, string padding, string data_format = "NHWC", string name = null)
130159
=> nn_ops.max_pool(value, ksize, strides, padding, data_format: data_format, name: name);
131160

src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs

+14-1
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,22 @@ protected override Tensors Call(Tensors inputs, Tensors state = null, bool? trai
153153
}
154154
else
155155
{
156+
var input_dtype = inputs.dtype;
157+
if ((input_dtype == tf.float16) && DType == tf.float32) inputs = tf.cast(inputs, tf.float32);
158+
(Tensor mean, Tensor variance) = tf.nn.moments(inputs, axis, keep_dims: true);
156159

157-
}
160+
(Tensor scale, Tensor offset) = (_broadcast(gamma), _broadcast(beta));
161+
162+
outputs = tf.nn.batch_normalization(
163+
inputs,
164+
mean,
165+
variance,
166+
offset: offset,
167+
scale: scale,
168+
variance_epsilon: epsilon);
158169

170+
outputs = tf.cast(outputs, input_dtype);
171+
}
159172
// If some components of the shape got lost due to adjustments, fix that.
160173
outputs.shape = input_shape;
161174

test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs

+22
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
using Microsoft.VisualStudio.TestTools.UnitTesting;
2+
using System;
23
using System.Collections.Generic;
4+
using System.Linq;
35
using Tensorflow.NumPy;
46
using static Tensorflow.Binding;
57
using static Tensorflow.KerasApi;
@@ -161,6 +163,26 @@ public void LayerNormalization()
161163
Tensor output = layer.Apply(inputs);
162164
Assert.AreEqual((5, 2), output.shape);
163165
Assert.IsTrue(output[0].numpy().Equals(new[] { -0.99998f, 0.99998f }));
166+
167+
// test_layernorm_weights
168+
Assert.AreEqual(len(layer.TrainableWeights), 2);
169+
Assert.AreEqual(len(layer.Weights), 2);
170+
171+
var beta = layer.Weights.Where(x => x.Name.StartsWith("beta")).Single();
172+
var gamma = layer.Weights.Where(x => x.Name.StartsWith("gamma")).Single();
173+
174+
// correctness_test
175+
layer = keras.layers.LayerNormalization(axis: -1, epsilon: (float) 1e-12);
176+
var x = np.random.normal(loc: 5.0f, scale: 10.0f, size: (1000, 2, 2, 2)).astype(tf.float32);
177+
178+
output = layer.Apply(x);
179+
180+
var y = (output - beta.numpy()) / gamma.numpy();
181+
182+
var y_mean = np.mean(y.numpy());
183+
var y_std = np.sqrt(np.sum(np.power(y.numpy() - np.mean(y.numpy()), 2)) / 8000);
184+
Assert.IsTrue(tf.greater(np.array(0.1f), tf.abs(y_std - 1.0)).ToArray<bool>()[0]);
185+
Assert.IsTrue(tf.greater(np.array(0.1f), tf.abs(y_mean)).ToArray<bool>()[0]);
164186
}
165187

166188
/// <summary>

0 commit comments

Comments
 (0)