Skip to content

Commit 1d1657d

Browse files
committed
Use operation with customized C API.
1 parent fd1eb40 commit 1d1657d

File tree

4 files changed

+22
-32
lines changed

4 files changed

+22
-32
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Runtime.InteropServices;
4+
using System.Text;
5+
6+
namespace Tensorflow
7+
{
8+
public partial class c_api
9+
{
10+
[DllImport(TensorFlowLibName)]
11+
public static extern void TFC_SetAttr(SafeGraphHandle graph, IntPtr op, string attr_name, SafeBufferHandle attr_value_proto, SafeStatusHandle status);
12+
}
13+
}

src/TensorFlowNET.Core/Functions/ConcreteFunction.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible
223223
{
224224
input_tangents = new TangentInfo();
225225
}
226-
if(possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER)
226+
if(possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER || tf.Runner.MustRecordGradient())
227227
{
228228
if(input_tangents.Indices is not null || executing_eagerly)
229229
{

src/TensorFlowNET.Core/Operations/Operation.cs

+7-30
Original file line numberDiff line numberDiff line change
@@ -317,27 +317,18 @@ internal void _add_outputs(TF_DataType[] types, Shape[] shapes)
317317
{
318318
Debug.Assert(types.Length == shapes.Length);
319319
int orig_num_outputs = this.outputs.Length;
320-
//var new_outputs = new List<Tensor>(_outputs);
321-
322-
var old_outputs = _outputs;
323-
_outputs = new Tensor[orig_num_outputs + types.Length];
324-
for(int i = 0; i < orig_num_outputs; i++)
325-
{
326-
_outputs[i] = old_outputs[i];
327-
}
320+
var new_outputs = new List<Tensor>(_outputs);
328321

329322
// Since the `_outputs` is defined as `Array`, when we add new output, we
330323
// have to create a new array, which brings some performance concerns.
331324
// In the future maybe the type of `outputs` should be reconsidered.
332325
for(int i = 0; i < types.Length; i++)
333326
{
334-
var t = new Tensor(this, orig_num_outputs + 1, types[i]);
335-
_outputs[i] = t;
336-
//t = tf.ensure_shape(t, shapes[i]);
327+
var t = new Tensor(this, orig_num_outputs + i, types[i]);
337328
t.shape = shapes[i];
338-
//new_outputs.Add(t);
329+
new_outputs.Add(t);
339330
}
340-
//_outputs = new_outputs.ToArray();
331+
_outputs = new_outputs.ToArray();
341332
}
342333

343334
internal void _set_func_attr(string attr_name, string func_name)
@@ -372,23 +363,9 @@ internal void _set_attr(string attr_name, AttrValue attr_value)
372363

373364
internal void _set_attr_with_buf(string attr_name, Buffer attr_buf)
374365
{
375-
//if(_op_desc is null)
376-
//{
377-
// //var new_node_def = NodeDef.Parser.ParseFrom(node_def.ToByteArray());
378-
// //new_node_def.Name += "_temp";
379-
// //var op = new Operation(new_node_def, graph, inputs, _output_types, control_inputs, _input_types);
380-
// //Status status = new();
381-
// //c_api.TF_SetAttrBool(op._op_desc, "trainable", true);
382-
// ////c_api.TF_SetAttrValueProto(op._op_desc, attr_name, attr_buf.ToArray(), attr_buf.Length, status);
383-
// //status.Check(true);
384-
// // TODO(Rinne): deal with it. Give a warning or make the Operation always contains `op_desc`.
385-
//}
386-
//else
387-
//{
388-
// //Status status = new();
389-
// //c_api.TF_SetAttrValueProto(_op_desc, attr_name, attr_buf.ToArray(), attr_buf.Length, status);
390-
// //status.Check(true);
391-
//}
366+
Status status = new();
367+
c_api.TFC_SetAttr(graph, _handle, attr_name, attr_buf, status);
368+
status.Check(true);
392369
}
393370
}
394371
}

src/TensorFlowNET.Core/Tensors/Tensor.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ protected virtual Shape GetShapeInternal()
135135

136136
protected virtual void SetShapeInternal(Shape value)
137137
{
138-
if (value == null)
138+
if (value is null || value.ndim == 0 || value.ndim == -1)
139139
c_api.TF_GraphSetTensorShape(op.graph.c_graph, _as_tf_output(), null, -1, tf.Status);
140140
else
141141
c_api.TF_GraphSetTensorShape(op.graph.c_graph, _as_tf_output(), value.dims, value.ndim, tf.Status);

0 commit comments

Comments
 (0)