Skip to content

Commit 4e78d3d

Browse files
authored
Merge pull request #1039 from AsakusaRinne/fix_1036
Fix (#1036) and adjust the keras unittest
2 parents e3a1843 + 3f6da21 commit 4e78d3d

31 files changed

+503
-448
lines changed

src/TensorFlowNET.Core/Operations/math_ops.cs

+104-8
Original file line numberDiff line numberDiff line change
@@ -905,13 +905,29 @@ public static Tensor tensordot(Tensor a, Tensor b, NDArray axes, string name = n
905905
var (a_reshape, a_free_dims, a_free_dims_static) = _tensordot_reshape(a, a_axes);
906906
var (b_reshape, b_free_dims, b_free_dims_static) = _tensordot_reshape(b, b_axes, true);
907907
var ab_matmul = matmul(a_reshape, b_reshape);
908-
var dims = new List<int>();
909-
dims.AddRange(a_free_dims);
910-
dims.AddRange(b_free_dims);
911-
if (ab_matmul.shape.Equals(dims))
912-
return ab_matmul;
908+
if(a_free_dims is int[] a_free_dims_list && b_free_dims is int[] b_free_dims_list)
909+
{
910+
var total_free_dims = a_free_dims_list.Concat(b_free_dims_list).ToArray();
911+
if (ab_matmul.shape.IsFullyDefined && ab_matmul.shape.as_int_list().SequenceEqual(total_free_dims))
912+
{
913+
return ab_matmul;
914+
}
915+
else
916+
{
917+
return array_ops.reshape(ab_matmul, ops.convert_to_tensor(total_free_dims), name);
918+
}
919+
}
913920
else
914-
return array_ops.reshape(ab_matmul, tf.constant(dims.ToArray()), name: name);
921+
{
922+
var a_free_dims_tensor = ops.convert_to_tensor(a_free_dims, dtype: dtypes.int32);
923+
var b_free_dims_tensor = ops.convert_to_tensor(b_free_dims, dtype: dtypes.int32);
924+
var product = array_ops.reshape(ab_matmul, array_ops.concat(new[] { a_free_dims_tensor, b_free_dims_tensor }, 0), name);
925+
if(a_free_dims_static is not null && b_free_dims_static is not null)
926+
{
927+
product.shape = new Shape(a_free_dims_static.Concat(b_free_dims_static).ToArray());
928+
}
929+
return product;
930+
}
915931
});
916932
}
917933

@@ -927,14 +943,42 @@ public static Tensor tensordot(Tensor a, Tensor b, NDArray axes, string name = n
927943
return (Binding.range(a.shape.ndim - axe, a.shape.ndim).ToArray(),
928944
Binding.range(0, axe).ToArray());
929945
}
930-
else
946+
else if(axes.rank == 1)
931947
{
948+
if (axes.shape[0] != 2)
949+
{
950+
throw new ValueError($"`axes` must be an integer or have length 2. Received {axes}.");
951+
}
932952
(int a_axe, int b_axe) = (axes[0], axes[1]);
933953
return (new[] { a_axe }, new[] { b_axe });
934954
}
955+
else if(axes.rank == 2)
956+
{
957+
if (axes.shape[0] != 2)
958+
{
959+
throw new ValueError($"`axes` must be an integer or have length 2. Received {axes}.");
960+
}
961+
int[] a_axes = new int[axes.shape[1]];
962+
int[] b_axes = new int[axes.shape[1]];
963+
for(int i = 0; i < a_axes.Length; i++)
964+
{
965+
a_axes[i] = axes[0, i];
966+
b_axes[i] = axes[1, i];
967+
if (a_axes[i] == -1 || b_axes[i] == -1)
968+
{
969+
throw new ValueError($"Different number of contraction axes `a` and `b`," +
970+
$"{len(a_axes)} != {len(b_axes)}.");
971+
}
972+
}
973+
return (a_axes, b_axes);
974+
}
975+
else
976+
{
977+
throw new ValueError($"Invalid rank {axes.rank} to make tensor dot.");
978+
}
935979
}
936980

937-
static (Tensor, int[], int[]) _tensordot_reshape(Tensor a, int[] axes, bool flipped = false)
981+
static (Tensor, object, int[]) _tensordot_reshape(Tensor a, int[] axes, bool flipped = false)
938982
{
939983
if (a.shape.IsFullyDefined && isinstance(axes, (typeof(int[]), typeof(Tuple))))
940984
{
@@ -977,6 +1021,58 @@ public static Tensor tensordot(Tensor a, Tensor b, NDArray axes, string name = n
9771021
var reshaped_a = array_ops.reshape(a_trans, new_shape);
9781022
return (reshaped_a, free_dims, free_dims);
9791023
}
1024+
else
1025+
{
1026+
int[] free_dims_static;
1027+
Tensor converted_shape_a, converted_axes, converted_free;
1028+
if (a.shape.ndim != -1)
1029+
{
1030+
var shape_a = a.shape.as_int_list();
1031+
for(int i = 0; i < axes.Length; i++)
1032+
{
1033+
if (axes[i] < 0)
1034+
{
1035+
axes[i] += shape_a.Length;
1036+
}
1037+
}
1038+
var free = Enumerable.Range(0, shape_a.Length).Where(i => !axes.Contains(i)).ToArray();
1039+
1040+
var axes_dims = axes.Select(i => shape_a[i]);
1041+
var free_dims = free.Select(i => shape_a[i]).ToArray();
1042+
free_dims_static = free_dims;
1043+
converted_axes = ops.convert_to_tensor(axes, dtypes.int32, "axes");
1044+
converted_free = ops.convert_to_tensor(free, dtypes.int32, "free");
1045+
converted_shape_a = array_ops.shape(a);
1046+
}
1047+
else
1048+
{
1049+
free_dims_static = null;
1050+
converted_shape_a = array_ops.shape(a);
1051+
var rank_a = array_ops.rank(a);
1052+
converted_axes = ops.convert_to_tensor(axes, dtypes.int32, "axes");
1053+
converted_axes = array_ops.where_v2(converted_axes >= 0, converted_axes, converted_axes + rank_a);
1054+
(converted_free, var _) = gen_ops.list_diff(gen_math_ops.range(ops.convert_to_tensor(0), rank_a, ops.convert_to_tensor(1)),
1055+
converted_axes, dtypes.int32);
1056+
}
1057+
var converted_free_dims = array_ops.gather(converted_shape_a, converted_free);
1058+
var converted_axes_dims = array_ops.gather(converted_shape_a, converted_axes);
1059+
var prod_free_dims = reduce_prod(converted_free_dims);
1060+
var prod_axes_dims = reduce_prod(converted_axes_dims);
1061+
Tensor reshaped_a;
1062+
if (flipped)
1063+
{
1064+
var perm = array_ops.concat(new[] { converted_axes, converted_free }, 0);
1065+
var new_shape = array_ops.stack(new[] { prod_axes_dims, prod_free_dims });
1066+
reshaped_a = array_ops.reshape(array_ops.transpose(a, perm), new_shape);
1067+
}
1068+
else
1069+
{
1070+
var perm = array_ops.concat(new[] { converted_free, converted_axes }, 0);
1071+
var new_shape = array_ops.stack(new[] { prod_free_dims, prod_axes_dims });
1072+
reshaped_a = array_ops.reshape(array_ops.transpose(a, perm), new_shape);
1073+
}
1074+
return (reshaped_a, converted_free_dims, free_dims_static);
1075+
}
9801076

9811077
throw new NotImplementedException("_tensordot_reshape");
9821078
}

test/TensorFlowNET.Keras.UnitTest/Callbacks/EarlystoppingTest.cs

+4-9
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,11 @@
11
using Microsoft.VisualStudio.TestTools.UnitTesting;
2-
using Tensorflow.Keras.UnitTest.Helpers;
3-
using static Tensorflow.Binding;
4-
using Tensorflow;
5-
using Tensorflow.Keras.Optimizers;
2+
using System.Collections.Generic;
63
using Tensorflow.Keras.Callbacks;
74
using Tensorflow.Keras.Engine;
8-
using System.Collections.Generic;
95
using static Tensorflow.KerasApi;
10-
using Tensorflow.Keras;
116

127

13-
namespace TensorFlowNET.Keras.UnitTest
8+
namespace Tensorflow.Keras.UnitTest.Callbacks
149
{
1510
[TestClass]
1611
public class EarlystoppingTest
@@ -31,7 +26,7 @@ public void Earlystopping()
3126
layers.Dense(10)
3227
});
3328

34-
29+
3530
model.summary();
3631

3732
model.compile(optimizer: keras.optimizers.RMSprop(1e-3f),
@@ -55,7 +50,7 @@ public void Earlystopping()
5550
var callbacks = new List<ICallback>();
5651
callbacks.add(earlystop);
5752

58-
model.fit(x_train[new Slice(0, 2000)], y_train[new Slice(0, 2000)], batch_size, num_epochs,callbacks:callbacks);
53+
model.fit(x_train[new Slice(0, 2000)], y_train[new Slice(0, 2000)], batch_size, num_epochs, callbacks: callbacks);
5954
}
6055

6156
}

test/TensorFlowNET.Keras.UnitTest/EagerModeTestBase.cs

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
using Microsoft.VisualStudio.TestTools.UnitTesting;
22
using System;
3-
using Tensorflow;
4-
using Tensorflow.Keras;
53
using static Tensorflow.Binding;
64

7-
namespace TensorFlowNET.Keras.UnitTest
5+
namespace Tensorflow.Keras.UnitTest
86
{
97
public class EagerModeTestBase
108
{

test/TensorFlowNET.Keras.UnitTest/GradientTest.cs

+2-5
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
11
using Microsoft.VisualStudio.TestTools.UnitTesting;
22
using System.Linq;
3-
using Tensorflow;
43
using Tensorflow.Keras.Engine;
4+
using Tensorflow.NumPy;
55
using static Tensorflow.Binding;
66
using static Tensorflow.KerasApi;
7-
using Tensorflow.NumPy;
8-
using System;
9-
using Tensorflow.Keras.Optimizers;
107

11-
namespace TensorFlowNET.Keras.UnitTest;
8+
namespace Tensorflow.Keras.UnitTest;
129

1310
[TestClass]
1411
public class GradientTest : EagerModeTestBase

test/TensorFlowNET.Keras.UnitTest/InitializerTest.cs

+1-6
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,7 @@
11
using Microsoft.VisualStudio.TestTools.UnitTesting;
2-
using System;
3-
using System.Collections.Generic;
4-
using System.Linq;
5-
using System.Text;
6-
using TensorFlowNET.Keras.UnitTest;
72
using static Tensorflow.Binding;
83

9-
namespace TensorFlowNET.Keras.UnitTest;
4+
namespace Tensorflow.Keras.UnitTest;
105

116
[TestClass]
127
public class InitializerTest : EagerModeTestBase

test/TensorFlowNET.Keras.UnitTest/Layers/ActivationTest.cs

+3-5
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
using Microsoft.VisualStudio.TestTools.UnitTesting;
2-
using System;
3-
using System.Collections.Generic;
4-
using static Tensorflow.Binding;
52
using Tensorflow.NumPy;
3+
using static Tensorflow.Binding;
64
using static Tensorflow.KerasApi;
7-
using Tensorflow;
85

9-
namespace TensorFlowNET.Keras.UnitTest {
6+
namespace Tensorflow.Keras.UnitTest.Layers
7+
{
108
[TestClass]
119
public class ActivationTest : EagerModeTestBase
1210
{

test/TensorFlowNET.Keras.UnitTest/Layers/AttentionTest.cs

+6-9
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,11 @@
11
using Microsoft.VisualStudio.TestTools.UnitTesting;
2-
using System;
3-
using System.Collections.Generic;
2+
using Tensorflow.Keras.Layers;
3+
using Tensorflow.Keras.Utils;
44
using Tensorflow.NumPy;
55
using static Tensorflow.Binding;
66
using static Tensorflow.KerasApi;
7-
using Tensorflow.Keras.Layers;
8-
using Tensorflow;
9-
using Tensorflow.Keras.ArgsDefinition;
10-
using Tensorflow.Keras.Utils;
117

12-
namespace TensorFlowNET.Keras.UnitTest
8+
namespace Tensorflow.Keras.UnitTest.Layers
139
{
1410
[TestClass]
1511
public class AttentionTest : EagerModeTestBase
@@ -118,7 +114,8 @@ public void test_calculate_scores_multi_dim_concat()
118114
} }, dtype: np.float32);
119115
var attention_layer = (Attention)keras.layers.Attention(score_mode: "concat");
120116
//attention_layer.concat_score_weight = 1;
121-
attention_layer.concat_score_weight = base_layer_utils.make_variable(new VariableArgs() {
117+
attention_layer.concat_score_weight = base_layer_utils.make_variable(new VariableArgs()
118+
{
122119
Name = "concat_score_weight",
123120
Shape = (1),
124121
DType = TF_DataType.TF_FLOAT,
@@ -156,7 +153,7 @@ public void test_masked_attention()
156153

157154
var query = keras.Input(shape: (4, 8));
158155
var value = keras.Input(shape: (2, 8));
159-
var mask_tensor = keras.Input(shape:(4, 2));
156+
var mask_tensor = keras.Input(shape: (4, 2));
160157
var attention_layer = keras.layers.MultiHeadAttention(num_heads: 2, key_dim: 2);
161158
attention_layer.Apply(new Tensor[] { query, value, mask_tensor });
162159

test/TensorFlowNET.Keras.UnitTest/Layers/CosineSimilarity.Test.cs

+7-9
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
using Microsoft.VisualStudio.TestTools.UnitTesting;
2-
using Tensorflow.NumPy;
3-
using Tensorflow;
42
using Tensorflow.Keras.Losses;
5-
using static Tensorflow.Binding;
3+
using Tensorflow.NumPy;
64
using static Tensorflow.KerasApi;
75

8-
namespace TensorFlowNET.Keras.UnitTest
6+
namespace Tensorflow.Keras.UnitTest.Layers
97
{
108
[TestClass]
119
public class CosineSimilarity
@@ -16,7 +14,7 @@ public class CosineSimilarity
1614
NDArray y_pred_float = new float[,] { { 1.0f, 0.0f }, { 1.0f, 1.0f } };
1715

1816
[TestMethod]
19-
17+
2018
public void _Default()
2119
{
2220
//>>> # Using 'auto'/'sum_over_batch_size' reduction type.
@@ -27,7 +25,7 @@ public void _Default()
2725
//>>> # loss = mean(sum(l2_norm(y_true) . l2_norm(y_pred), axis=1))
2826
//>>> # = -((0. + 0.) + (0.5 + 0.5)) / 2
2927
//-0.5
30-
var loss = keras.losses.CosineSimilarity(axis : 1);
28+
var loss = keras.losses.CosineSimilarity(axis: 1);
3129
var call = loss.Call(y_true_float, y_pred_float);
3230
Assert.AreEqual((NDArray)(-0.49999997f), call.numpy());
3331
}
@@ -41,7 +39,7 @@ public void _Sample_Weight()
4139
//- 0.0999
4240
var loss = keras.losses.CosineSimilarity();
4341
var call = loss.Call(y_true_float, y_pred_float, sample_weight: (NDArray)new float[] { 0.8f, 0.2f });
44-
Assert.AreEqual((NDArray) (- 0.099999994f), call.numpy());
42+
Assert.AreEqual((NDArray)(-0.099999994f), call.numpy());
4543
}
4644

4745
[TestMethod]
@@ -53,7 +51,7 @@ public void _SUM()
5351
//... reduction = tf.keras.losses.Reduction.SUM)
5452
//>>> cosine_loss(y_true, y_pred).numpy()
5553
//- 0.999
56-
var loss = keras.losses.CosineSimilarity(axis: 1,reduction : ReductionV2.SUM);
54+
var loss = keras.losses.CosineSimilarity(axis: 1, reduction: ReductionV2.SUM);
5755
var call = loss.Call(y_true_float, y_pred_float);
5856
Assert.AreEqual((NDArray)(-0.99999994f), call.numpy());
5957
}
@@ -67,7 +65,7 @@ public void _None()
6765
//... reduction = tf.keras.losses.Reduction.NONE)
6866
//>>> cosine_loss(y_true, y_pred).numpy()
6967
//array([-0., -0.999], dtype = float32)
70-
var loss = keras.losses.CosineSimilarity(axis :1, reduction: ReductionV2.NONE);
68+
var loss = keras.losses.CosineSimilarity(axis: 1, reduction: ReductionV2.NONE);
7169
var call = loss.Call(y_true_float, y_pred_float);
7270
Assert.AreEqual((NDArray)new float[] { -0f, -0.99999994f }, call.numpy());
7371
}

test/TensorFlowNET.Keras.UnitTest/Layers/Huber.Test.cs

+4-6
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
using Microsoft.VisualStudio.TestTools.UnitTesting;
2-
using Tensorflow.NumPy;
3-
using Tensorflow;
42
using Tensorflow.Keras.Losses;
5-
using static Tensorflow.Binding;
3+
using Tensorflow.NumPy;
64
using static Tensorflow.KerasApi;
75

8-
namespace TensorFlowNET.Keras.UnitTest
6+
namespace Tensorflow.Keras.UnitTest.Layers
97
{
108
[TestClass]
119
public class Huber
@@ -16,7 +14,7 @@ public class Huber
1614
NDArray y_pred_float = new float[,] { { 0.6f, 0.4f }, { 0.4f, 0.6f } };
1715

1816
[TestMethod]
19-
17+
2018
public void _Default()
2119
{
2220
//>>> # Using 'auto'/'sum_over_batch_size' reduction type.
@@ -49,7 +47,7 @@ public void _SUM()
4947
//... reduction = tf.keras.losses.Reduction.SUM)
5048
//>>> h(y_true, y_pred).numpy()
5149
//0.31
52-
var loss = keras.losses.Huber(reduction : ReductionV2.SUM);
50+
var loss = keras.losses.Huber(reduction: ReductionV2.SUM);
5351
var call = loss.Call(y_true_float, y_pred_float);
5452
Assert.AreEqual((NDArray)0.31f, call.numpy());
5553
}

test/TensorFlowNET.Keras.UnitTest/Layers/Layers.Convolution.Test.cs

+2-4
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
using Microsoft.VisualStudio.TestTools.UnitTesting;
22
using Tensorflow.NumPy;
3-
using Tensorflow;
4-
using Tensorflow.Operations;
53
using static Tensorflow.KerasApi;
64

7-
namespace TensorFlowNET.Keras.UnitTest
5+
namespace Tensorflow.Keras.UnitTest.Layers
86
{
97
[TestClass]
108
public class LayersConvolutionTest : EagerModeTestBase
@@ -14,7 +12,7 @@ public void BasicConv1D()
1412
{
1513
var filters = 8;
1614

17-
var conv = keras.layers.Conv1D(filters, kernel_size: 3, activation: "linear");
15+
var conv = keras.layers.Conv1D(filters, kernel_size: 3, activation: "linear");
1816

1917
var x = np.arange(256.0f).reshape((8, 8, 4));
2018
var y = conv.Apply(x);

0 commit comments

Comments
 (0)