Skip to content

Commit 52b513d

Browse files
Support loading of SavedModel format (#989)
* Add CheckpointReader and corresponding C APIs. * Add essential components of SavedModel format loading. * Add checkpoint reading for SavedModel format loading. * Revise customized json converters. * Add support for loading models from python. * Fix the duplicated weights in Keras.Model. * Add alexnet loading test and check for loaded weights. * Fix ci error caused by branch merge. * Resolve the comments and errors. * Fix the stucking of training when loading model. * Fix the stucking of training when loading model. * fix intptr. --------- Co-authored-by: Haiping Chen <[email protected]>
1 parent 45f2626 commit 52b513d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

70 files changed

+3118
-209
lines changed

src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,4 +149,22 @@ public static void add_checkpoint_values_check(TrackableObjectGraph object_graph
149149
// object_graph_proto.Nodes[i].has_checkpoint_values.value = checkpointed_trackables.Contains(i);
150150
// }
151151
}
152+
153+
/// <summary>
154+
/// Traverse the object graph and list all accessible objects.
155+
/// </summary>
156+
/// <param name="object_graph_view"></param>
157+
public static IList<Trackable> list_objects(ObjectGraphView graph_view)
158+
{
159+
return objects_ids_and_slot_variables_and_paths(graph_view).Item1;
160+
}
161+
162+
internal static IEnumerable<Trackable> _objects_with_attributes(IEnumerable<Trackable> full_list)
163+
{
164+
return full_list.TakeWhile(x =>
165+
{
166+
var saveables = x.gather_saveables_for_checkpoint();
167+
return saveables is not null && saveables.Count > 0;
168+
});
169+
}
152170
}
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
using Tensorflow.Util;
2+
3+
namespace Tensorflow.Checkpoint
4+
{
5+
sealed class SafeCheckpointReaderHandle : SafeTensorflowHandle
6+
{
7+
public SafeCheckpointReaderHandle(): base()
8+
{
9+
10+
}
11+
public SafeCheckpointReaderHandle(IntPtr handle): base(handle)
12+
{
13+
14+
}
15+
16+
protected override bool ReleaseHandle()
17+
{
18+
c_api.TF_DeleteCheckpointReader(handle);
19+
SetHandle(IntPtr.Zero);
20+
return true;
21+
}
22+
}
23+
public class CheckpointReader
24+
{
25+
private SafeCheckpointReaderHandle _handle;
26+
public Dictionary<string, TF_DataType> VariableToDataTypeMap { get; set; }
27+
public Dictionary<string, Shape> VariableToShapeMap { get; set; }
28+
29+
public CheckpointReader(string filename)
30+
{
31+
Status status = new Status();
32+
_handle = c_api.TF_NewCheckpointReader(filename, status.Handle);
33+
status.Check(true);
34+
ReadAllShapeAndType();
35+
}
36+
37+
public int HasTensor(string name)
38+
{
39+
return c_api.TF_CheckpointReaderHasTensor(_handle, name);
40+
}
41+
42+
/// <summary>
43+
/// Get the variable name.
44+
/// </summary>
45+
/// <param name="index"></param>
46+
/// <returns></returns>
47+
public string GetVariable(int index)
48+
{
49+
return c_api.StringPiece(c_api.TF_CheckpointReaderGetVariable(_handle, index));
50+
}
51+
52+
public int Size()
53+
{
54+
return c_api.TF_CheckpointReaderSize(_handle);
55+
}
56+
57+
public TF_DataType GetVariableDataType(string name)
58+
{
59+
return c_api.TF_CheckpointReaderGetVariableDataType(_handle, name);
60+
}
61+
62+
public Shape GetVariableShape(string name)
63+
{
64+
int num_dims = GetVariableNumDims(name);
65+
long[] dims = new long[num_dims];
66+
Status status = new Status();
67+
c_api.TF_CheckpointReaderGetVariableShape(_handle, name, dims, num_dims, status.Handle);
68+
status.Check(true);
69+
return new Shape(dims);
70+
}
71+
72+
public int GetVariableNumDims(string name)
73+
{
74+
return c_api.TF_CheckpointReaderGetVariableNumDims(_handle, name);
75+
}
76+
77+
public unsafe Tensor GetTensor(string name, TF_DataType dtype = TF_DataType.DtInvalid)
78+
{
79+
Status status = new Status();
80+
var tensor = c_api.TF_CheckpointReaderGetTensor(_handle, name, status.Handle);
81+
status.Check(true);
82+
return new Tensor(tensor);
83+
}
84+
85+
private void ReadAllShapeAndType()
86+
{
87+
VariableToDataTypeMap = new Dictionary<string, TF_DataType>();
88+
VariableToShapeMap = new Dictionary<string, Shape>();
89+
int size = Size();
90+
for(int i = 0; i < size; i++)
91+
{
92+
var name = GetVariable(i);
93+
var shape = GetVariableShape(name);
94+
var dtype = GetVariableDataType(name);
95+
VariableToDataTypeMap[name] = dtype;
96+
VariableToShapeMap[name] = shape;
97+
}
98+
}
99+
}
100+
}

src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,9 @@ public static (IList<MySaveableObject>, object?) generate_saveable_objects(
175175
{
176176
var name = factory_data.name;
177177
var key = factory_data.checkpoint_key;
178-
var maybe_saveable = factory_data.factory;
178+
var maybe_saveable = saveable_object_util.create_saveable_object(name, key, factory_data.factory);
179179

180-
// TODO: oneflow python has a process with callable `saveable_factory`.
180+
// TODO: tensorflow python has a process with callable `saveable_factory`.
181181
List<MySaveableObject> saveables = new();
182182
if (maybe_saveable.TryGet<MySaveableObject>(out var s))
183183
{
@@ -217,7 +217,7 @@ public static (IList<MySaveableObject>, object?) generate_saveable_objects(
217217

218218
public record class CheckpointFactoryData
219219
(
220-
Maybe<BaseResourceVariable, MySaveableObject> factory,
220+
Func<string, Maybe<BaseResourceVariable, MySaveableObject>> factory,
221221
string name,
222222
string checkpoint_key
223223
);
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
using System.Runtime.InteropServices;
2+
using Tensorflow.Checkpoint;
3+
4+
namespace Tensorflow
5+
{
6+
public unsafe partial class c_api
7+
{
8+
[DllImport(TensorFlowLibName)]
9+
internal static extern SafeCheckpointReaderHandle TF_NewCheckpointReader(string filename, SafeStatusHandle status);
10+
[DllImport(TensorFlowLibName)]
11+
internal static extern void TF_DeleteCheckpointReader(IntPtr reader);
12+
[DllImport(TensorFlowLibName)]
13+
internal static extern int TF_CheckpointReaderHasTensor(SafeCheckpointReaderHandle reader, string name);
14+
[DllImport(TensorFlowLibName)]
15+
internal static extern IntPtr TF_CheckpointReaderGetVariable(SafeCheckpointReaderHandle reader, int index);
16+
[DllImport(TensorFlowLibName)]
17+
internal static extern int TF_CheckpointReaderSize(SafeCheckpointReaderHandle reader);
18+
[DllImport(TensorFlowLibName)]
19+
internal static extern TF_DataType TF_CheckpointReaderGetVariableDataType(SafeCheckpointReaderHandle reader, string name);
20+
[DllImport(TensorFlowLibName)]
21+
internal static extern void TF_CheckpointReaderGetVariableShape(SafeCheckpointReaderHandle reader, string name, long[] dims, int num_dims, SafeStatusHandle status);
22+
[DllImport(TensorFlowLibName)]
23+
internal static extern int TF_CheckpointReaderGetVariableNumDims(SafeCheckpointReaderHandle reader, string name);
24+
[DllImport(TensorFlowLibName)]
25+
internal static extern SafeTensorHandle TF_CheckpointReaderGetTensor(SafeCheckpointReaderHandle reader, string name, SafeStatusHandle status);
26+
}
27+
}

0 commit comments

Comments
 (0)