Skip to content

Commit 1b1a503

Browse files
Visagan GuruparanOceania2018
Visagan Guruparan
authored andcommitted
np update square and dot product
1 parent e1ece66 commit 1b1a503

File tree

4 files changed

+84
-4
lines changed

4 files changed

+84
-4
lines changed

src/TensorFlowNET.Core/APIs/tf.math.cs

+13-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using Tensorflow.NumPy;
1718
using Tensorflow.Operations;
1819

1920
namespace Tensorflow
@@ -42,7 +43,6 @@ public Tensor erf(Tensor x, string name = null)
4243

4344
public Tensor multiply(Tensor x, Tensor y, string name = null)
4445
=> math_ops.multiply(x, y, name: name);
45-
4646
public Tensor divide_no_nan(Tensor a, Tensor b, string name = null)
4747
=> math_ops.div_no_nan(a, b);
4848

@@ -452,7 +452,18 @@ public Tensor multiply(Tensor x, Tensor y, string name = null)
452452
/// <returns></returns>
453453
public Tensor multiply<Tx, Ty>(Tx x, Ty y, string name = null)
454454
=> gen_math_ops.mul(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name: name);
455-
455+
/// <summary>
456+
/// return scalar product
457+
/// </summary>
458+
/// <typeparam name="Tx"></typeparam>
459+
/// <typeparam name="Ty"></typeparam>
460+
/// <param name="x"></param>
461+
/// <param name="y"></param>
462+
/// <param name="axes"></param>
463+
/// <param name="name"></param>
464+
/// <returns></returns>
465+
public Tensor dot_prod<Tx, Ty>(Tx x, Ty y, NDArray axes, string name = null)
466+
=> math_ops.tensordot(convert_to_tensor(x), convert_to_tensor(y), axes, name: name);
456467
public Tensor negative(Tensor x, string name = null)
457468
=> gen_math_ops.neg(x, name);
458469

src/TensorFlowNET.Core/Binding.Util.cs

+22-1
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,28 @@ public static Shape GetShape(this object data)
486486
throw new NotImplementedException("");
487487
}
488488
}
489-
489+
public static NDArray GetFlattenArray(NDArray x)
490+
{
491+
switch (x.GetDataType())
492+
{
493+
case TF_DataType.TF_FLOAT:
494+
x = x.ToArray<float>();
495+
break;
496+
case TF_DataType.TF_DOUBLE:
497+
x = x.ToArray<double>();
498+
break;
499+
case TF_DataType.TF_INT16:
500+
case TF_DataType.TF_INT32:
501+
x = x.ToArray<int>();
502+
break;
503+
case TF_DataType.TF_INT64:
504+
x = x.ToArray<long>();
505+
break;
506+
default:
507+
break;
508+
}
509+
return x;
510+
}
490511
public static TF_DataType GetDataType(this object data)
491512
{
492513
var type = data.GetType();

src/TensorFlowNET.Core/NumPy/Numpy.Math.cs

+21
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,30 @@ public static NDArray prod(NDArray array, Axis? axis = null, Type? dtype = null,
4949
[AutoNumPy]
5050
public static NDArray prod<T>(params T[] array) where T : unmanaged
5151
=> new NDArray(tf.reduce_prod(new NDArray(array)));
52+
[AutoNumPy]
53+
public static NDArray dot(NDArray x1, NDArray x2, NDArray? axes = null, string? name = null)
54+
{
55+
//if axes mentioned
56+
if (axes != null)
57+
{
58+
return new NDArray(tf.dot_prod(x1, x2, axes, name));
59+
}
60+
if (x1.shape.ndim > 1)
61+
{
62+
x1 = GetFlattenArray(x1);
63+
}
64+
if (x2.shape.ndim > 1)
65+
{
66+
x2 = GetFlattenArray(x2);
67+
}
68+
//if axes not mentioned, default 0,0
69+
return new NDArray(tf.dot_prod(x1, x2, axes: new int[] { 0, 0 }, name));
5270

71+
}
5372
[AutoNumPy]
5473
public static NDArray power(NDArray x, NDArray y) => new NDArray(tf.pow(x, y));
74+
[AutoNumPy]
75+
public static NDArray square(NDArray x) => new NDArray(tf.square(x));
5576

5677
[AutoNumPy]
5778
public static NDArray sin(NDArray x) => new NDArray(math_ops.sin(x));

test/TensorFlowNET.UnitTest/Numpy/Math.Test.cs

+28-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,34 @@ public void power()
6565
var y = np.power(x, 3);
6666
Assert.AreEqual(y, new[] { 0, 1, 8, 27, 64, 125 });
6767
}
68-
[TestMethod]
68+
[TestMethod]
69+
public void square()
70+
{
71+
var x = np.arange(6);
72+
var y = np.square(x);
73+
Assert.AreEqual(y, new[] { 0, 1, 4, 9, 16, 25 });
74+
}
75+
[TestMethod]
76+
public void dotproduct()
77+
{
78+
var x1 = new NDArray(new[] { 1, 2, 3 });
79+
var x2 = new NDArray(new[] { 4, 5, 6 });
80+
double result1 = np.dot(x1, x2);
81+
NDArray y1 = new float[,] {
82+
{ 1.0f, 2.0f, 3.0f },
83+
{ 4.0f, 5.1f,6.0f },
84+
{ 4.0f, 5.1f,6.0f }
85+
};
86+
NDArray y2 = new float[,] {
87+
{ 3.0f, 2.0f, 1.0f },
88+
{ 6.0f, 5.1f, 4.0f },
89+
{ 6.0f, 5.1f, 4.0f }
90+
};
91+
double result2 = np.dot(y1, y2);
92+
Assert.AreEqual(result1, 32);
93+
Assert.AreEqual(Math.Round(result2, 2), 158.02);
94+
}
95+
[TestMethod]
6996
public void maximum()
7097
{
7198
var x1 = new NDArray(new[,] { { 1, 2, 3 }, { 4, 5.1, 6 } });

0 commit comments

Comments
 (0)