Skip to content

Commit 3d0e2d0

Browse files
authored
Merge pull request #1032 from AsakusaRinne/master
Fix the error of loading model saved before tf2.5.
2 parents e72024b + 7823b08 commit 3d0e2d0

File tree

19 files changed

+288
-83
lines changed

19 files changed

+288
-83
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Train;
5+
6+
namespace Tensorflow
7+
{
8+
public partial class tensorflow
9+
{
10+
public SavedModelAPI saved_model { get; } = new SavedModelAPI();
11+
}
12+
13+
public class SavedModelAPI
14+
{
15+
public Trackable load(string export_dir, LoadOptions? options = null)
16+
{
17+
return Loader.load(export_dir, options);
18+
}
19+
}
20+
}

src/TensorFlowNET.Core/Graphs/FuncGraph.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
using Tensorflow.Framework;
99
using Tensorflow.Framework.Models;
1010
using Tensorflow.Functions;
11+
using Tensorflow.NumPy;
1112
using Tensorflow.Operations;
1213
using Tensorflow.Util;
1314
using static Tensorflow.Binding;
@@ -181,7 +182,7 @@ public override Operation create_op(string op_type, Tensor[] inputs, TF_DataType
181182
const int _EAGER_CONST_THRESHOLD = 128;
182183
public Tensor capture(Tensor tensor, string name = null, Shape shape = null)
183184
{
184-
if(tensor is EagerTensor)
185+
if(tensor is EagerTensor or NDArray)
185186
{
186187
if (name == null)
187188
name = ops.uid().ToString();

src/TensorFlowNET.Core/Keras/Engine/IOptimizer.cs

+1
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ void apply_gradients((Tensor, IVariableV1) grads_and_vars,
1010
void apply_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars,
1111
string name = null,
1212
bool experimental_aggregate_gradients = true);
13+
IVariableV1 add_slot(IVariableV1 var, string slot_name, IInitializer initializer = null);
1314
}

src/TensorFlowNET.Core/Operations/Operation.cs

+5-3
Original file line numberDiff line numberDiff line change
@@ -216,10 +216,12 @@ public virtual T[] get_attr_list<T>(string name)
216216
public virtual object get_attr(string name)
217217
{
218218
var buf = new Buffer();
219-
c_api.TF_OperationGetAttrValueProto(_handle, name, buf, tf.Status);
220-
tf.Status.Check(true);
219+
Status status = new();
220+
c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status);
221+
status.Check(true);
222+
var tf_buffer = c_api.TF_GetBuffer(buf);
221223

222-
var x = AttrValue.Parser.ParseFrom(buf.ToArray());
224+
var x = AttrValue.Parser.ParseFrom(tf_buffer.AsSpan<byte>());
223225

224226
var oneof_value = x.ValueCase;
225227
if (oneof_value == AttrValue.ValueOneofCase.None)

src/TensorFlowNET.Core/Tensors/tensor_util.cs

+40-8
Original file line numberDiff line numberDiff line change
@@ -64,36 +64,68 @@ public static NDArray MakeNdarray(TensorProto tensor)
6464
var num_elements = shape.size;
6565
var tensor_dtype = tensor.Dtype.as_tf_dtype();
6666

67+
T[] ExpandArrayToSize<T>(IList<T> src)
68+
{
69+
if(src.Count == 0)
70+
{
71+
return new T[0];
72+
}
73+
var pad_count = num_elements - src.Count;
74+
var pre = pad_count / 2;
75+
var after = pad_count - pre;
76+
var first_elem = src[0];
77+
var last_elem = src[src.Count - 1];
78+
T[] res = new T[num_elements];
79+
for(long i = 0; i < num_elements; i++)
80+
{
81+
if (i < pre) res[i] = first_elem;
82+
else if (i >= num_elements - after) res[i] = last_elem;
83+
else res[i] = src[(int)(i - pre)];
84+
}
85+
return res;
86+
}
87+
6788
if (shape.ndim > 0 && tensor.TensorContent.Length > 0)
6889
{
6990
return np.frombuffer(tensor.TensorContent.ToByteArray(), shape, tensor_dtype);
7091
}
71-
else if (tensor.Dtype == DataType.DtHalf || tensor.Dtype == DataType.DtBfloat16)
92+
NDArray values;
93+
if (tensor.Dtype == DataType.DtHalf || tensor.Dtype == DataType.DtBfloat16)
7294
{
73-
return np.array(tensor.HalfVal.ToArray()).reshape(shape);
95+
values = np.array(ExpandArrayToSize(tensor.HalfVal));
7496
}
7597
else if (tensor.Dtype == DataType.DtFloat)
7698
{
77-
return np.array(tensor.FloatVal.ToArray()).reshape(shape);
99+
values = np.array(ExpandArrayToSize(tensor.FloatVal));
78100
}
79101
else if (new DataType[] { DataType.DtInt32, DataType.DtUint8 }.Contains(tensor.Dtype))
80102
{
81-
return np.array(tensor.IntVal.ToArray()).reshape(shape);
103+
values = np.array(ExpandArrayToSize(tensor.IntVal));
82104
}
83105
else if (new DataType[] { DataType.DtInt64 }.Contains(tensor.Dtype))
84106
{
85-
return np.array(tensor.Int64Val.ToArray()).reshape(shape);
107+
values = np.array(ExpandArrayToSize(tensor.Int64Val));
86108
}
87109
else if (new DataType[] { DataType.DtUint64 }.Contains(tensor.Dtype))
88110
{
89-
return np.array(tensor.Uint64Val.ToArray()).reshape(shape);
111+
values = np.array(ExpandArrayToSize(tensor.Uint64Val));
90112
}
91113
else if (tensor.Dtype == DataType.DtBool)
92114
{
93-
return np.array(tensor.BoolVal.ToArray()).reshape(shape);
115+
values = np.array(ExpandArrayToSize(tensor.BoolVal));
116+
}
117+
else
118+
{
119+
throw new TypeError($"Unsupported tensor type: {tensor.Dtype}. See " +
120+
$"https://www.tensorflow.org/api_docs/python/tf/dtypes for supported TF dtypes.");
121+
}
122+
123+
if(values.size == 0)
124+
{
125+
return np.zeros(shape, tensor_dtype);
94126
}
95127

96-
throw new NotImplementedException("MakeNdarray");
128+
return values.reshape(shape);
97129
}
98130

99131
private static readonly TF_DataType[] quantized_types = new TF_DataType[]

src/TensorFlowNET.Core/Trackables/TrackableConstant.cs

+15-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Google.Protobuf.Collections;
22
using Tensorflow.Train;
3+
using static Tensorflow.Binding;
34

45
namespace Tensorflow.Trackables;
56

@@ -11,12 +12,23 @@ public TrackableConstant(Tensor constant)
1112
_constant = constant;
1213
}
1314

14-
public static (Trackable, Action<object, object, object>) deserialize_from_proto(SavedObject object_proto,
15+
public static (Tensor, Action<object, object, object>) deserialize_from_proto(SavedObject object_proto,
1516
Dictionary<string, MapField<string, AttrValue>> operation_attributes)
1617
{
1718
var tensor_proto = operation_attributes[object_proto.Constant.Operation]["value"].Tensor;
1819
var ndarray = tensor_util.MakeNdarray(tensor_proto);
19-
var imported_constant = constant_op.constant(ndarray);
20-
return (new TrackableConstant(imported_constant), null);
20+
Tensor imported_constant;
21+
if (tensor_proto.Dtype == DataType.DtString)
22+
{
23+
imported_constant = tf_with(ops.device("CPU"), _ =>
24+
{
25+
return constant_op.constant(ndarray);
26+
});
27+
}
28+
else
29+
{
30+
imported_constant = constant_op.constant(ndarray);
31+
}
32+
return (imported_constant, null);
2133
}
2234
}

src/TensorFlowNET.Core/Training/Saving/SavedModel/RevivedTypes.cs

+5
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,9 @@ public static (Trackable, Action<object, object, object>) deserialize(SavedUserO
4646
return (null, null);
4747
}
4848
}
49+
50+
public static void RegisterRevivedTypeCreator(string identifier, ITrackableWrapper obj)
51+
{
52+
_registered_revived_creator[identifier] = obj;
53+
}
4954
}

src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ public List<ConcreteFunction> get_concrete_resource_initializers()
137137
/// </summary>
138138
public List<int> dependency_sorted_node_ids()
139139
{
140-
Dictionary<int, IEnumerable<int>> dependency_map = new();
140+
Dictionary<int, List<int>> dependency_map = new();
141141
foreach (var node in _nodes)
142142
{
143143
var node_id = _node_ids[node];

src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs

+41-21
Original file line numberDiff line numberDiff line change
@@ -116,17 +116,23 @@ public static Dictionary<string, ConcreteFunction> load_function_def_library(Fun
116116
}
117117

118118
Dictionary<string, ConcreteFunction> loaded_gradients = new();
119-
foreach (var fdef in _sort_function_defs(library, function_deps))
119+
// Debug(Rinne)
120+
var temp = _sort_function_defs(library, function_deps);
121+
int i = 0;
122+
foreach (var fdef in temp)
120123
{
124+
i++;
121125
var orig_name = _fix_fdef_in_place(fdef, functions, load_shared_name_suffix, new_gradient_op_types);
122126

123127
object structured_input_signature = null;
124128
object structured_outputs = null;
125129
if (saved_object_graph is not null && saved_object_graph.ConcreteFunctions.ContainsKey(orig_name))
126130
{
127-
var proto = saved_object_graph.ConcreteFunctions[orig_name];
128-
structured_input_signature = nested_structure_coder.decode_proto(proto.CanonicalizedInputSignature);
129-
structured_outputs = nested_structure_coder.decode_proto(proto.OutputSignature);
131+
// TODO(Rinne): deal with structured_input_signature and structured_outputs.
132+
133+
//var proto = saved_object_graph.ConcreteFunctions[orig_name];
134+
//structured_input_signature = nested_structure_coder.decode_proto(proto.CanonicalizedInputSignature);
135+
//structured_outputs = nested_structure_coder.decode_proto(proto.OutputSignature);
130136
}
131137

132138
graph.as_default();
@@ -234,27 +240,41 @@ private static Func<Operation, Tensor[], Tensor[]> _gen_gradient_func(ConcreteFu
234240

235241
private static void _restore_gradient_functions(FuncGraph func_graph, Dictionary<string, ConcreteFunction> renamed_functions, Dictionary<string, ConcreteFunction> loaded_gradients)
236242
{
237-
foreach(var op in func_graph.get_operations())
243+
if(loaded_gradients is null || loaded_gradients.Count == 0)
238244
{
239-
if(op.op.type == "StatefulPartitionedCall" || op.op.type == "PartitionedCall")
240-
{
241-
var function = renamed_functions[op.op.node_def.Attr["f"].Func.Name];
242-
op.op._gradient_function = function._get_gradient_function();
243-
}
244-
string gradient_op_type = null;
245-
try
246-
{
247-
gradient_op_type = op.op.get_attr("_gradient_op_type") as string;
248-
}
249-
catch(InvalidArgumentError)
245+
foreach (var op in func_graph.get_operations())
250246
{
251-
continue;
247+
if (op.op.type == "StatefulPartitionedCall" || op.op.type == "PartitionedCall")
248+
{
249+
var function = renamed_functions[op.op.node_def.Attr["f"].Func.Name];
250+
op.op._gradient_function = function._get_gradient_function();
251+
}
252252
}
253-
if (loaded_gradients.ContainsKey(gradient_op_type))
253+
}
254+
else
255+
{
256+
foreach (var op in func_graph.get_operations())
254257
{
255-
var grad_fn = loaded_gradients[gradient_op_type];
256-
grad_fn.NumPositionArgs = op.op.inputs.Length;
257-
grad_fn.ArgKeywords = op.op.inputs._inputs.Select(x => x.name);
258+
if (op.op.type == "StatefulPartitionedCall" || op.op.type == "PartitionedCall")
259+
{
260+
var function = renamed_functions[op.op.node_def.Attr["f"].Func.Name];
261+
op.op._gradient_function = function._get_gradient_function();
262+
}
263+
string gradient_op_type = null;
264+
try
265+
{
266+
gradient_op_type = op.op.get_attr("_gradient_op_type") as string;
267+
}
268+
catch (InvalidArgumentError)
269+
{
270+
continue;
271+
}
272+
if (loaded_gradients.ContainsKey(gradient_op_type))
273+
{
274+
var grad_fn = loaded_gradients[gradient_op_type];
275+
grad_fn.NumPositionArgs = op.op.inputs.Length;
276+
grad_fn.ArgKeywords = op.op.inputs._inputs.Select(x => x.name);
277+
}
258278
}
259279
}
260280
}

0 commit comments

Comments
 (0)