diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs index 2b864f902..55409df36 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs @@ -17,7 +17,7 @@ public interface ILayer: IWithTrackable, IKerasConfigable List<IVariableV1> TrainableVariables { get; } List<IVariableV1> TrainableWeights { get; } List<IVariableV1> NonTrainableWeights { get; } - List<IVariableV1> Weights { get; } + List<IVariableV1> Weights { get; set} Shape OutputShape { get; } Shape BatchInputShape { get; } TensorShapeConfig BuildInputShape { get; } diff --git a/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs b/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs index 0aa5006c2..cba621fae 100644 --- a/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs +++ b/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs @@ -77,7 +77,7 @@ public void on_epoch_end(int epoch, Dictionary<string, float> epoch_logs) // Restore the weights after first epoch if no progress is ever made. if (_restore_best_weights && _best_weights == null) { - _best_weights = _parameters.Model.TrainableWeights; + _best_weights = _parameters.Model.Weights; } _wait += 1; @@ -103,9 +103,7 @@ public void on_epoch_end(int epoch, Dictionary<string, float> epoch_logs) Console.WriteLine($"Restoring model weights from the end of the best epoch: {_best_epoch + 1}"); } } - // Because loading the weight variable into the model has not yet been implemented, so Earlystopping can't load best_weight yet. - // TODO(Wanglongzhi2001): implement it. - // _parameters.Model.load_weights(best_weights); + _parameters.Model.Weights = _best_weights; } } public void on_train_end()