diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs
index e0c29bfa7..e5cd4e569 100644
--- a/src/TensorFlowNET.Core/APIs/tf.nn.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs
@@ -14,6 +14,7 @@ You may obtain a copy of the License at
limitations under the License.
******************************************************************************/
+using System.Xml.Linq;
using Tensorflow.Operations;
using Tensorflow.Operations.Activation;
using static Tensorflow.Binding;
@@ -126,6 +127,34 @@ public Tensor[] fused_batch_norm(Tensor x,
name: name,
exponential_avg_factor: exponential_avg_factor);
+ ///
+ /// Normalizes a tensor by `mean` and `variance`, and applies (optionally) a`scale` \\(\gamma\\) to it, as well as an `offset` \\(\beta\\).
+ ///
+ /// A floating point tensor.
+ /// A mean `Tensor`.
+ /// A variance `Tensor`.
+ /// An offset `Tensor`, often denoted \\(\beta\\) in equations, or NULL. If present, will be added to the normalized tensor.
+ /// A scale `Tensor`, often denoted \\(\gamma\\) in equations, or NULL. If present, the scale is applied to the normalized tensor.
+ /// A small float number to avoid dividing by 0.
+ /// A name for this operation.
+ /// the normalized, scaled, offset tensor.
+ public Tensor batch_normalization(Tensor x,
+ Tensor mean,
+ Tensor variance,
+ Tensor offset,
+ Tensor scale,
+ float variance_epsilon,
+ string name = null)
+ {
+ var inv = math_ops.rsqrt(variance + variance_epsilon);
+ tf_with(ops.name_scope(name, "batchnorm", (x, mean, variance, scale, offset)), scope =>
+ {
+ if (scale != null) inv *= scale;
+ });
+ if (offset != null) return x * math_ops.cast(inv, x.dtype) + math_ops.cast(offset - mean * inv, dtype: x.dtype);
+ else return x * math_ops.cast(inv, x.dtype) + math_ops.cast(-mean * inv, dtype: x.dtype);
+ }
+
public Tensor max_pool(Tensor value, int[] ksize, int[] strides, string padding, string data_format = "NHWC", string name = null)
=> nn_ops.max_pool(value, ksize, strides, padding, data_format: data_format, name: name);
diff --git a/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs b/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs
index 1898f24c8..69bdfbaa0 100644
--- a/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs
+++ b/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs
@@ -153,9 +153,22 @@ protected override Tensors Call(Tensors inputs, Tensors state = null, bool? trai
}
else
{
+ var input_dtype = inputs.dtype;
+ if ((input_dtype == tf.float16) && DType == tf.float32) inputs = tf.cast(inputs, tf.float32);
+ (Tensor mean, Tensor variance) = tf.nn.moments(inputs, axis, keep_dims: true);
- }
+ (Tensor scale, Tensor offset) = (_broadcast(gamma), _broadcast(beta));
+
+ outputs = tf.nn.batch_normalization(
+ inputs,
+ mean,
+ variance,
+ offset: offset,
+ scale: scale,
+ variance_epsilon: epsilon);
+ outputs = tf.cast(outputs, input_dtype);
+ }
// If some components of the shape got lost due to adjustments, fix that.
outputs.shape = input_shape;
diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs
index f4980b82d..98d909668 100644
--- a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs
+++ b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs
@@ -1,5 +1,7 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System;
using System.Collections.Generic;
+using System.Linq;
using Tensorflow.NumPy;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
@@ -161,6 +163,26 @@ public void LayerNormalization()
Tensor output = layer.Apply(inputs);
Assert.AreEqual((5, 2), output.shape);
Assert.IsTrue(output[0].numpy().Equals(new[] { -0.99998f, 0.99998f }));
+
+ // test_layernorm_weights
+ Assert.AreEqual(len(layer.TrainableWeights), 2);
+ Assert.AreEqual(len(layer.Weights), 2);
+
+ var beta = layer.Weights.Where(x => x.Name.StartsWith("beta")).Single();
+ var gamma = layer.Weights.Where(x => x.Name.StartsWith("gamma")).Single();
+
+ // correctness_test
+ layer = keras.layers.LayerNormalization(axis: -1, epsilon: (float) 1e-12);
+ var x = np.random.normal(loc: 5.0f, scale: 10.0f, size: (1000, 2, 2, 2)).astype(tf.float32);
+
+ output = layer.Apply(x);
+
+ var y = (output - beta.numpy()) / gamma.numpy();
+
+ var y_mean = np.mean(y.numpy());
+ var y_std = np.sqrt(np.sum(np.power(y.numpy() - np.mean(y.numpy()), 2)) / 8000);
+ Assert.IsTrue(tf.greater(np.array(0.1f), tf.abs(y_std - 1.0)).ToArray()[0]);
+ Assert.IsTrue(tf.greater(np.array(0.1f), tf.abs(y_mean)).ToArray()[0]);
}
///