Skip to content

Commit e72024b

Browse files
authored
Merge pull request #1030 from Wanglongzhi2001/master
Add set_weights and get_weighst APIs
2 parents 793ec4a + cd54e0f commit e72024b

File tree

3 files changed

+67
-3
lines changed

3 files changed

+67
-3
lines changed

src/TensorFlowNET.Core/Keras/Layers/ILayer.cs

+3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Tensorflow.Keras.Engine;
22
using Tensorflow.Keras.Saving;
3+
using Tensorflow.NumPy;
34
using Tensorflow.Training;
45

56
namespace Tensorflow.Keras
@@ -18,6 +19,8 @@ public interface ILayer: IWithTrackable, IKerasConfigable
1819
List<IVariableV1> TrainableWeights { get; }
1920
List<IVariableV1> NonTrainableWeights { get; }
2021
List<IVariableV1> Weights { get; set; }
22+
void set_weights(IEnumerable<NDArray> weights);
23+
List<NDArray> get_weights();
2124
Shape OutputShape { get; }
2225
Shape BatchInputShape { get; }
2326
TensorShapeConfig BuildInputShape { get; }

src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs

+5-3
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License.
2121
using Tensorflow.Keras.ArgsDefinition.Rnn;
2222
using Tensorflow.Keras.Engine;
2323
using Tensorflow.Keras.Saving;
24+
using Tensorflow.NumPy;
2425
using Tensorflow.Operations;
2526
using Tensorflow.Train;
2627
using Tensorflow.Util;
@@ -71,7 +72,10 @@ public abstract class RnnCell : ILayer, RNNArgs.IRnnArgCell
7172

7273
public List<IVariableV1> TrainableVariables => throw new NotImplementedException();
7374
public List<IVariableV1> TrainableWeights => throw new NotImplementedException();
74-
public List<IVariableV1> Weights => throw new NotImplementedException();
75+
public List<IVariableV1> Weights { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }
76+
77+
public List<NDArray> get_weights() => throw new NotImplementedException();
78+
public void set_weights(IEnumerable<NDArray> weights) => throw new NotImplementedException();
7579
public List<IVariableV1> NonTrainableWeights => throw new NotImplementedException();
7680

7781
public Shape OutputShape => throw new NotImplementedException();
@@ -84,8 +88,6 @@ public abstract class RnnCell : ILayer, RNNArgs.IRnnArgCell
8488
protected bool built = false;
8589
public bool Built => built;
8690

87-
List<IVariableV1> ILayer.Weights { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }
88-
8991
public RnnCell(bool trainable = true,
9092
string name = null,
9193
TF_DataType dtype = TF_DataType.DtInvalid,

src/TensorFlowNET.Keras/Engine/Layer.cs

+59
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ limitations under the License.
3030
using Tensorflow.Training.Saving.SavedModel;
3131
using Tensorflow.Util;
3232
using static Tensorflow.Binding;
33+
using Tensorflow.Framework;
34+
using Tensorflow.Sessions;
35+
3336

3437
namespace Tensorflow.Keras.Engine
3538
{
@@ -134,6 +137,62 @@ public virtual List<IVariableV1> Weights
134137
}
135138
}
136139

140+
public virtual void set_weights(IEnumerable<NDArray> weights)
141+
{
142+
if (Weights.Count() != weights.Count()) throw new ValueError(
143+
$"You called `set_weights` on layer \"{this.name}\"" +
144+
$"with a weight list of length {len(weights)}, but the layer was " +
145+
$"expecting {len(Weights)} weights.");
146+
147+
148+
149+
// check if the shapes are compatible
150+
var weight_index = 0;
151+
foreach(var w in weights)
152+
{
153+
if (!Weights[weight_index].AsTensor().is_compatible_with(w))
154+
{
155+
throw new ValueError($"Layer weight shape {w.shape} not compatible with provided weight shape {Weights[weight_index].shape}");
156+
}
157+
weight_index++;
158+
}
159+
160+
if (tf.executing_eagerly())
161+
{
162+
foreach (var (this_w, v_w) in zip(Weights, weights))
163+
this_w.assign(v_w, read_value: true);
164+
}
165+
else
166+
{
167+
// TODO(Wanglongzhi2001):seems like there exist some bug in graph mode when define model, so uncomment the following when it fixed.
168+
169+
//Tensors assign_ops = new Tensors();
170+
//var feed_dict = new FeedDict();
171+
172+
//Graph g = tf.Graph().as_default();
173+
//foreach (var (this_w, v_w) in zip(Weights, weights))
174+
//{
175+
// var tf_dtype = this_w.dtype;
176+
// var placeholder_shape = v_w.shape;
177+
// var assign_placeholder = tf.placeholder(tf_dtype, placeholder_shape);
178+
// var assign_op = this_w.assign(assign_placeholder);
179+
// assign_ops.Add(assign_op);
180+
// feed_dict.Add(assign_placeholder, v_w);
181+
//}
182+
//var sess = tf.Session().as_default();
183+
//sess.run(assign_ops, feed_dict);
184+
185+
//g.Exit();
186+
}
187+
}
188+
189+
public List<NDArray> get_weights()
190+
{
191+
List<NDArray > weights = new List<NDArray>();
192+
weights.AddRange(Weights.ConvertAll(x => x.numpy()));
193+
return weights;
194+
}
195+
137196
protected int id;
138197
public int Id => id;
139198
protected string name;

0 commit comments

Comments
 (0)