From 13e986d4277f7d754dd9c11ae8be175fd86820e8 Mon Sep 17 00:00:00 2001 From: AsakusaRinne Date: Tue, 2 May 2023 03:58:37 +0800 Subject: [PATCH] fix: partially fix the error when saving model after loading. --- .../Checkpoint/CheckPointUtils.cs | 3 +- .../Saving/SavedModel/SaveableView.cs | 19 ++++++----- .../Variables/BaseResourceVariable.cs | 9 ++++- .../Saving/KerasObjectLoader.cs | 33 +++++++++++++++++++ .../Saving/SavedModel/load.cs | 2 +- .../Model/ModelSaveTest.cs | 12 +++++++ 6 files changed, 66 insertions(+), 12 deletions(-) diff --git a/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs b/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs index 490c284b7..071b41875 100644 --- a/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs +++ b/src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs @@ -3,6 +3,7 @@ using System.Diagnostics; using System.IO; using System.Linq; +using Tensorflow.Functions; using Tensorflow.Train; using Tensorflow.Training; using pbc = global::Google.Protobuf.Collections; @@ -13,7 +14,7 @@ public static class CheckPointUtils { private static string _ESCAPE_CHAR = "."; public static (IList, IDictionary>, IDictionary, - IDictionary>, + IDictionary>, IDictionary) objects_ids_and_slot_variables_and_paths(ObjectGraphView graph_view) { var (trackable_objects, node_paths) = graph_view.breadth_first_traversal(); diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs index b7d987e71..44a627b67 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/SaveableView.cs @@ -93,13 +93,14 @@ private void initialize_nodes_and_concrete_functions() // // } - foreach (var obj in _nodes) - { - if (obj is ConcreteFunction) - { - _concrete_functions.Add((ConcreteFunction)obj); - } - } + //_concrete_functions = new(); + //foreach (var obj in _nodes) + //{ + // if (obj is ConcreteFunction) + // { + // _concrete_functions.Add((ConcreteFunction)obj); + // } + //} } public List get_concrete_resource_initializers() @@ -225,8 +226,8 @@ private static void write_object_proto(Trackable obj, SavedObject proto, } else if (obj is ConcreteFunction) { - // TODO: complete it. - throw new NotImplementedException(); + // TODO(Rinne): complete it. + // throw new NotImplementedException(); } // skip the process of type `_CapturedTensor` and `CapturableResource`. else diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs index 64fe0ec84..52ca328e3 100644 --- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs @@ -17,7 +17,14 @@ public class BaseResourceVariable : DisposableTrackableObject { protected string _name; public virtual string Name => _handle_name; - public virtual string SharedName => _name; + public virtual string SharedName + { + get + { + // TODO(Rinne): optimize the implementation with refactor of variable. + return _handle_name.Substring(0, _handle_name.IndexOf(':') + 1); + } + } protected TF_DataType _dtype; public TF_DataType dtype => _dtype; protected string _handle_name; diff --git a/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs b/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs index fee987294..a26879e0c 100644 --- a/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs +++ b/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs @@ -152,6 +152,39 @@ public void finalize_objects() _reconstruct_all_models(); } + /// + /// Removes tracked references that are only used when loading the model. + /// Now that the node object has been fully loaded, and the checkpoint has + /// been restored, the object no longer needs to track objects added from + /// SerializedAttributes. (Note that saving a training checkpoint still + /// functions correctly, because layers and variables are tracked + /// separately by the Layer object.) + /// + public void del_tracking() + { + foreach(var (node, _) in loaded_nodes.Values) + { + if(node is not Layer layer) + { + continue; + } + foreach(var name in PUBLIC_ATTRIBUTES.Keys) + { + layer._delete_tracking(name); + } + if(node is Functional functional) + { + foreach(var name in functional.UnconditionalDependencyNames.Keys) + { + if(Regex.Match(name, @"^layer(_with_weights)?-[\d+]").Success) + { + functional._delete_tracking(name); + } + } + } + } + } + private void _reconstruct_all_models() { HashSet all_initialized_models = new(); diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/load.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/load.cs index 362464d1f..aa763fc2e 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/load.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/load.cs @@ -77,7 +77,7 @@ private static Trackable load(string path, bool compile = true, LoadOptions? opt var loaded = Loader.load_partial(path, nodes_to_load, options); keras_loader.finalize_objects(); - // keras_loader.del_tracking(); + keras_loader.del_tracking(); var model = loaded["root"]; diff --git a/test/TensorFlowNET.Keras.UnitTest/Model/ModelSaveTest.cs b/test/TensorFlowNET.Keras.UnitTest/Model/ModelSaveTest.cs index 19b59d821..0854a09da 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Model/ModelSaveTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Model/ModelSaveTest.cs @@ -196,5 +196,17 @@ public void AlexnetFromSequential() // ) #endregion } + + [TestMethod] + public void SaveAfterLoad() + { + var model = tf.keras.models.load_model(@"Assets/simple_model_from_auto_compile"); + model.summary(); + + model.save("Assets/saved_auto_compile_after_loading"); + + //model = tf.keras.models.load_model(@"Assets/saved_auto_compile_after_loading"); + //model.summary(); + } } }