Skip to content

Commit 6d1b459

Browse files
committed
Fix resize_nearest_neighbor_grad.
1 parent 6fe6057 commit 6d1b459

19 files changed

+241
-136
lines changed

src/TensorFlowNET.Core/Functions/ConcreteFunction.cs

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ public class ConcreteFunction : IDisposable
1414
{
1515
IntPtr _handle;
1616
FuncGraph func_graph;
17+
public Tensor[] Inputs => func_graph.Inputs;
1718
public Tensor[] CapturedInputs => func_graph.external_captures;
1819

1920
public string Name
@@ -127,30 +128,53 @@ public void Exit()
127128
func_graph.Exit();
128129
}
129130

130-
public Tensors Invoke(Tensors inputs)
131+
public Tensors FilteredCall(Tensors inputs)
131132
{
132-
var forward_backward = SelectForwardAndBackwardFunctions(inputs, 1, tf.Context.executing_eagerly());
133-
var (forward_function, args_with_tangents) = forward_backward.Forward();
134-
Tensors flat_outputs = null;
135-
if (tf.Context.executing_eagerly())
136-
flat_outputs = forward_function.Call(args_with_tangents);
137-
forward_backward.Record(flat_outputs);
138-
return flat_outputs;
133+
return CallFlat(inputs, CapturedInputs);
139134
}
140135

136+
/// <summary>
137+
/// Executes the wrapped function.
138+
/// </summary>
139+
/// <param name="args"></param>
140+
/// <param name="captured_inputs"></param>
141+
/// <returns></returns>
141142
public Tensor[] CallFlat(Tensor[] args, Tensor[] captured_inputs)
142143
{
143-
var new_args = new List<Tensor>();
144-
new_args.AddRange(args);
145-
new_args.AddRange(captured_inputs);
146-
args = new_args.ToArray();
144+
var executing_eagerly = tf.Context.executing_eagerly();
145+
var default_graph = ops.get_default_graph();
146+
var tensor_inputs = new Tensors();
147+
foreach (var (i, arg) in enumerate(args))
148+
{
149+
tensor_inputs.Add(arg);
150+
// If we're graph building, shape inference is on.
151+
if (!executing_eagerly)
152+
{
153+
}
154+
}
155+
tensor_inputs.AddRange(captured_inputs);
156+
157+
args = tensor_inputs.ToArray();
147158

148-
var attrs = new object[]
159+
var possible_gradient_type = tf.Runner.MustRecordGradient() ? 1 : 0;
160+
// No tape is watching; skip to running the function.
161+
if (possible_gradient_type == 0 && executing_eagerly)
149162
{
150-
"executor_type", "",
151-
"config_proto", tf.Context.FunctionCallOptions.config_proto_serialized()
152-
};
153-
return tf.Runner.Execute(tf.Context, func_graph.FuncName, func_graph.Outputs.Length, args, attrs);
163+
var attrs = new object[]
164+
{
165+
"executor_type", "",
166+
"config_proto", tf.Context.FunctionCallOptions.config_proto_serialized()
167+
};
168+
return tf.Runner.Execute(tf.Context, func_graph.FuncName, func_graph.Outputs.Length, args, attrs);
169+
}
170+
171+
var forward_backward = SelectForwardAndBackwardFunctions(args, possible_gradient_type, executing_eagerly);
172+
var (forward_function, args_with_tangents) = forward_backward.Forward();
173+
Tensors flat_outputs = null;
174+
if (executing_eagerly)
175+
flat_outputs = forward_function.Call(args_with_tangents);
176+
forward_backward.Record(flat_outputs);
177+
return flat_outputs;
154178
}
155179

156180
ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly)

src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,17 @@ public EagerDefinedFunction(string name, FuncGraph graph,
3131

3232
public Tensors Call(Tensors args)
3333
{
34+
var attrs = new object[]
35+
{
36+
"executor_type", "",
37+
"config_proto", tf.Context.FunctionCallOptions.config_proto_serialized()
38+
};
39+
3440
var results = tf.Runner.TFE_Execute(tf.Context,
3541
tf.Context.DeviceName,
3642
_func_graph.FuncName,
3743
args,
38-
null,
44+
attrs,
3945
_num_outputs);
4046

4147
return results;

src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,24 +49,61 @@ public void Record(Tensors flat_outputs, Tensors inference_args)
4949
getBackwardFunction: () => backward_function);
5050
}
5151

52+
/// <summary>
53+
/// Create a backward function given `outputs` from the forward function.
54+
/// </summary>
55+
/// <param name="forward_graph"></param>
56+
/// <param name="backward"></param>
57+
/// <param name="outputs"></param>
58+
/// <returns></returns>
5259
(BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors outputs)
5360
{
54-
BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) =>
61+
var capture_mapping = new Dictionary<long, Tensor>();
62+
foreach(var (i, output) in enumerate(outputs))
63+
capture_mapping[forward_graph.Outputs[i].Id] = output;
64+
65+
var remapped_captures = new Tensors();
66+
foreach(var capture in backward.CapturedInputs)
67+
{
68+
if (capture_mapping.ContainsKey(capture.Id))
69+
remapped_captures.Add(capture_mapping[capture.Id]);
70+
}
71+
72+
var backward_function_inputs = backward.Inputs.Length - backward.CapturedInputs.Length;
73+
var recorded_outputs = new Tensors();
74+
var relevant_outputs = outputs;
75+
var trainable_recorded_outputs = 0;
76+
var skip_positions = new List<int>();
77+
foreach (var (output_index, output) in enumerate(relevant_outputs))
78+
{
79+
if (trainable_recorded_outputs < backward_function_inputs)
80+
recorded_outputs.Add(output);
81+
if (gradients_util.IsTrainable(output))
82+
trainable_recorded_outputs += 1;
83+
else
84+
skip_positions.Add(output_index);
85+
}
86+
87+
BackwardFunction _backward_function_wrapper = (args, unneeded_gradients) =>
5588
{
56-
var processed_args = new List<Tensor>();
89+
var processed_args = new Tensors();
5790
var input_index = 0;
58-
foreach (var (output_index, arg) in enumerate(output_grads))
91+
foreach (var (output_index, arg) in enumerate(args))
5992
{
60-
if (arg is null)
93+
if (skip_positions.Contains(output_index))
94+
continue;
95+
if (arg == null)
6196
throw new NotImplementedException("");
62-
processed_args.add(arg);
97+
processed_args.Add(arg);
6398
input_index += 1;
99+
if (input_index >= backward_function_inputs)
100+
break;
64101
}
65102
tf.Logger.Debug($"Invoke backward function: {backward.Name}");
66-
return backward.CallFlat(processed_args.ToArray(), outputs);
103+
return backward.CallFlat(processed_args, remapped_captures);
67104
};
68105

69-
return (_backward_function_wrapper, outputs);
106+
return (_backward_function_wrapper, recorded_outputs);
70107
}
71108

72109
protected (EagerDefinedFunction, FuncGraph, ConcreteFunction, List<int>, int)
@@ -103,7 +140,7 @@ public void Record(Tensors flat_outputs, Tensors inference_args)
103140
}
104141
backwards_graph.Exit();
105142

106-
var forward_function_name = $"{_FORWARD_PREFIX}_{ops.uid()}";
143+
var forward_function_name = $"{_FORWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}";
107144
var backward_function_attr = new Dictionary<string, string>();
108145
backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name;
109146
gradients_wrt_outputs.append(backwards_graph.internal_captures);

src/TensorFlowNET.Core/Gradients/array_grad.cs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -228,13 +228,14 @@ public static Tensor[] _PadGrad(Operation op, Tensor[] grads)
228228
var grad = grads[0];
229229
var x = op.inputs[0];
230230
var a = op.inputs[1];
231-
var size = array_ops.stack(new object[] { array_ops.rank(x), 1 });
232-
var pad_before = array_ops.slice(a, new[] { 0, 0 }, size);
231+
var size = array_ops.stack(new Tensor[] { array_ops.rank(x), constant_op.constant(1) });
232+
var begin = constant_op.constant(new[] { 0, 0 });
233+
var pad_before = array_ops.slice(a, begin, size);
233234

234235
// Make it a 1-D tensor.
235-
var begin = array_ops.reshape(pad_before, new[] { -1 });
236-
var sizes = array_ops.shape(x);
237-
var x_grad = array_ops.slice(grad, begin, sizes);
236+
begin = array_ops.reshape(pad_before, new[] { -1 });
237+
size = array_ops.shape(x);
238+
var x_grad = array_ops.slice(grad, begin, size);
238239

239240
if (len(op.inputs) == 3)
240241
return new Tensor[] { x_grad, null, null };

src/TensorFlowNET.Core/Gradients/image_grad.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ public static Tensor[] _ResizeNearestNeighborGrad(Operation op, Tensor[] grads)
3030
var shape = new TensorShape(image.shape.Skip(1).Take(2).ToArray());
3131
Tensor image_shape = null;
3232
if (shape.is_fully_defined())
33-
throw new NotImplementedException("_ResizeNearestNeighborGrad shape.is_fully_defined");
33+
image_shape = constant_op.constant(image.shape[1..3]);
3434
else
3535
image_shape = array_ops.shape(image)["1:3"];
3636

src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88

99
namespace Tensorflow.Graphs
1010
{
11+
/// <summary>
12+
/// func_graph.py func_graph_from_py_func
13+
/// </summary>
1114
[AllowChangingInputArguments]
1215
public sealed class AutoGraphAttribute : OnMethodBoundaryAspect
1316
{
@@ -18,15 +21,16 @@ public sealed class AutoGraphAttribute : OnMethodBoundaryAspect
1821

1922
public override void OnEntry(MethodExecutionArgs args)
2023
{
21-
func_name = $"{args.Method.Name}_{Guid.NewGuid()}";
24+
// TODO: func_name can be cache in FullName + Args
25+
func_name = $"{args.Method.DeclaringType.FullName}.{args.Method.Name}_{Guid.NewGuid()}";
2226

2327
if (functions.ContainsKey(func_name))
2428
{
2529
function = functions[func_name];
2630
if (args.Arguments[0] is Tensors tensor_inputs)
27-
args.ReturnValue = ConvertReturnValue(function.Invoke(tensor_inputs));
31+
args.ReturnValue = ConvertReturnValue(function.FilteredCall(tensor_inputs));
2832
else
29-
args.ReturnValue = ConvertReturnValue(function.Invoke(args.Arguments.Select(x => x as Tensor).ToArray()));
33+
args.ReturnValue = ConvertReturnValue(function.FilteredCall(args.Arguments.Select(x => x as Tensor).ToArray()));
3034
args.FlowBehavior = FlowBehavior.Return;
3135
return;
3236
}
@@ -62,22 +66,35 @@ public override void OnExit(MethodExecutionArgs args)
6266
{
6367
if (args.ReturnValue is Tensors outputs)
6468
{
65-
if (args.Arguments[0] is Tensors inputs)
66-
function.ToGraph(inputs, outputs);
69+
Tensors inputs = null;
70+
outputs = mark_as_return(outputs);
71+
if (args.Arguments[0] is Tensors inputs1)
72+
inputs = inputs1;
6773
else
68-
function.ToGraph(args.Arguments.Select(x => x as Tensor).ToArray(), outputs);
74+
inputs = args.Arguments.Select(x => x as Tensor).ToArray();
75+
76+
inputs = inputs.Where(x => x.op.OpType == "Placeholder"
77+
&& x.op.name.StartsWith("inputs")).ToArray();
78+
79+
function.ToGraph(inputs, outputs);
6980
}
70-
else
71-
function.ToGraph(args.Arguments.Select(x => x as Tensor).ToArray(), args.ReturnValue as Tensor);
72-
81+
else if (args.ReturnValue is Tensor output)
82+
{
83+
var inputs = args.Arguments.Select(x => x as Tensor)
84+
.Where(x => x.op.type == "Placeholder" && x.op.name.StartsWith("inputs"))
85+
.ToArray();
86+
var outputs2 = array_ops.identity(output);
87+
function.ToGraph(inputs, outputs2);
88+
}
89+
7390
function.Exit();
7491

7592
// cache function.
7693
function.ReturnType = args.ReturnValue.GetType();
7794
functions[func_name] = function;
7895

7996
// run function
80-
args.ReturnValue = ConvertReturnValue(function.Invoke(originalInputs));
97+
args.ReturnValue = ConvertReturnValue(function.FilteredCall(originalInputs));
8198
}
8299

83100
object ConvertReturnValue(Tensors tensors)
@@ -87,5 +104,20 @@ object ConvertReturnValue(Tensors tensors)
87104
else
88105
return tensors;
89106
}
107+
108+
/// <summary>
109+
/// Acts like identity but marks the `Tensor` as a return value.
110+
/// </summary>
111+
/// <param name="tensors"></param>
112+
/// <returns></returns>
113+
public Tensors mark_as_return(Tensors tensors)
114+
{
115+
if (tensors == null)
116+
return null;
117+
var result = new Tensors();
118+
foreach (var tensor in tensors)
119+
result.Add(array_ops.identity(tensor));
120+
return result;
121+
}
90122
}
91123
}

src/TensorFlowNET.Core/Operations/array_ops.cs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -925,7 +925,28 @@ public static Tensor slice(Tensor input, Tensor[] begin, Tensor[] size, string n
925925
public static Tensor slice<Tb, Ts>(Tensor input, Tb begin, Ts size, string name = null)
926926
=> gen_array_ops.slice(input, begin, size, name: name);
927927

928-
public static Tensor stack(object values, int axis = 0, string name = "stack")
928+
public static Tensor slice(Tensor input, Tensor begin, Tensor size, string name = null)
929+
=> tf.Context.RunInAutoMode2(
930+
() => tf.OpDefLib._apply_op_helper("Slice", name, new
931+
{
932+
input, begin, size
933+
}).output,
934+
() => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
935+
"Slice", name,
936+
null,
937+
input, begin, size).FirstOrDefault(),
938+
(op) =>
939+
{
940+
var attrs = new object[]
941+
{
942+
"T", op.get_attr<TF_DataType>("T"),
943+
"Index", op.get_attr<int>("Index")
944+
};
945+
tf.Runner.RecordGradient("Slice", op.inputs, attrs, op.outputs);
946+
},
947+
new Tensors(input, begin, size));
948+
949+
public static Tensor stack(object values, int axis = 0, string name = "stack")
929950
{
930951
if (axis == 0)
931952
// If the input is a constant list, it can be converted to a constant op

src/TensorFlowNET.Core/Operations/gen_image_ops.cs

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -238,18 +238,32 @@ public static Tensor resize_nearest_neighbor<Tsize>(Tensor images, Tsize size, b
238238
"half_pixel_centers", half_pixel_centers).FirstOrDefault(),
239239
images);
240240

241-
public static Tensor resize_nearest_neighbor_grad<Tsize>(Tensor grads, Tsize size, bool align_corners = false,
241+
public static Tensor resize_nearest_neighbor_grad(Tensor grads, Tensor size, bool align_corners = false,
242242
bool half_pixel_centers = false, string name = null)
243-
{
244-
var op = tf.OpDefLib._apply_op_helper("ResizeNearestNeighborGrad", name: name, args: new
245-
{
246-
grads,
247-
size,
248-
align_corners,
249-
half_pixel_centers
250-
});
251-
252-
return op.output;
253-
}
243+
=> tf.Context.RunInAutoMode2(
244+
() => tf.OpDefLib._apply_op_helper("ResizeNearestNeighborGrad", name, new
245+
{
246+
grads,
247+
size,
248+
align_corners,
249+
half_pixel_centers
250+
}).output,
251+
() => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
252+
"ResizeNearestNeighborGrad", name,
253+
null,
254+
grads, size,
255+
"align_corners", align_corners,
256+
"half_pixel_centers", half_pixel_centers).FirstOrDefault(),
257+
(op) =>
258+
{
259+
var attrs = new object[]
260+
{
261+
"T", op.get_attr<TF_DataType>("T"),
262+
"align_corners", op.get_attr<bool>("align_corners"),
263+
"half_pixel_centers", op.get_attr<bool>("half_pixel_centers")
264+
};
265+
tf.Runner.RecordGradient("ResizeNearestNeighborGrad", op.inputs, attrs, op.outputs);
266+
},
267+
new Tensors(grads, size));
254268
}
255269
}

src/TensorFlowNET.Core/Operations/gen_random_ops.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,16 @@ public static Tensor random_uniform(Tensor shape, TF_DataType dtype, int? seed =
126126
public static Tensor random_shuffle(Tensor value, int seed = 0, int seed2 = 0,
127127
string name = null)
128128
{
129+
if (tf.Context.executing_eagerly())
130+
{
131+
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
132+
"RandomShuffle", name,
133+
null,
134+
value, seed, seed2);
135+
136+
return results[0];
137+
}
138+
129139
var _op = tf.OpDefLib._apply_op_helper("RandomShuffle",
130140
name: name,
131141
args: new { value, seed, seed2 });

src/TensorFlowNET.Core/Tensorflow.Binding.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ TensorFlow .NET v0.30 is focused on making more Keras API work including:
8383

8484
<ItemGroup>
8585
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.138" />
86+
<PackageReference Include="Microsoft.Extensions.DependencyInjection" Version="5.0.1" />
8687
<PackageReference Include="NumSharp.Lite" Version="0.1.10" />
8788
<PackageReference Include="Protobuf.Text" Version="0.4.0" />
8889
<PackageReference Include="Serilog.Sinks.Console" Version="3.1.1" />

0 commit comments

Comments
 (0)