Skip to content

Commit f16902d

Browse files
authored
Merge pull request #1188 from hchen2020/master
Allow Model to cache weights.
2 parents 0ee9d42 + 0f02885 commit f16902d

File tree

2 files changed

+36
-3
lines changed

2 files changed

+36
-3
lines changed

src/TensorFlowNET.Keras/Engine/Model.Training.cs

+34-1
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,38 @@ namespace Tensorflow.Keras.Engine
1010
{
1111
public partial class Model
1212
{
13+
static Dictionary<string, List<(string, NDArray)>> weightsCache
14+
= new Dictionary<string, List<(string, NDArray)>>();
15+
1316
public void load_weights(string filepath, bool by_name = false, bool skip_mismatch = false, object options = null)
1417
{
18+
// Get from cache
19+
if (weightsCache.ContainsKey(filepath))
20+
{
21+
var filtered_layers = new List<ILayer>();
22+
foreach (var layer in Layers)
23+
{
24+
var weights = hdf5_format._legacy_weights(layer);
25+
if (weights.Count > 0)
26+
filtered_layers.append(layer);
27+
}
28+
29+
var weight_value_tuples = new List<(IVariableV1, NDArray)>();
30+
filtered_layers.Select((layer, i) =>
31+
{
32+
var symbolic_weights = hdf5_format._legacy_weights(layer);
33+
foreach(var weight in symbolic_weights)
34+
{
35+
var weight_value = weightsCache[filepath].First(x => x.Item1 == weight.Name).Item2;
36+
weight_value_tuples.Add((weight, weight_value));
37+
}
38+
return layer;
39+
}).ToList();
40+
41+
keras.backend.batch_set_value(weight_value_tuples);
42+
return;
43+
}
44+
1545
long fileId = Hdf5.OpenFile(filepath, true);
1646
if(fileId < 0)
1747
{
@@ -29,8 +59,11 @@ public void load_weights(string filepath, bool by_name = false, bool skip_mismat
2959
throw new NotImplementedException("");
3060
else
3161
{
32-
hdf5_format.load_weights_from_hdf5_group(fileId, Layers);
62+
var weight_value_tuples = hdf5_format.load_weights_from_hdf5_group(fileId, Layers);
3363
Hdf5.CloseFile(fileId);
64+
65+
weightsCache[filepath] = weight_value_tuples.Select(x => (x.Item1.Name, x.Item2)).ToList();
66+
keras.backend.batch_set_value(weight_value_tuples);
3467
}
3568
}
3669

src/TensorFlowNET.Keras/Saving/hdf5_format.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ public static void load_optimizer_weights_from_hdf5_group(long filepath = -1, Di
8282

8383
}
8484

85-
public static void load_weights_from_hdf5_group(long f, List<ILayer> layers)
85+
public static List<(IVariableV1, NDArray)> load_weights_from_hdf5_group(long f, List<ILayer> layers)
8686
{
8787
string original_keras_version = "2.5.0";
8888
string original_backend = null;
@@ -152,7 +152,7 @@ public static void load_weights_from_hdf5_group(long f, List<ILayer> layers)
152152
weight_value_tuples.AddRange(zip(symbolic_weights, weight_values));
153153
}
154154

155-
keras.backend.batch_set_value(weight_value_tuples);
155+
return weight_value_tuples;
156156
}
157157

158158
public static void toarrayf4(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false)

0 commit comments

Comments
 (0)