@@ -10,8 +10,38 @@ namespace Tensorflow.Keras.Engine
10
10
{
11
11
public partial class Model
12
12
{
13
+ static Dictionary < string , List < ( string , NDArray ) > > weightsCache
14
+ = new Dictionary < string , List < ( string , NDArray ) > > ( ) ;
15
+
13
16
public void load_weights ( string filepath , bool by_name = false , bool skip_mismatch = false , object options = null )
14
17
{
18
+ // Get from cache
19
+ if ( weightsCache . ContainsKey ( filepath ) )
20
+ {
21
+ var filtered_layers = new List < ILayer > ( ) ;
22
+ foreach ( var layer in Layers )
23
+ {
24
+ var weights = hdf5_format . _legacy_weights ( layer ) ;
25
+ if ( weights . Count > 0 )
26
+ filtered_layers . append ( layer ) ;
27
+ }
28
+
29
+ var weight_value_tuples = new List < ( IVariableV1 , NDArray ) > ( ) ;
30
+ filtered_layers . Select ( ( layer , i ) =>
31
+ {
32
+ var symbolic_weights = hdf5_format . _legacy_weights ( layer ) ;
33
+ foreach ( var weight in symbolic_weights )
34
+ {
35
+ var weight_value = weightsCache [ filepath ] . First ( x => x . Item1 == weight . Name ) . Item2 ;
36
+ weight_value_tuples . Add ( ( weight , weight_value ) ) ;
37
+ }
38
+ return layer ;
39
+ } ) . ToList ( ) ;
40
+
41
+ keras . backend . batch_set_value ( weight_value_tuples ) ;
42
+ return ;
43
+ }
44
+
15
45
long fileId = Hdf5 . OpenFile ( filepath , true ) ;
16
46
if ( fileId < 0 )
17
47
{
@@ -29,8 +59,11 @@ public void load_weights(string filepath, bool by_name = false, bool skip_mismat
29
59
throw new NotImplementedException ( "" ) ;
30
60
else
31
61
{
32
- hdf5_format . load_weights_from_hdf5_group ( fileId , Layers ) ;
62
+ var weight_value_tuples = hdf5_format . load_weights_from_hdf5_group ( fileId , Layers ) ;
33
63
Hdf5 . CloseFile ( fileId ) ;
64
+
65
+ weightsCache [ filepath ] = weight_value_tuples . Select ( x => ( x . Item1 . Name , x . Item2 ) ) . ToList ( ) ;
66
+ keras . backend . batch_set_value ( weight_value_tuples ) ;
34
67
}
35
68
}
36
69
0 commit comments