@@ -30,6 +30,9 @@ limitations under the License.
30
30
using Tensorflow . Training . Saving . SavedModel ;
31
31
using Tensorflow . Util ;
32
32
using static Tensorflow . Binding ;
33
+ using Tensorflow . Framework ;
34
+ using Tensorflow . Sessions ;
35
+
33
36
34
37
namespace Tensorflow . Keras . Engine
35
38
{
@@ -134,6 +137,62 @@ public virtual List<IVariableV1> Weights
134
137
}
135
138
}
136
139
140
+ public virtual void set_weights ( IEnumerable < NDArray > weights )
141
+ {
142
+ if ( Weights . Count ( ) != weights . Count ( ) ) throw new ValueError (
143
+ $ "You called `set_weights` on layer \" { this . name } \" " +
144
+ $ "with a weight list of length { len ( weights ) } , but the layer was " +
145
+ $ "expecting { len ( Weights ) } weights.") ;
146
+
147
+
148
+
149
+ // check if the shapes are compatible
150
+ var weight_index = 0 ;
151
+ foreach ( var w in weights )
152
+ {
153
+ if ( ! Weights [ weight_index ] . AsTensor ( ) . is_compatible_with ( w ) )
154
+ {
155
+ throw new ValueError ( $ "Layer weight shape { w . shape } not compatible with provided weight shape { Weights [ weight_index ] . shape } ") ;
156
+ }
157
+ weight_index ++ ;
158
+ }
159
+
160
+ if ( tf . executing_eagerly ( ) )
161
+ {
162
+ foreach ( var ( this_w , v_w ) in zip ( Weights , weights ) )
163
+ this_w . assign ( v_w , read_value : true ) ;
164
+ }
165
+ else
166
+ {
167
+ // TODO(Wanglongzhi2001):seems like there exist some bug in graph mode when define model, so uncomment the following when it fixed.
168
+
169
+ //Tensors assign_ops = new Tensors();
170
+ //var feed_dict = new FeedDict();
171
+
172
+ //Graph g = tf.Graph().as_default();
173
+ //foreach (var (this_w, v_w) in zip(Weights, weights))
174
+ //{
175
+ // var tf_dtype = this_w.dtype;
176
+ // var placeholder_shape = v_w.shape;
177
+ // var assign_placeholder = tf.placeholder(tf_dtype, placeholder_shape);
178
+ // var assign_op = this_w.assign(assign_placeholder);
179
+ // assign_ops.Add(assign_op);
180
+ // feed_dict.Add(assign_placeholder, v_w);
181
+ //}
182
+ //var sess = tf.Session().as_default();
183
+ //sess.run(assign_ops, feed_dict);
184
+
185
+ //g.Exit();
186
+ }
187
+ }
188
+
189
+ public List < NDArray > get_weights ( )
190
+ {
191
+ List < NDArray > weights = new List < NDArray > ( ) ;
192
+ weights . AddRange ( Weights . ConvertAll ( x => x . numpy ( ) ) ) ;
193
+ return weights ;
194
+ }
195
+
137
196
protected int id ;
138
197
public int Id => id ;
139
198
protected string name ;
0 commit comments