Skip to content

Commit a883898

Browse files
committed
Add graph to function test.
1 parent d4f1c34 commit a883898

16 files changed

+182
-9
lines changed

test/TensorFlowNET.UnitTest/Keras/LayersTest.cs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,23 @@ public void Sequential()
3030
public void Embedding()
3131
{
3232
var model = tf.keras.Sequential();
33-
var layer = tf.keras.layers.Embedding(1000, 64, input_length: 10);
33+
var layer = tf.keras.layers.Embedding(7, 2, input_length: 4);
3434
model.add(layer);
3535
// the model will take as input an integer matrix of size (batch,
3636
// input_length).
3737
// the largest integer (i.e. word index) in the input should be no larger
3838
// than 999 (vocabulary size).
3939
// now model.output_shape == (None, 10, 64), where None is the batch
4040
// dimension.
41-
var input_array = np.random.randint(1000, size: (32, 10));
41+
var input_array = np.array(new int[,]
42+
{
43+
{ 1, 2, 3, 4 },
44+
{ 2, 3, 4, 5 },
45+
{ 3, 4, 5, 6 }
46+
});
4247
model.compile("rmsprop", "mse");
4348
var output_array = model.predict(input_array);
49+
Assert.AreEqual((32, 10, 64), output_array.TensorShape);
4450
}
4551

4652
/// <summary>

test/TensorFlowNET.UnitTest/TF_API/BitwiseApiTest.cs renamed to test/TensorFlowNET.UnitTest/ManagedAPI/BitwiseApiTest.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
using Tensorflow;
44
using static Tensorflow.Binding;
55

6-
namespace TensorFlowNET.UnitTest.TF_API
6+
namespace TensorFlowNET.UnitTest.ManagedAPI
77
{
88
[TestClass]
99
public class BitwiseApiTest : TFNetApiTest

test/TensorFlowNET.UnitTest/TF_API/GradientTest.cs renamed to test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
using Tensorflow;
66
using static Tensorflow.Binding;
77

8-
namespace TensorFlowNET.UnitTest.TF_API
8+
namespace TensorFlowNET.UnitTest.ManagedAPI
99
{
1010
[TestClass]
1111
public class GradientTest

test/TensorFlowNET.UnitTest/TF_API/LinalgTest.cs renamed to test/TensorFlowNET.UnitTest/ManagedAPI/LinalgTest.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
using System.Text;
55
using static Tensorflow.Binding;
66

7-
namespace TensorFlowNET.UnitTest.TF_API
7+
namespace TensorFlowNET.UnitTest.ManagedAPI
88
{
99
[TestClass]
1010
public class LinalgTest

test/TensorFlowNET.UnitTest/TF_API/MathApiTest.cs renamed to test/TensorFlowNET.UnitTest/ManagedAPI/MathApiTest.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
using Tensorflow;
77
using static Tensorflow.Binding;
88

9-
namespace TensorFlowNET.UnitTest.TF_API
9+
namespace TensorFlowNET.UnitTest.ManagedAPI
1010
{
1111
[TestClass]
1212
public class MathApiTest : TFNetApiTest

test/TensorFlowNET.UnitTest/TF_API/StringsApiTest.cs renamed to test/TensorFlowNET.UnitTest/ManagedAPI/StringsApiTest.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
using Tensorflow;
66
using static Tensorflow.Binding;
77

8-
namespace TensorFlowNET.UnitTest.TF_API
8+
namespace TensorFlowNET.UnitTest.ManagedAPI
99
{
1010
[TestClass]
1111
public class StringsApiTest

test/TensorFlowNET.UnitTest/TF_API/TensorOperate.cs renamed to test/TensorFlowNET.UnitTest/ManagedAPI/TensorOperate.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
using Tensorflow;
66
using static Tensorflow.Binding;
77

8-
namespace TensorFlowNET.UnitTest.TF_API
8+
namespace TensorFlowNET.UnitTest.ManagedAPI
99
{
1010
[TestClass]
1111
public class TensorOperate
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
using Microsoft.VisualStudio.TestTools.UnitTesting;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Linq;
5+
using System.Runtime.InteropServices;
6+
using Tensorflow;
7+
using Tensorflow.Functions;
8+
using static TensorFlowNET.UnitTest.c_test_util;
9+
10+
namespace TensorFlowNET.UnitTest.NativeAPI
11+
{
12+
/// <summary>
13+
/// tensorflow\c\c_api_function_test.cc
14+
/// `class CApiColocationTest`
15+
/// </summary>
16+
[TestClass]
17+
public class CApiFunctionTest : CApiTest, IDisposable
18+
{
19+
Graph func_graph_;
20+
Graph host_graph_;
21+
string func_name_ = "MyFunc";
22+
string func_node_name_ = "MyFunc_0";
23+
Status s_;
24+
IntPtr func_;
25+
26+
[TestInitialize]
27+
public void Initialize()
28+
{
29+
func_graph_ = new Graph();
30+
host_graph_ = new Graph();
31+
s_ = new Status();
32+
}
33+
34+
[TestMethod]
35+
public void OneOp_ZeroInputs_OneOutput()
36+
{
37+
var c = ScalarConst(10, func_graph_, s_, "scalar10");
38+
// Define
39+
Define(-1, new Operation[0], new Operation[0], new[] { c }, new string[0]);
40+
41+
// Use, run, and verify
42+
var func_op = Use(new Operation[0]);
43+
Run(new KeyValuePair<Operation, Tensor>[0], func_op, 10);
44+
VerifyFDef(new[] { "scalar10_0" });
45+
}
46+
47+
void Define(int num_opers, Operation[] opers,
48+
Operation[] inputs, Operation[] outputs,
49+
string[] output_names, bool expect_failure = false)
50+
=> DefineT(num_opers, opers,
51+
inputs.Select(x => new TF_Output(x, 0)).ToArray(),
52+
outputs.Select(x => new TF_Output(x, 0)).ToArray(),
53+
output_names, expect_failure);
54+
55+
void DefineT(int num_opers, Operation[] opers,
56+
TF_Output[] inputs, TF_Output[] outputs,
57+
string[] output_names, bool expect_failure = false)
58+
{
59+
IntPtr output_names_ptr = IntPtr.Zero;
60+
61+
func_ = c_api.TF_GraphToFunction(func_graph_, func_name_, false,
62+
num_opers, num_opers == -1 ? new IntPtr[0] : opers.Select(x => (IntPtr)x).ToArray(),
63+
inputs.Length, inputs.ToArray(),
64+
outputs.Length, outputs.ToArray(),
65+
output_names_ptr, IntPtr.Zero, null, s_.Handle);
66+
67+
// delete output_names_ptr
68+
69+
if (expect_failure)
70+
{
71+
ASSERT_EQ(IntPtr.Zero, func_);
72+
return;
73+
}
74+
75+
ASSERT_EQ(TF_OK, s_.Code, s_.Message);
76+
ASSERT_EQ(func_name_, c_api.StringPiece(c_api.TF_FunctionName(func_)));
77+
c_api.TF_GraphCopyFunction(host_graph_, func_, IntPtr.Zero, s_.Handle);
78+
ASSERT_EQ(TF_OK, s_.Code, s_.Message);
79+
}
80+
81+
Operation Use(Operation[] inputs)
82+
=> UseT(inputs.Select(x => new TF_Output(x, 0)).ToArray());
83+
84+
Operation UseT(TF_Output[] inputs)
85+
=> UseHelper(inputs);
86+
87+
Operation UseHelper(TF_Output[] inputs)
88+
{
89+
var desc = TF_NewOperation(host_graph_, func_name_, func_node_name_);
90+
foreach (var input in inputs)
91+
TF_AddInput(desc, input);
92+
c_api.TF_SetDevice(desc, "/cpu:0");
93+
var op = TF_FinishOperation(desc, s_);
94+
ASSERT_EQ(TF_OK, s_.Code, s_.Message);
95+
ASSERT_NE(op, IntPtr.Zero);
96+
97+
return op;
98+
}
99+
100+
void Run(KeyValuePair<Operation, Tensor>[] inputs, Operation output, int expected_result)
101+
=> Run(inputs, new[] { new TF_Output(output, 0) }, new[] { expected_result });
102+
103+
unsafe void Run(KeyValuePair<Operation, Tensor>[] inputs, TF_Output[] outputs, int[] expected_results)
104+
{
105+
var csession = new CSession(host_graph_, s_);
106+
ASSERT_EQ(TF_OK, s_.Code, s_.Message);
107+
108+
csession.SetInputs(inputs);
109+
csession.SetOutputs(outputs);
110+
csession.Run(s_);
111+
ASSERT_EQ(TF_OK, s_.Code, s_.Message);
112+
113+
for (int i = 0; i < expected_results.Length; ++i)
114+
{
115+
var output = csession.output_tensor(i);
116+
ASSERT_NE(output, IntPtr.Zero);
117+
EXPECT_EQ(TF_DataType.TF_INT32, c_api.TF_TensorType(output));
118+
EXPECT_EQ(0, c_api.TF_NumDims(output));
119+
ASSERT_EQ(sizeof(int), (int)c_api.TF_TensorByteSize(output));
120+
var output_contents = c_api.TF_TensorData(output);
121+
EXPECT_EQ(expected_results[i], *(int*)output_contents.ToPointer());
122+
}
123+
}
124+
125+
void VerifyFDef(string[] nodes)
126+
{
127+
var fdef = GetFunctionDef(func_);
128+
EXPECT_NE(fdef, IntPtr.Zero);
129+
VerifyFDefNodes(fdef, nodes);
130+
}
131+
132+
void VerifyFDefNodes(FunctionDef fdef, string[] nodes)
133+
{
134+
ASSERT_EQ(nodes.Length, fdef.NodeDef.Count);
135+
}
136+
137+
public void Dispose()
138+
{
139+
140+
}
141+
}
142+
}

test/TensorFlowNET.UnitTest/NativeAPI/CApiTest.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ protected void EXPECT_GE(int expected, int actual, string msg = "")
3535
protected void ASSERT_EQ(object expected, object actual, string msg = "")
3636
=> Assert.AreEqual(expected, actual, msg);
3737

38+
protected void ASSERT_NE(object expected, object actual, string msg = "")
39+
=> Assert.AreNotEqual(expected, actual, msg);
40+
3841
protected void ASSERT_TRUE(bool condition, string msg = "")
3942
=> Assert.IsTrue(condition, msg);
4043

test/TensorFlowNET.UnitTest/NativeAPI/CSession.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,17 @@ public void SetInputs(Dictionary<Operation, Tensor> inputs)
4141
}
4242
}
4343

44+
public void SetInputs(KeyValuePair<Operation, Tensor>[] inputs)
45+
{
46+
DeleteInputValues();
47+
inputs_.Clear();
48+
foreach (var input in inputs)
49+
{
50+
inputs_.Add(new TF_Output(input.Key, 0));
51+
input_values_.Add(input.Value);
52+
}
53+
}
54+
4455
private void DeleteInputValues()
4556
{
4657
//clearing is enough as they will be disposed by the GC unless they are referenced else-where.

test/TensorFlowNET.UnitTest/NativeAPI/c_test_util.cs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using System.Diagnostics.CodeAnalysis;
1+
using System;
2+
using System.Diagnostics.CodeAnalysis;
23
using Tensorflow;
34
using Tensorflow.Util;
45
using Buffer = Tensorflow.Buffer;
@@ -60,6 +61,16 @@ public static GraphDef GetGraphDef(Graph graph)
6061
}
6162
}
6263

64+
public static FunctionDef GetFunctionDef(IntPtr func)
65+
{
66+
using var s = new Status();
67+
using var buffer = new Buffer();
68+
c_api.TF_FunctionToFunctionDef(func, buffer.Handle, s.Handle);
69+
s.Check(true);
70+
var func_def = FunctionDef.Parser.ParseFrom(buffer.ToArray());
71+
return func_def;
72+
}
73+
6374
public static bool IsAddN(NodeDef node_def, int n)
6475
{
6576
if (node_def.Op != "AddN" || node_def.Name != "add" ||

0 commit comments

Comments
 (0)