From c43c60d75f16af0a26709391d28a9c98c1c8eff3 Mon Sep 17 00:00:00 2001
From: Visagan Guruparan <103048@smsassist.com>
Date: Sun, 18 Jun 2023 22:46:36 -0500
Subject: [PATCH] np update square and dot product
---
src/TensorFlowNET.Core/APIs/tf.math.cs | 15 ++++++++--
src/TensorFlowNET.Core/Binding.Util.cs | 23 ++++++++++++++-
src/TensorFlowNET.Core/NumPy/Numpy.Math.cs | 21 ++++++++++++++
.../TensorFlowNET.UnitTest/Numpy/Math.Test.cs | 29 ++++++++++++++++++-
4 files changed, 84 insertions(+), 4 deletions(-)
diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs
index 75253700a..0e53d938a 100644
--- a/src/TensorFlowNET.Core/APIs/tf.math.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.math.cs
@@ -14,6 +14,7 @@ You may obtain a copy of the License at
limitations under the License.
******************************************************************************/
+using Tensorflow.NumPy;
using Tensorflow.Operations;
namespace Tensorflow
@@ -42,7 +43,6 @@ public Tensor erf(Tensor x, string name = null)
public Tensor multiply(Tensor x, Tensor y, string name = null)
=> math_ops.multiply(x, y, name: name);
-
public Tensor divide_no_nan(Tensor a, Tensor b, string name = null)
=> math_ops.div_no_nan(a, b);
@@ -452,7 +452,18 @@ public Tensor multiply(Tensor x, Tensor y, string name = null)
///
public Tensor multiply(Tx x, Ty y, string name = null)
=> gen_math_ops.mul(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name: name);
-
+ ///
+ /// return scalar product
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public Tensor dot_prod(Tx x, Ty y, NDArray axes, string name = null)
+ => math_ops.tensordot(convert_to_tensor(x), convert_to_tensor(y), axes, name: name);
public Tensor negative(Tensor x, string name = null)
=> gen_math_ops.neg(x, name);
diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs
index 8df39334a..e414ef6e8 100644
--- a/src/TensorFlowNET.Core/Binding.Util.cs
+++ b/src/TensorFlowNET.Core/Binding.Util.cs
@@ -486,7 +486,28 @@ public static Shape GetShape(this object data)
throw new NotImplementedException("");
}
}
-
+ public static NDArray GetFlattenArray(NDArray x)
+ {
+ switch (x.GetDataType())
+ {
+ case TF_DataType.TF_FLOAT:
+ x = x.ToArray();
+ break;
+ case TF_DataType.TF_DOUBLE:
+ x = x.ToArray();
+ break;
+ case TF_DataType.TF_INT16:
+ case TF_DataType.TF_INT32:
+ x = x.ToArray();
+ break;
+ case TF_DataType.TF_INT64:
+ x = x.ToArray();
+ break;
+ default:
+ break;
+ }
+ return x;
+ }
public static TF_DataType GetDataType(this object data)
{
var type = data.GetType();
diff --git a/src/TensorFlowNET.Core/NumPy/Numpy.Math.cs b/src/TensorFlowNET.Core/NumPy/Numpy.Math.cs
index ea85048f8..5bc97952b 100644
--- a/src/TensorFlowNET.Core/NumPy/Numpy.Math.cs
+++ b/src/TensorFlowNET.Core/NumPy/Numpy.Math.cs
@@ -49,9 +49,30 @@ public static NDArray prod(NDArray array, Axis? axis = null, Type? dtype = null,
[AutoNumPy]
public static NDArray prod(params T[] array) where T : unmanaged
=> new NDArray(tf.reduce_prod(new NDArray(array)));
+ [AutoNumPy]
+ public static NDArray dot(NDArray x1, NDArray x2, NDArray? axes = null, string? name = null)
+ {
+ //if axes mentioned
+ if (axes != null)
+ {
+ return new NDArray(tf.dot_prod(x1, x2, axes, name));
+ }
+ if (x1.shape.ndim > 1)
+ {
+ x1 = GetFlattenArray(x1);
+ }
+ if (x2.shape.ndim > 1)
+ {
+ x2 = GetFlattenArray(x2);
+ }
+ //if axes not mentioned, default 0,0
+ return new NDArray(tf.dot_prod(x1, x2, axes: new int[] { 0, 0 }, name));
+ }
[AutoNumPy]
public static NDArray power(NDArray x, NDArray y) => new NDArray(tf.pow(x, y));
+ [AutoNumPy]
+ public static NDArray square(NDArray x) => new NDArray(tf.square(x));
[AutoNumPy]
public static NDArray sin(NDArray x) => new NDArray(math_ops.sin(x));
diff --git a/test/TensorFlowNET.UnitTest/Numpy/Math.Test.cs b/test/TensorFlowNET.UnitTest/Numpy/Math.Test.cs
index 32b517e4f..65cdaedd9 100644
--- a/test/TensorFlowNET.UnitTest/Numpy/Math.Test.cs
+++ b/test/TensorFlowNET.UnitTest/Numpy/Math.Test.cs
@@ -65,7 +65,34 @@ public void power()
var y = np.power(x, 3);
Assert.AreEqual(y, new[] { 0, 1, 8, 27, 64, 125 });
}
- [TestMethod]
+ [TestMethod]
+ public void square()
+ {
+ var x = np.arange(6);
+ var y = np.square(x);
+ Assert.AreEqual(y, new[] { 0, 1, 4, 9, 16, 25 });
+ }
+ [TestMethod]
+ public void dotproduct()
+ {
+ var x1 = new NDArray(new[] { 1, 2, 3 });
+ var x2 = new NDArray(new[] { 4, 5, 6 });
+ double result1 = np.dot(x1, x2);
+ NDArray y1 = new float[,] {
+ { 1.0f, 2.0f, 3.0f },
+ { 4.0f, 5.1f,6.0f },
+ { 4.0f, 5.1f,6.0f }
+ };
+ NDArray y2 = new float[,] {
+ { 3.0f, 2.0f, 1.0f },
+ { 6.0f, 5.1f, 4.0f },
+ { 6.0f, 5.1f, 4.0f }
+ };
+ double result2 = np.dot(y1, y2);
+ Assert.AreEqual(result1, 32);
+ Assert.AreEqual(Math.Round(result2, 2), 158.02);
+ }
+ [TestMethod]
public void maximum()
{
var x1 = new NDArray(new[,] { { 1, 2, 3 }, { 4, 5.1, 6 } });