Skip to content

Add set_weights and get_weighst APIs #1030

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/TensorFlowNET.Core/Keras/Layers/ILayer.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.NumPy;
using Tensorflow.Training;

namespace Tensorflow.Keras
Expand All @@ -18,6 +19,8 @@ public interface ILayer: IWithTrackable, IKerasConfigable
List<IVariableV1> TrainableWeights { get; }
List<IVariableV1> NonTrainableWeights { get; }
List<IVariableV1> Weights { get; set; }
void set_weights(IEnumerable<NDArray> weights);
List<NDArray> get_weights();
Shape OutputShape { get; }
Shape BatchInputShape { get; }
TensorShapeConfig BuildInputShape { get; }
Expand Down
8 changes: 5 additions & 3 deletions src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.NumPy;
using Tensorflow.Operations;
using Tensorflow.Train;
using Tensorflow.Util;
Expand Down Expand Up @@ -71,7 +72,10 @@ public abstract class RnnCell : ILayer, RNNArgs.IRnnArgCell

public List<IVariableV1> TrainableVariables => throw new NotImplementedException();
public List<IVariableV1> TrainableWeights => throw new NotImplementedException();
public List<IVariableV1> Weights => throw new NotImplementedException();
public List<IVariableV1> Weights { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }

public List<NDArray> get_weights() => throw new NotImplementedException();
public void set_weights(IEnumerable<NDArray> weights) => throw new NotImplementedException();
public List<IVariableV1> NonTrainableWeights => throw new NotImplementedException();

public Shape OutputShape => throw new NotImplementedException();
Expand All @@ -84,8 +88,6 @@ public abstract class RnnCell : ILayer, RNNArgs.IRnnArgCell
protected bool built = false;
public bool Built => built;

List<IVariableV1> ILayer.Weights { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }

public RnnCell(bool trainable = true,
string name = null,
TF_DataType dtype = TF_DataType.DtInvalid,
Expand Down
59 changes: 59 additions & 0 deletions src/TensorFlowNET.Keras/Engine/Layer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ limitations under the License.
using Tensorflow.Training.Saving.SavedModel;
using Tensorflow.Util;
using static Tensorflow.Binding;
using Tensorflow.Framework;
using Tensorflow.Sessions;


namespace Tensorflow.Keras.Engine
{
Expand Down Expand Up @@ -134,6 +137,62 @@ public virtual List<IVariableV1> Weights
}
}

public virtual void set_weights(IEnumerable<NDArray> weights)
{
if (Weights.Count() != weights.Count()) throw new ValueError(
$"You called `set_weights` on layer \"{this.name}\"" +
$"with a weight list of length {len(weights)}, but the layer was " +
$"expecting {len(Weights)} weights.");



// check if the shapes are compatible
var weight_index = 0;
foreach(var w in weights)
{
if (!Weights[weight_index].AsTensor().is_compatible_with(w))
{
throw new ValueError($"Layer weight shape {w.shape} not compatible with provided weight shape {Weights[weight_index].shape}");
}
weight_index++;
}

if (tf.executing_eagerly())
{
foreach (var (this_w, v_w) in zip(Weights, weights))
this_w.assign(v_w, read_value: true);
}
else
{
// TODO(Wanglongzhi2001):seems like there exist some bug in graph mode when define model, so uncomment the following when it fixed.

//Tensors assign_ops = new Tensors();
//var feed_dict = new FeedDict();

//Graph g = tf.Graph().as_default();
//foreach (var (this_w, v_w) in zip(Weights, weights))
//{
// var tf_dtype = this_w.dtype;
// var placeholder_shape = v_w.shape;
// var assign_placeholder = tf.placeholder(tf_dtype, placeholder_shape);
// var assign_op = this_w.assign(assign_placeholder);
// assign_ops.Add(assign_op);
// feed_dict.Add(assign_placeholder, v_w);
//}
//var sess = tf.Session().as_default();
//sess.run(assign_ops, feed_dict);

//g.Exit();
}
}

public List<NDArray> get_weights()
{
List<NDArray > weights = new List<NDArray>();
weights.AddRange(Weights.ConvertAll(x => x.numpy()));
return weights;
}

protected int id;
public int Id => id;
protected string name;
Expand Down