Skip to content

Fix the error of loading model saved before tf2.5. #1032

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 21, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/TensorFlowNET.Core/APIs/tf.saved_model.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Train;

namespace Tensorflow
{
public partial class tensorflow
{
public SavedModelAPI saved_model { get; } = new SavedModelAPI();
}

public class SavedModelAPI
{
public Trackable load(string export_dir, LoadOptions? options = null)
{
return Loader.load(export_dir, options);
}
}
}
3 changes: 2 additions & 1 deletion src/TensorFlowNET.Core/Graphs/FuncGraph.cs
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@
using Tensorflow.Framework;
using Tensorflow.Framework.Models;
using Tensorflow.Functions;
using Tensorflow.NumPy;
using Tensorflow.Operations;
using Tensorflow.Util;
using static Tensorflow.Binding;
@@ -181,7 +182,7 @@ public override Operation create_op(string op_type, Tensor[] inputs, TF_DataType
const int _EAGER_CONST_THRESHOLD = 128;
public Tensor capture(Tensor tensor, string name = null, Shape shape = null)
{
if(tensor is EagerTensor)
if(tensor is EagerTensor or NDArray)
{
if (name == null)
name = ops.uid().ToString();
1 change: 1 addition & 0 deletions src/TensorFlowNET.Core/Keras/Engine/IOptimizer.cs
Original file line number Diff line number Diff line change
@@ -10,4 +10,5 @@ void apply_gradients((Tensor, IVariableV1) grads_and_vars,
void apply_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars,
string name = null,
bool experimental_aggregate_gradients = true);
IVariableV1 add_slot(IVariableV1 var, string slot_name, IInitializer initializer = null);
}
8 changes: 5 additions & 3 deletions src/TensorFlowNET.Core/Operations/Operation.cs
Original file line number Diff line number Diff line change
@@ -216,10 +216,12 @@ public virtual T[] get_attr_list<T>(string name)
public virtual object get_attr(string name)
{
var buf = new Buffer();
c_api.TF_OperationGetAttrValueProto(_handle, name, buf, tf.Status);
tf.Status.Check(true);
Status status = new();
c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status);
status.Check(true);
var tf_buffer = c_api.TF_GetBuffer(buf);

var x = AttrValue.Parser.ParseFrom(buf.ToArray());
var x = AttrValue.Parser.ParseFrom(tf_buffer.AsSpan<byte>());

var oneof_value = x.ValueCase;
if (oneof_value == AttrValue.ValueOneofCase.None)
48 changes: 40 additions & 8 deletions src/TensorFlowNET.Core/Tensors/tensor_util.cs
Original file line number Diff line number Diff line change
@@ -64,36 +64,68 @@ public static NDArray MakeNdarray(TensorProto tensor)
var num_elements = shape.size;
var tensor_dtype = tensor.Dtype.as_tf_dtype();

T[] ExpandArrayToSize<T>(IList<T> src)
{
if(src.Count == 0)
{
return new T[0];
}
var pad_count = num_elements - src.Count;
var pre = pad_count / 2;
var after = pad_count - pre;
var first_elem = src[0];
var last_elem = src[src.Count - 1];
T[] res = new T[num_elements];
for(long i = 0; i < num_elements; i++)
{
if (i < pre) res[i] = first_elem;
else if (i >= num_elements - after) res[i] = last_elem;
else res[i] = src[(int)(i - pre)];
}
return res;
}

if (shape.ndim > 0 && tensor.TensorContent.Length > 0)
{
return np.frombuffer(tensor.TensorContent.ToByteArray(), shape, tensor_dtype);
}
else if (tensor.Dtype == DataType.DtHalf || tensor.Dtype == DataType.DtBfloat16)
NDArray values;
if (tensor.Dtype == DataType.DtHalf || tensor.Dtype == DataType.DtBfloat16)
{
return np.array(tensor.HalfVal.ToArray()).reshape(shape);
values = np.array(ExpandArrayToSize(tensor.HalfVal));
}
else if (tensor.Dtype == DataType.DtFloat)
{
return np.array(tensor.FloatVal.ToArray()).reshape(shape);
values = np.array(ExpandArrayToSize(tensor.FloatVal));
}
else if (new DataType[] { DataType.DtInt32, DataType.DtUint8 }.Contains(tensor.Dtype))
{
return np.array(tensor.IntVal.ToArray()).reshape(shape);
values = np.array(ExpandArrayToSize(tensor.IntVal));
}
else if (new DataType[] { DataType.DtInt64 }.Contains(tensor.Dtype))
{
return np.array(tensor.Int64Val.ToArray()).reshape(shape);
values = np.array(ExpandArrayToSize(tensor.Int64Val));
}
else if (new DataType[] { DataType.DtUint64 }.Contains(tensor.Dtype))
{
return np.array(tensor.Uint64Val.ToArray()).reshape(shape);
values = np.array(ExpandArrayToSize(tensor.Uint64Val));
}
else if (tensor.Dtype == DataType.DtBool)
{
return np.array(tensor.BoolVal.ToArray()).reshape(shape);
values = np.array(ExpandArrayToSize(tensor.BoolVal));
}
else
{
throw new TypeError($"Unsupported tensor type: {tensor.Dtype}. See " +
$"https://www.tensorflow.org/api_docs/python/tf/dtypes for supported TF dtypes.");
}

if(values.size == 0)
{
return np.zeros(shape, tensor_dtype);
}

throw new NotImplementedException("MakeNdarray");
return values.reshape(shape);
}

private static readonly TF_DataType[] quantized_types = new TF_DataType[]
18 changes: 15 additions & 3 deletions src/TensorFlowNET.Core/Trackables/TrackableConstant.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Google.Protobuf.Collections;
using Tensorflow.Train;
using static Tensorflow.Binding;

namespace Tensorflow.Trackables;

@@ -11,12 +12,23 @@ public TrackableConstant(Tensor constant)
_constant = constant;
}

public static (Trackable, Action<object, object, object>) deserialize_from_proto(SavedObject object_proto,
public static (Tensor, Action<object, object, object>) deserialize_from_proto(SavedObject object_proto,
Dictionary<string, MapField<string, AttrValue>> operation_attributes)
{
var tensor_proto = operation_attributes[object_proto.Constant.Operation]["value"].Tensor;
var ndarray = tensor_util.MakeNdarray(tensor_proto);
var imported_constant = constant_op.constant(ndarray);
return (new TrackableConstant(imported_constant), null);
Tensor imported_constant;
if (tensor_proto.Dtype == DataType.DtString)
{
imported_constant = tf_with(ops.device("CPU"), _ =>
{
return constant_op.constant(ndarray);
});
}
else
{
imported_constant = constant_op.constant(ndarray);
}
return (imported_constant, null);
}
}
Original file line number Diff line number Diff line change
@@ -46,4 +46,9 @@ public static (Trackable, Action<object, object, object>) deserialize(SavedUserO
return (null, null);
}
}

public static void RegisterRevivedTypeCreator(string identifier, ITrackableWrapper obj)
{
_registered_revived_creator[identifier] = obj;
}
}
Original file line number Diff line number Diff line change
@@ -137,7 +137,7 @@ public List<ConcreteFunction> get_concrete_resource_initializers()
/// </summary>
public List<int> dependency_sorted_node_ids()
{
Dictionary<int, IEnumerable<int>> dependency_map = new();
Dictionary<int, List<int>> dependency_map = new();
foreach (var node in _nodes)
{
var node_id = _node_ids[node];
Original file line number Diff line number Diff line change
@@ -116,17 +116,23 @@ public static Dictionary<string, ConcreteFunction> load_function_def_library(Fun
}

Dictionary<string, ConcreteFunction> loaded_gradients = new();
foreach (var fdef in _sort_function_defs(library, function_deps))
// Debug(Rinne)
var temp = _sort_function_defs(library, function_deps);
int i = 0;
foreach (var fdef in temp)
{
i++;
var orig_name = _fix_fdef_in_place(fdef, functions, load_shared_name_suffix, new_gradient_op_types);

object structured_input_signature = null;
object structured_outputs = null;
if (saved_object_graph is not null && saved_object_graph.ConcreteFunctions.ContainsKey(orig_name))
{
var proto = saved_object_graph.ConcreteFunctions[orig_name];
structured_input_signature = nested_structure_coder.decode_proto(proto.CanonicalizedInputSignature);
structured_outputs = nested_structure_coder.decode_proto(proto.OutputSignature);
// TODO(Rinne): deal with structured_input_signature and structured_outputs.

//var proto = saved_object_graph.ConcreteFunctions[orig_name];
//structured_input_signature = nested_structure_coder.decode_proto(proto.CanonicalizedInputSignature);
//structured_outputs = nested_structure_coder.decode_proto(proto.OutputSignature);
}

graph.as_default();
@@ -234,27 +240,41 @@ private static Func<Operation, Tensor[], Tensor[]> _gen_gradient_func(ConcreteFu

private static void _restore_gradient_functions(FuncGraph func_graph, Dictionary<string, ConcreteFunction> renamed_functions, Dictionary<string, ConcreteFunction> loaded_gradients)
{
foreach(var op in func_graph.get_operations())
if(loaded_gradients is null || loaded_gradients.Count == 0)
{
if(op.op.type == "StatefulPartitionedCall" || op.op.type == "PartitionedCall")
{
var function = renamed_functions[op.op.node_def.Attr["f"].Func.Name];
op.op._gradient_function = function._get_gradient_function();
}
string gradient_op_type = null;
try
{
gradient_op_type = op.op.get_attr("_gradient_op_type") as string;
}
catch(InvalidArgumentError)
foreach (var op in func_graph.get_operations())
{
continue;
if (op.op.type == "StatefulPartitionedCall" || op.op.type == "PartitionedCall")
{
var function = renamed_functions[op.op.node_def.Attr["f"].Func.Name];
op.op._gradient_function = function._get_gradient_function();
}
}
if (loaded_gradients.ContainsKey(gradient_op_type))
}
else
{
foreach (var op in func_graph.get_operations())
{
var grad_fn = loaded_gradients[gradient_op_type];
grad_fn.NumPositionArgs = op.op.inputs.Length;
grad_fn.ArgKeywords = op.op.inputs._inputs.Select(x => x.name);
if (op.op.type == "StatefulPartitionedCall" || op.op.type == "PartitionedCall")
{
var function = renamed_functions[op.op.node_def.Attr["f"].Func.Name];
op.op._gradient_function = function._get_gradient_function();
}
string gradient_op_type = null;
try
{
gradient_op_type = op.op.get_attr("_gradient_op_type") as string;
}
catch (InvalidArgumentError)
{
continue;
}
if (loaded_gradients.ContainsKey(gradient_op_type))
{
var grad_fn = loaded_gradients[gradient_op_type];
grad_fn.NumPositionArgs = op.op.inputs.Length;
grad_fn.ArgKeywords = op.op.inputs._inputs.Select(x => x.name);
}
}
}
}
67 changes: 45 additions & 22 deletions src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@
using Tensorflow.Training.Saving.SavedModel;
using Tensorflow.Trackables;
using OneOf;
using Tensorflow.Keras.Engine;

namespace Tensorflow
{
@@ -34,7 +35,7 @@ public partial class Loader
private List<int>? _filtered_nodes;
private List<int> _ordered_node_ids;
private Dictionary<int, (Trackable, Action<object, object, object>)> _loaded_nodes;
private List<Trackable> _nodes;
private List<object> _nodes;
private Dictionary<int, Action<object, object, object>> _node_setters;
private Dictionary<string, ConcreteFunction> _concrete_functions;
private HashSet<string> _restored_concrete_functions;
@@ -213,7 +214,13 @@ private List<int> _generate_ordered_node_ids()
continue;
}
var proto = _proto.Nodes[node_id];
foreach(var dep in _get_node_dependencies(proto).Values.Distinct())
if(node_id == 10522)
{
// Debug(Rinne)
Console.WriteLine();
}
var temp = _get_node_dependencies(proto);
foreach (var dep in _get_node_dependencies(proto).Values.Distinct())
{
deps.Add(dep);
if(_filtered_nodes is not null && !_filtered_nodes.Contains(dep))
@@ -232,7 +239,7 @@ private List<int> _generate_ordered_node_ids()
// The optimizer and original variable must be created before the slot
// variable, since the slot variable is generated using the Optimizer's
// add_slot API.
var slot_deps = dependency_map[slot_variable_node_id];
var slot_deps = dependency_map.SetDefault(slot_variable_node_id, new List<int>());
slot_deps.Add(node_id);
slot_deps.Add(slot_variable_proto.OriginalVariableNodeId);

@@ -245,7 +252,12 @@ private List<int> _generate_ordered_node_ids()
}
try
{
return TrackableUtils.order_by_dependency(dependency_map.ToDictionary(x => x.Key, x => x.Value as IEnumerable<int>));
int total = 0;
foreach(var v in dependency_map.Values)
{
total += v.Count;
}
return TrackableUtils.order_by_dependency(dependency_map);
}
catch (TrackableUtils.CyclicDependencyError ex)
{
@@ -339,9 +351,20 @@ private void _load_checkpoint_save_and_restore_functions()
var saveable_object_proto = item.Value;
var save_fn_id = saveable_object_proto.SaveFunction;
var restore_fn_id = saveable_object_proto.RestoreFunction;
saveable_fn_by_name[name] = (get(save_fn_id), get(restore_fn_id));
saveable_fn_by_name[name] = ((Trackable)get(save_fn_id), (Trackable)get(restore_fn_id));
}
var saveable_objects = saveable_object_util.recreate_saveable_objects(saveable_fn_by_name, null);
if (saveable_objects is not null && saveable_objects.Count > 0)
{
if(node is Trackable trackable)
{
trackable.SelfSaveableObjectFactories = saveable_objects;
}
else
{
throw new TypeError();
}
}
node.SelfSaveableObjectFactories = saveable_object_util.recreate_saveable_objects(saveable_fn_by_name, null);
}
}
}
@@ -379,12 +402,12 @@ private void _load_nodes()
{
// Use the public Optimizer interface when creating slot variables.
var (optimizer_node_id, slot_variable_proto) = slot_variable_node_ids[node_id];
var optimizer_object = nodes[optimizer_node_id];
var optimizer_object = nodes[optimizer_node_id] as IOptimizer;
var optimizer_variable = nodes[slot_variable_proto.OriginalVariableNodeId];

// TODO(Rinne): implement it.
throw new NotImplementedException("The model loading of SavedModel still has some incompleted part." +
" Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues.");
var slot_variable = optimizer_object.add_slot(optimizer_variable as IVariableV1, slot_variable_proto.SlotName);
nodes[slot_variable_proto.SlotVariableNodeId] = slot_variable as Trackable;
node_setters[slot_variable_proto.SlotVariableNodeId] = setattr;
}
else
{
@@ -398,7 +421,7 @@ private void _load_nodes()
{
nodes[0] = _recreate_base_user_object().Item1;
}
_nodes = new List<Trackable>();
_nodes = new List<object>();
for(int i = 0; i < _proto.Nodes.Count; i++)
{
_nodes.Add(nodes[i]);
@@ -412,7 +435,7 @@ private void _load_nodes()
private void _restore_checkpoint()
{
var variables_path = SavedModelUtils.get_variables_path(_export_dir);
var saver = new TrackableSaver(new ObjectGraphView(get(0)));
var saver = new TrackableSaver(new ObjectGraphView((Trackable)get(0)));
tf_with(ops.device("CPU"), _ =>
{
saver.FilePrefixPlaceHolder = constant_op.constant(variables_path);
@@ -467,7 +490,7 @@ private void _load_edges()
}
}

private void _setup_function_captures(string concrete_function_name, IDictionary<OneOf<string, int>, Trackable> nodes)
private void _setup_function_captures(string concrete_function_name, IDictionary<OneOf<string, int>, object> nodes)
{
if (_restored_concrete_functions.Contains(concrete_function_name))
{
@@ -485,12 +508,12 @@ private void _setup_remaining_functions()
// TODO: implement it with concrete functions.
}

public Trackable get(int node_id)
public object get(int node_id)
{
return _nodes[node_id];
}

public Trackable get(string node_id)
public object get(string node_id)
{
return get(_node_path_to_id[node_id]);
}
@@ -512,9 +535,9 @@ private void _add_object_graph_edges(SavedObject proto, int node_id)
}
}

private (Dictionary<int, Trackable>, Dictionary<int, Action<object, object, object>>) _initialize_loaded_nodes()
private (Dictionary<int, object>, Dictionary<int, Action<object, object, object>>) _initialize_loaded_nodes()
{
Dictionary<int, Trackable> nodes = new();
Dictionary<int, object> nodes = new();
Dictionary<int, Action<object, object, object>> node_setters = new();
foreach(var item in _loaded_nodes)
{
@@ -534,10 +557,10 @@ private void _add_object_graph_edges(SavedObject proto, int node_id)
}
}

private (Trackable, Action<object, object, object>) _recreate(SavedObject proto, int node_id, IDictionary<int, Trackable> nodes)
private (object, Action<object, object, object>) _recreate(SavedObject proto, int node_id, IDictionary<int, object> nodes)
{
// skip the registered classes.
Dictionary<OneOf<string, int>, Trackable> dependencies = new();
Dictionary<OneOf<string, int>, object> dependencies = new();
foreach(var item in _get_node_dependencies(proto))
{
dependencies[item.Key] = nodes[item.Value];
@@ -558,7 +581,7 @@ private void _add_object_graph_edges(SavedObject proto, int node_id)
/// <param name="proto"></param>
/// <param name="node_id"></param>
/// <param name="dependencies"></param>
private (Trackable, Action<object, object, object>) _recreate_default(SavedObject proto, int node_id, IDictionary<OneOf<string, int>, Trackable> dependencies)
private (Trackable, Action<object, object, object>) _recreate_default(SavedObject proto, int node_id, IDictionary<OneOf<string, int>, object> dependencies)
{
return proto.KindCase switch
{
@@ -626,7 +649,7 @@ private void _add_object_graph_edges(SavedObject proto, int node_id)
}

private (Function, Action<object, object, object>) _recreate_function(SavedFunction proto,
IDictionary<OneOf<string, int>, Trackable> dependencies)
IDictionary<OneOf<string, int>, object> dependencies)
{
var fn = function_deserialization.recreate_function(proto, _concrete_functions);
foreach (var name in proto.ConcreteFunctions)
@@ -637,7 +660,7 @@ private void _add_object_graph_edges(SavedObject proto, int node_id)
}

private (ConcreteFunction, Action<object, object, object>) _recreate_bare_concrete_function(SavedBareConcreteFunction proto,
IDictionary<OneOf<string, int>, Trackable> dependencies)
IDictionary<OneOf<string, int>, object> dependencies)
{
var fn = function_deserialization.setup_bare_concrete_function(proto, _concrete_functions);
_setup_function_captures(proto.ConcreteFunctionName, dependencies);
Original file line number Diff line number Diff line change
@@ -78,7 +78,7 @@ public static IDictionary<string, Trackable> load_partial(string export_dir, IDi
tf_with(ops.init_scope(), x =>
{
loader = new Loader(object_graph_proto, saved_model_proto, export_dir, ckpt_options, options, filters);
root = loader.get(0);
root = (Trackable)loader.get(0);
// skip the assignment of `graph_debug_info`.
});
// skip the assignment of `tensorflow_version`
@@ -99,7 +99,7 @@ public static IDictionary<string, Trackable> load_partial(string export_dir, IDi
}
if(filters != null && filters.Count > 0)
{
return filters.Keys.ToDictionary(x => x, x => loader.get(x));
return filters.Keys.ToDictionary(x => x, x => (Trackable)loader.get(x));
}
else
{
4 changes: 2 additions & 2 deletions src/TensorFlowNET.Core/Training/TrackableUtils.cs
Original file line number Diff line number Diff line change
@@ -52,7 +52,7 @@ public static string checkpoint_key(string object_path, string local_name)
/// </summary>
/// <param name="dependency_map"></param>
/// <exception cref="ValueError"></exception>
public static List<int> order_by_dependency(IDictionary<int, IEnumerable<int>> dependency_map)
public static List<int> order_by_dependency(IDictionary<int, List<int>> dependency_map)
{
Dictionary<int, HashSet<int>> reverse_dependency_map = new();
foreach (var pair in dependency_map)
@@ -102,7 +102,7 @@ public static List<int> order_by_dependency(IDictionary<int, IEnumerable<int>> d
edges.Remove(x);
if (edges.Count == 0)
{
to_visit.Enqueue(dep);
to_visit.Enqueue(dep);
if (!reverse_dependency_map.Remove(dep))
{
throw new KeyError($"Cannot find the key {dep} in reverse_dependency_map");
18 changes: 18 additions & 0 deletions src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs
Original file line number Diff line number Diff line change
@@ -333,5 +333,23 @@ public Tensor read_value_no_copy()
});
return array_ops.identity(value);
}

//public static Tensor operator +(BaseResourceVariable x, int y) => x.value() + y;
//public static Tensor operator +(BaseResourceVariable x, float y) => x.value() + y;
//public static Tensor operator +(BaseResourceVariable x, double y) => x.value() + y;
//public static Tensor operator +(BaseResourceVariable x, BaseResourceVariable y) => x.value() + y.value();
//public static Tensor operator -(BaseResourceVariable x, int y) => x.value() - y;
//public static Tensor operator -(BaseResourceVariable x, float y) => x.value() - y;
//public static Tensor operator -(BaseResourceVariable x, double y) => x.value() - y;
//public static Tensor operator -(BaseResourceVariable x, Tensor y) => x.value() - y;
//public static Tensor operator -(BaseResourceVariable x, BaseResourceVariable y) => x.value() - y.value();

//public static Tensor operator *(BaseResourceVariable x, BaseResourceVariable y) => x.value() * y.value();
//public static Tensor operator *(BaseResourceVariable x, Tensor y) => x.value() * y;
//public static Tensor operator *(BaseResourceVariable x, NDArray y) => x.value() * y;

//public static Tensor operator <(BaseResourceVariable x, Tensor y) => x.value() < y;

//public static Tensor operator >(BaseResourceVariable x, Tensor y) => x.value() > y;
}
}
19 changes: 3 additions & 16 deletions src/TensorFlowNET.Core/Variables/ResourceVariable.Operators.cs
Original file line number Diff line number Diff line change
@@ -1,19 +1,6 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.NumPy;

namespace Tensorflow
6 changes: 6 additions & 0 deletions src/TensorFlowNET.Keras/BackendImpl.cs
Original file line number Diff line number Diff line change
@@ -169,6 +169,12 @@ public void set_learning_phase(bool value)
_GRAPH_LEARNING_PHASES[tf.get_default_graph()] = (GraphLearningPhase)((value) ? 1 : 0);
}

public void set_value(IVariableV1 x, object value)
{
// TODO(Rinne): check the implementation.
x.assign(value);
}

public void batch_set_value(List<(IVariableV1, NDArray)> tuples)
{
if (ops.executing_eagerly_outside_functions())
5 changes: 5 additions & 0 deletions src/TensorFlowNET.Keras/KerasInterface.cs
Original file line number Diff line number Diff line change
@@ -36,6 +36,11 @@ public static KerasInterface Instance
}
}

static KerasInterface()
{
RevivedTypes.RegisterRevivedTypeCreator("optimizer", new RestoredOptimizer());
}

public KerasDataset datasets { get; } = new KerasDataset();
public IInitializersApi initializers { get; } = new InitializersApi();
public Regularizers regularizers { get; } = new Regularizers();
8 changes: 4 additions & 4 deletions src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs
Original file line number Diff line number Diff line change
@@ -14,11 +14,11 @@ public class OptimizerV2 : Trackable, IOptimizer
protected bool _hypers_created;
protected virtual string _name { get; }

IVariableV1 _iterations;
protected IVariableV1 _iterations;
protected ResourceVariable iterations => _iterations as ResourceVariable;
List<IVariableV1> _weights;
Dictionary<string, float> _hyper;
Dictionary<string, IVariableV1> _hyper_variables;
protected Dictionary<string, float> _hyper;
protected Dictionary<string, IVariableV1> _hyper_variables;
protected bool _momentum;
protected float _initial_decay = 0.0f;
protected bool _use_locking = true;
@@ -224,7 +224,7 @@ protected virtual void _create_slots(IVariableV1[] var_list)
}
}

protected IVariableV1 add_slot(IVariableV1 var, string slot_name, IInitializer initializer = null)
public IVariableV1 add_slot(IVariableV1 var, string slot_name, IInitializer initializer = null)
{
if (initializer == null)
initializer = tf.zeros_initializer;
63 changes: 63 additions & 0 deletions src/TensorFlowNET.Keras/Optimizers/RestoredOptimizer.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.Saving;
using Tensorflow.Train;
using Tensorflow.Training;

namespace Tensorflow.Keras.Optimizers
{
public class RestoredOptimizer: OptimizerV2, ITrackableWrapper, IKerasConfig
{
public String Identifier { get; } = "optimizer";
public int Version { get; } = 2;
public int MinConsumerVersion { get; } = 1;
public int MinProducerVersion { get; } = 1;
public RestoredOptimizer(): base(new ArgsDefinition.OptimizerV2Args() { Name = "RestoredOptimizer" })
{
_hypers_created = true;
}

public IKerasConfig get_config()
{
throw new NotImplementedException("Restoring functional Optimizers from SavedModels is not currently " +
"supported. Please file a feature request if this limitation bothers you.");
}

public void SetValue(object name, object value)
{
if(name is not String str)
{
throw new TypeError($"The name of value to set must be string, but got {name.GetType()}");
}
if(value is Trackable trackable)
{
_track_trackable(trackable, str, overwrite: true);
}
if(value is IVariableV1 resource_variable)
{
if (!_hyper_variables.ContainsKey(str))
{
_hyper_variables[str] = resource_variable;
}
else
{
keras.backend.set_value(resource_variable, value);
}
}
else if (value is float f)
{
_hyper[str] = f;
}
else
{
throw new NotImplementedException();
}
}

public Trackable FromProto(SavedUserObject proto)
{
return new RestoredOptimizer();
}
}
}
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@
using System;
using System.Linq;
using Tensorflow;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Optimizers;
using Tensorflow.Keras.UnitTest.Helpers;
using Tensorflow.NumPy;
@@ -103,4 +104,13 @@ public void VGG19()

classify_model.fit(x, y, batch_size: 4);
}

[Ignore]
[TestMethod]
public void TestModelBeforeTF2_5()
{
var a = keras.layers;
var model = tf.saved_model.load(@"D:\development\temp\saved_model") as Model;
model.summary();
}
}