Skip to content

Commit b5c3578

Browse files
authored
Merge pull request #1 from SciSharp/master
Pulling from SciSharp
2 parents 4205c50 + 1edf86a commit b5c3578

File tree

12 files changed

+113
-25
lines changed

12 files changed

+113
-25
lines changed

src/TensorFlowNET.Core/APIs/tf.math.cs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,24 @@ public static Tensor arg_min(Tensor input, int dimension, TF_DataType output_typ
5454
public static Tensor ceil(Tensor x, string name = null)
5555
=> gen_math_ops.ceil(x, name);
5656

57+
/// <summary>
58+
/// Computes sin of x element-wise.
59+
/// </summary>
60+
/// <param name="x"></param>
61+
/// <param name="name"></param>
62+
/// <returns></returns>
63+
public static Tensor sin(Tensor x, string name = null)
64+
=> gen_math_ops.sin(x, name);
65+
66+
/// <summary>
67+
/// Computes hyperbolic sine of x element-wise.
68+
/// </summary>
69+
/// <param name="x"></param>
70+
/// <param name="name"></param>
71+
/// <returns></returns>
72+
public static Tensor sinh(Tensor x, string name = null)
73+
=> gen_math_ops.sinh(x, name);
74+
5775
/// <summary>
5876
/// Computes cos of x element-wise.
5977
/// </summary>
@@ -72,6 +90,12 @@ public static Tensor cos(Tensor x, string name = null)
7290
public static Tensor cosh(Tensor x, string name = null)
7391
=> gen_math_ops.cosh(x, name);
7492

93+
public static Tensor tan(Tensor x, string name = null)
94+
=> gen_math_ops.tan(x, name);
95+
96+
public static Tensor tanh(Tensor x, string name = null)
97+
=> gen_math_ops.tanh(x, name);
98+
7599
/// <summary>
76100
/// Returns element-wise largest integer not greater than x.
77101
/// </summary>
@@ -257,6 +281,9 @@ public static Tensor reduce_sum(Tensor input, int axis, int? reduction_indices =
257281
return math_ops.reduce_sum(input, axis);
258282
}
259283

284+
public static Tensor sum(Tensor input, int axis, bool keep_dims = false, string name = null)
285+
=> gen_math_ops._sum(input, axis, keep_dims: keep_dims, name: name);
286+
260287
public static Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null, int? reduction_indices = null)
261288
=> math_ops.reduce_mean(input_tensor, axis: axis, keepdims: keepdims, name: name, reduction_indices: reduction_indices);
262289

src/TensorFlowNET.Core/Clustering/_InitializeClustersOpFactory.cs

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,6 @@ public _InitializeClustersOpFactory(Tensor[] inputs,
5252
_num_data = math_ops.add_n(_inputs.Select(i => array_ops.shape(i)[0]).ToArray());
5353
}
5454

55-
public Tensor[] op()
56-
{
57-
return control_flow_ops.cond(gen_math_ops.equal(_num_remaining, 0),
58-
() => new Operation[] { check_ops.assert_equal(_cluster_centers_initialized, true) },
59-
_initialize);
60-
}
61-
6255
private Operation[] _initialize()
6356
{
6457
with(ops.control_dependencies(new Operation[]
@@ -72,6 +65,17 @@ private Operation[] _initialize()
7265
throw new NotImplementedException("_InitializeClustersOpFactory _initialize");
7366
}
7467

68+
public Tensor[] op()
69+
{
70+
return control_flow_ops.cond(gen_math_ops.equal(_num_remaining, 0),
71+
() =>
72+
{
73+
var op = check_ops.assert_equal(_cluster_centers_initialized, true);
74+
return new Operation[] { op };
75+
},
76+
_initialize);
77+
}
78+
7579
/*private int _add_new_centers()
7680
{
7781
var new_centers = _choose_initial_centers();

src/TensorFlowNET.Core/Operations/Operation.Control.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ public void _control_flow_post_processing()
2525

2626
public void _add_control_input(Operation op)
2727
{
28-
c_api.TF_AddControlInput(_handle, op);
28+
c_api.TF_AddControlInput(_operDesc, op);
2929
}
3030

3131
public void _add_control_inputs(Operation[] ops)

src/TensorFlowNET.Core/Operations/Operation.cs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ namespace Tensorflow
1111
public partial class Operation : ITensorOrOperation
1212
{
1313
private readonly IntPtr _handle; // _c_op in python
14+
private readonly IntPtr _operDesc;
1415

1516
private Graph _graph;
1617
//[JsonIgnore]
@@ -58,9 +59,9 @@ public Operation(Graph g, string opType, string oper_name)
5859
{
5960
_graph = g;
6061

61-
var desc = c_api.TF_NewOperation(g, opType, oper_name);
62-
c_api.TF_SetAttrType(desc, "dtype", TF_DataType.TF_INT32);
63-
c_api.TF_FinishOperation(desc, status);
62+
_operDesc = c_api.TF_NewOperation(g, opType, oper_name);
63+
c_api.TF_SetAttrType(_operDesc, "dtype", TF_DataType.TF_INT32);
64+
_handle = c_api.TF_FinishOperation(_operDesc, status);
6465
}
6566

6667
/// <summary>
@@ -112,7 +113,7 @@ public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[
112113
op_def = g.GetOpDef(node_def.Op);
113114

114115
var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr);
115-
_handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray());
116+
(_handle, _operDesc) = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray());
116117

117118
// Initialize self._outputs.
118119
output_types = new TF_DataType[NumOutputs];

src/TensorFlowNET.Core/Operations/check_ops.cs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,29 @@ public class check_ops : Python
1212
/// <param name="t1"></param>
1313
/// <param name="t2"></param>
1414
/// <param name="name"></param>
15-
public static Operation assert_equal(object t1, object t2, object[] data = null, string name = null)
15+
public static Operation assert_equal(object t1, object t2, object[] data = null, string message = null, string name = null)
1616
{
17+
if (message == null)
18+
message = "";
19+
1720
return with(ops.name_scope(name, "assert_equal", new { t1, t2, data }), delegate
1821
{
1922
var x = ops.convert_to_tensor(t1, name: "x");
2023
var y = ops.convert_to_tensor(t2, name: "y");
24+
25+
if (data == null)
26+
{
27+
data = new object[]
28+
{
29+
message,
30+
"Condition x == y did not hold element-wise:",
31+
$"x (%s) = {x.name}",
32+
x,
33+
$"y (%s) = {y.name}",
34+
y
35+
};
36+
}
37+
2138
var condition = math_ops.reduce_all(gen_math_ops.equal(x, y));
2239
var x_static = tensor_util.constant_value(x);
2340
var y_static = tensor_util.constant_value(y);

src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,16 @@ public static Operation Assert(Tensor condition, object[] data, int? summarize =
1616
name = scope;
1717
var xs = ops.convert_n_to_tensor(data);
1818
condition = ops.convert_to_tensor(condition, name: "Condition");
19-
Func<Operation[]> true_assert = () => new Operation[]
19+
Func<Operation[]> true_assert = () =>
2020
{
21-
gen_logging_ops._assert(condition, data, summarize, name: "Assert")
21+
var assert = gen_logging_ops._assert(condition, data, summarize, name: "Assert");
22+
return new Operation[] { assert };
2223
};
2324

24-
Func<Operation[]> false_assert = () => new Operation[]
25+
Func<Operation[]> false_assert = () =>
2526
{
26-
gen_control_flow_ops.no_op()
27+
var op = gen_control_flow_ops.no_op();
28+
return new Operation[] { op };
2729
};
2830

2931
var guarded_assert = cond(condition, false_assert, true_assert, name: "AssertGuard");

src/TensorFlowNET.Core/Operations/gen_math_ops.cs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,20 @@ public static Tensor ceil(Tensor x, string name = null)
101101
return _op.outputs[0];
102102
}
103103

104+
public static Tensor sin(Tensor x, string name = null)
105+
{
106+
var _op = _op_def_lib._apply_op_helper("Sin", name, args: new { x });
107+
108+
return _op.outputs[0];
109+
}
110+
111+
public static Tensor sinh(Tensor x, string name = null)
112+
{
113+
var _op = _op_def_lib._apply_op_helper("Sinh", name, args: new { x });
114+
115+
return _op.outputs[0];
116+
}
117+
104118
public static Tensor cos(Tensor x, string name = null)
105119
{
106120
var _op = _op_def_lib._apply_op_helper("Cos", name, args: new { x });
@@ -115,6 +129,20 @@ public static Tensor cosh(Tensor x, string name = null)
115129
return _op.outputs[0];
116130
}
117131

132+
public static Tensor tan(Tensor x, string name = null)
133+
{
134+
var _op = _op_def_lib._apply_op_helper("Tan", name, args: new { x });
135+
136+
return _op.outputs[0];
137+
}
138+
139+
public static Tensor tanh(Tensor x, string name = null)
140+
{
141+
var _op = _op_def_lib._apply_op_helper("Tanh", name, args: new { x });
142+
143+
return _op.outputs[0];
144+
}
145+
118146
public static Tensor floor(Tensor x, string name = null)
119147
{
120148
var _op = _op_def_lib._apply_op_helper("Floor", name, args: new { x });

src/TensorFlowNET.Core/Tensors/dtypes.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ public static Type as_numpy_datatype(this TF_DataType type)
1010
{
1111
switch (type)
1212
{
13+
case TF_DataType.TF_BOOL:
14+
return typeof(bool);
1315
case TF_DataType.TF_INT32:
1416
return typeof(int);
1517
case TF_DataType.TF_INT16:

src/TensorFlowNET.Core/Tensors/tensor_util.cs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,23 @@ public static NDArray MakeNdarray(TensorProto tensor)
4747
var tensor_dtype = tensor.Dtype.as_numpy_dtype();
4848

4949
if (tensor.TensorContent.Length > 0)
50-
return np.frombuffer(tensor.TensorContent.ToByteArray(), tensor_dtype)
51-
.reshape(shape);
50+
{
51+
return np.frombuffer(tensor.TensorContent.ToByteArray(), tensor_dtype).reshape(shape);
52+
}
5253
else if (tensor.Dtype == DataType.DtHalf || tensor.Dtype == DataType.DtBfloat16)
5354
;
5455
else if (tensor.Dtype == DataType.DtFloat)
5556
;
5657
else if (new DataType[] { DataType.DtInt32, DataType.DtUint8 }.Contains(tensor.Dtype))
58+
{
5759
if (tensor.IntVal.Count == 1)
58-
return np.repeat(np.array(tensor.IntVal[0]), Convert.ToInt32(num_elements))
59-
.reshape(shape);
60+
return np.repeat(np.array(tensor.IntVal[0]), num_elements).reshape(shape);
61+
}
62+
else if (tensor.Dtype == DataType.DtBool)
63+
{
64+
if (tensor.BoolVal.Count == 1)
65+
return np.repeat(np.array(tensor.BoolVal[0]), num_elements).reshape(shape);
66+
}
6067

6168
throw new NotImplementedException("MakeNdarray");
6269
}

src/TensorFlowNET.Core/Variables/RefVariable.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ public ITensorOrOperation assign(object value, bool use_locking = false, string
265265

266266
public override string ToString()
267267
{
268-
return $"tf.Variable '{name}' shape={shape} dtype={dtype}";
268+
return $"tf.RefVariable '{name}' shape={shape} dtype={dtype}";
269269
}
270270

271271
public VariableDef to_proto(string export_scope)

src/TensorFlowNET.Core/ops.py.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ public static _ControlDependenciesController control_dependencies(Operation[] co
122122
/// </param>
123123
/// <param name="control_inputs">A list of `Operation`s to set as control dependencies.</param>
124124
/// <returns>A wrapped TF_Operation*.</returns>
125-
public static IntPtr _create_c_op<T>(Graph graph, NodeDef node_def, T[] inputs, Operation[] control_inputs)
125+
public static (IntPtr, IntPtr) _create_c_op<T>(Graph graph, NodeDef node_def, T[] inputs, Operation[] control_inputs)
126126
{
127127
var op_desc = graph.NewOperation(node_def.Op, node_def.Name);
128128

@@ -164,7 +164,7 @@ public static IntPtr _create_c_op<T>(Graph graph, NodeDef node_def, T[] inputs,
164164

165165
status.Check(true);
166166

167-
return c_op;
167+
return (c_op, op_desc);
168168
}
169169

170170
public static OpDef _get_op_def(Graph graph, string type)

test/TensorFlowNET.Examples/KMeansClustering.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ namespace TensorFlowNET.Examples
1616
public class KMeansClustering : Python, IExample
1717
{
1818
public int Priority => 8;
19-
public bool Enabled => false;
19+
public bool Enabled => true;
2020
public string Name => "K-means Clustering";
2121

2222
Datasets mnist;

0 commit comments

Comments
 (0)