Skip to content

Commit d452d8c

Browse files
authored
Merge pull request #1144 from Wanglongzhi2001/master
fix: fix the bug of load LSTM model and add test
2 parents 3006c86 + 68772b2 commit d452d8c

34 files changed

+81
-40
lines changed

src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUCellArgs.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
using System.Collections.Generic;
44
using System.Text;
55

6-
namespace Tensorflow.Keras.ArgsDefinition.Rnn
6+
namespace Tensorflow.Keras.ArgsDefinition
77
{
88
public class GRUCellArgs : AutoSerializeLayerArgs
99
{

src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMArgs.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
namespace Tensorflow.Keras.ArgsDefinition.Rnn
1+
namespace Tensorflow.Keras.ArgsDefinition
22
{
33
public class LSTMArgs : RNNArgs
44
{

src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using Newtonsoft.Json;
22
using static Tensorflow.Binding;
33

4-
namespace Tensorflow.Keras.ArgsDefinition.Rnn
4+
namespace Tensorflow.Keras.ArgsDefinition
55
{
66
// TODO: complete the implementation
77
public class LSTMCellArgs : AutoSerializeLayerArgs

src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs

+9-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
using Newtonsoft.Json;
22
using System.Collections.Generic;
3-
using Tensorflow.Keras.Layers.Rnn;
3+
using Tensorflow.Keras.Layers;
44

5-
namespace Tensorflow.Keras.ArgsDefinition.Rnn
5+
namespace Tensorflow.Keras.ArgsDefinition
66
{
77
// TODO(Rinne): add regularizers.
88
public class RNNArgs : AutoSerializeLayerArgs
@@ -23,16 +23,22 @@ public class RNNArgs : AutoSerializeLayerArgs
2323
public int? InputDim { get; set; }
2424
public int? InputLength { get; set; }
2525
// TODO: Add `num_constants` and `zero_output_for_mask`.
26-
26+
[JsonProperty("units")]
2727
public int Units { get; set; }
28+
[JsonProperty("activation")]
2829
public Activation Activation { get; set; }
30+
[JsonProperty("recurrent_activation")]
2931
public Activation RecurrentActivation { get; set; }
32+
[JsonProperty("use_bias")]
3033
public bool UseBias { get; set; } = true;
3134
public IInitializer KernelInitializer { get; set; }
3235
public IInitializer RecurrentInitializer { get; set; }
3336
public IInitializer BiasInitializer { get; set; }
37+
[JsonProperty("dropout")]
3438
public float Dropout { get; set; } = .0f;
39+
[JsonProperty("zero_output_for_mask")]
3540
public bool ZeroOutputForMask { get; set; } = false;
41+
[JsonProperty("recurrent_dropout")]
3642
public float RecurrentDropout { get; set; } = .0f;
3743
}
3844
}

src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RnnOptionalArgs.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
using System.Text;
44
using Tensorflow.Common.Types;
55

6-
namespace Tensorflow.Keras.ArgsDefinition.Rnn
6+
namespace Tensorflow.Keras.ArgsDefinition
77
{
88
public class RnnOptionalArgs: IOptionalArgs
99
{

src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNArgs.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
namespace Tensorflow.Keras.ArgsDefinition.Rnn
1+
namespace Tensorflow.Keras.ArgsDefinition
22
{
33
public class SimpleRNNArgs : RNNArgs
44
{

src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNCellArgs.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using Newtonsoft.Json;
22

3-
namespace Tensorflow.Keras.ArgsDefinition.Rnn
3+
namespace Tensorflow.Keras.ArgsDefinition
44
{
55
public class SimpleRNNCellArgs: AutoSerializeLayerArgs
66
{

src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using System.Collections.Generic;
2-
using Tensorflow.Keras.Layers.Rnn;
2+
using Tensorflow.Keras.Layers;
33

4-
namespace Tensorflow.Keras.ArgsDefinition.Rnn
4+
namespace Tensorflow.Keras.ArgsDefinition
55
{
66
public class StackedRNNCellsArgs : LayerArgs
77
{

src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using System;
22
using Tensorflow.Framework.Models;
33
using Tensorflow.Keras.Engine;
4-
using Tensorflow.Keras.Layers.Rnn;
4+
using Tensorflow.Keras.Layers;
55
using Tensorflow.NumPy;
66
using static Google.Protobuf.Reflection.FieldDescriptorProto.Types;
77

src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
using System.Text;
44
using Tensorflow.Common.Types;
55

6-
namespace Tensorflow.Keras.Layers.Rnn
6+
namespace Tensorflow.Keras.Layers
77
{
88
public interface IRnnCell: ILayer
99
{

src/TensorFlowNET.Core/Keras/Layers/Rnn/IStackedRnnCells.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
using System.Collections.Generic;
33
using System.Text;
44

5-
namespace Tensorflow.Keras.Layers.Rnn
5+
namespace Tensorflow.Keras.Layers
66
{
77
public interface IStackedRnnCells : IRnnCell
88
{

src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs

+1-2
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,8 @@ limitations under the License.
1919
using Tensorflow.Common.Types;
2020
using Tensorflow.Keras;
2121
using Tensorflow.Keras.ArgsDefinition;
22-
using Tensorflow.Keras.ArgsDefinition.Rnn;
2322
using Tensorflow.Keras.Engine;
24-
using Tensorflow.Keras.Layers.Rnn;
23+
using Tensorflow.Keras.Layers;
2524
using Tensorflow.Keras.Saving;
2625
using Tensorflow.NumPy;
2726
using Tensorflow.Operations;

src/TensorFlowNET.Core/ops.cs

+3-1
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,9 @@ public static bool executing_eagerly_outside_functions()
571571
if (tf.Context.executing_eagerly())
572572
return true;
573573
else
574-
throw new NotImplementedException("");
574+
// TODO(Wanglongzhi2001), implement the false case
575+
return true;
576+
//throw new NotImplementedException("");
575577
}
576578

577579
public static bool inside_function()

src/TensorFlowNET.Keras/Layers/LayersApi.cs

+1-2
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22
using Tensorflow.Framework.Models;
33
using Tensorflow.Keras.ArgsDefinition;
44
using Tensorflow.Keras.ArgsDefinition.Core;
5-
using Tensorflow.Keras.ArgsDefinition.Rnn;
65
using Tensorflow.Keras.Engine;
7-
using Tensorflow.Keras.Layers.Rnn;
6+
using Tensorflow.Keras.Layers;
87
using Tensorflow.NumPy;
98
using static Tensorflow.Binding;
109
using static Tensorflow.KerasApi;

src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
using Tensorflow.Keras.Engine;
77
using Tensorflow.Keras.Utils;
88

9-
namespace Tensorflow.Keras.Layers.Rnn
9+
namespace Tensorflow.Keras.Layers
1010
{
1111
public abstract class DropoutRNNCellMixin: Layer, IRnnCell
1212
{

src/TensorFlowNET.Keras/Layers/Rnn/GRUCell.cs

+1-2
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@
33
using System.Diagnostics;
44
using System.Text;
55
using Tensorflow.Keras.ArgsDefinition;
6-
using Tensorflow.Keras.ArgsDefinition.Rnn;
76
using Tensorflow.Common.Extensions;
87
using Tensorflow.Common.Types;
98
using Tensorflow.Keras.Saving;
109

11-
namespace Tensorflow.Keras.Layers.Rnn
10+
namespace Tensorflow.Keras.Layers
1211
{
1312
/// <summary>
1413
/// Cell class for the GRU layer.

src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
using System.Linq;
2-
using Tensorflow.Keras.ArgsDefinition.Rnn;
2+
using Tensorflow.Keras.ArgsDefinition;
33
using Tensorflow.Keras.Engine;
44
using Tensorflow.Common.Types;
55
using Tensorflow.Common.Extensions;
66

7-
namespace Tensorflow.Keras.Layers.Rnn
7+
namespace Tensorflow.Keras.Layers
88
{
99
/// <summary>
1010
/// Long Short-Term Memory layer - Hochreiter 1997.

src/TensorFlowNET.Keras/Layers/Rnn/LSTMCell.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
using System.Diagnostics;
44
using Tensorflow.Common.Extensions;
55
using Tensorflow.Common.Types;
6-
using Tensorflow.Keras.ArgsDefinition.Rnn;
6+
using Tensorflow.Keras.ArgsDefinition;
77
using Tensorflow.Keras.Engine;
88
using Tensorflow.Keras.Saving;
99
using Tensorflow.Keras.Utils;
1010

11-
namespace Tensorflow.Keras.Layers.Rnn
11+
namespace Tensorflow.Keras.Layers
1212
{
1313
/// <summary>
1414
/// Cell class for the LSTM layer.

src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
using System.Collections.Generic;
44
using System.Reflection;
55
using Tensorflow.Keras.ArgsDefinition;
6-
using Tensorflow.Keras.ArgsDefinition.Rnn;
76
using Tensorflow.Keras.Engine;
87
using Tensorflow.Keras.Saving;
98
using Tensorflow.Util;
@@ -14,7 +13,7 @@
1413
using System.Runtime.CompilerServices;
1514
// from tensorflow.python.distribute import distribution_strategy_context as ds_context;
1615

17-
namespace Tensorflow.Keras.Layers.Rnn
16+
namespace Tensorflow.Keras.Layers
1817
{
1918
/// <summary>
2019
/// Base class for recurrent layers.
@@ -185,6 +184,7 @@ private Tensors compute_mask(Tensors inputs, Tensors mask)
185184

186185
public override void build(KerasShapesWrapper input_shape)
187186
{
187+
_buildInputShape = input_shape;
188188
input_shape = new KerasShapesWrapper(input_shape.Shapes[0]);
189189

190190
InputSpec get_input_spec(Shape shape)

src/TensorFlowNET.Keras/Layers/Rnn/RnnBase.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
using Tensorflow.Keras.ArgsDefinition;
55
using Tensorflow.Keras.Engine;
66

7-
namespace Tensorflow.Keras.Layers.Rnn
7+
namespace Tensorflow.Keras.Layers
88
{
99
public abstract class RnnBase: Layer
1010
{

src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
using System.Data;
2-
using Tensorflow.Keras.ArgsDefinition.Rnn;
2+
using Tensorflow.Keras.ArgsDefinition;
33
using Tensorflow.Keras.Saving;
44
using Tensorflow.Operations.Activation;
55
using static HDF.PInvoke.H5Z;
66
using static Tensorflow.ApiDef.Types;
77

8-
namespace Tensorflow.Keras.Layers.Rnn
8+
namespace Tensorflow.Keras.Layers
99
{
1010
public class SimpleRNN : RNN
1111
{

src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Text;
4-
using Tensorflow.Keras.ArgsDefinition.Rnn;
4+
using Tensorflow.Keras.ArgsDefinition;
55
using Tensorflow.Keras.Engine;
66
using Tensorflow.Keras.Saving;
77
using Tensorflow.Common.Types;
88
using Tensorflow.Common.Extensions;
99
using Tensorflow.Keras.Utils;
1010
using Tensorflow.Graphs;
1111

12-
namespace Tensorflow.Keras.Layers.Rnn
12+
namespace Tensorflow.Keras.Layers
1313
{
1414
/// <summary>
1515
/// Cell class for SimpleRNN.

src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
using System.Linq;
44
using Tensorflow.Common.Extensions;
55
using Tensorflow.Common.Types;
6-
using Tensorflow.Keras.ArgsDefinition.Rnn;
6+
using Tensorflow.Keras.ArgsDefinition;
77
using Tensorflow.Keras.Engine;
88
using Tensorflow.Keras.Saving;
99
using Tensorflow.Keras.Utils;
1010

11-
namespace Tensorflow.Keras.Layers.Rnn
11+
namespace Tensorflow.Keras.Layers
1212
{
1313
public class StackedRNNCells : Layer, IRnnCell
1414
{

src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs

-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
using Tensorflow.Keras.ArgsDefinition;
1414
using Tensorflow.Keras.Engine;
1515
using Tensorflow.Keras.Layers;
16-
using Tensorflow.Keras.Layers.Rnn;
1716
using Tensorflow.Keras.Losses;
1817
using Tensorflow.Keras.Metrics;
1918
using Tensorflow.Keras.Saving.SavedModel;

src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
using System.Linq;
44
using System.Text;
55
using Tensorflow.Keras.Engine;
6-
using Tensorflow.Keras.Layers.Rnn;
6+
using Tensorflow.Keras.Layers;
77
using Tensorflow.Keras.Metrics;
88
using Tensorflow.Train;
99

src/TensorFlowNET.Keras/Utils/RnnUtils.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
using System.Diagnostics;
44
using System.Text;
55
using Tensorflow.Common.Types;
6-
using Tensorflow.Keras.Layers.Rnn;
6+
using Tensorflow.Keras.Layers;
77
using Tensorflow.Common.Extensions;
88

99
namespace Tensorflow.Keras.Utils
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
����������̟땐͉��������� ��Σ�����(��ռ����2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
�&root"_tf_keras_sequential*�&{"name": "sequential", "trainable": true, "expects_training_arg": true, "dtype": "float32", "batch_input_shape": null, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": false, "class_name": "Sequential", "config": {"name": "sequential", "layers": [{"class_name": "InputLayer", "config": {"batch_input_shape": {"class_name": "__tuple__", "items": [null, 5, 3]}, "dtype": "float32", "sparse": false, "ragged": false, "name": "input_1"}}, {"class_name": "LSTM", "config": {"name": "lstm", "trainable": true, "dtype": "float32", "return_sequences": false, "return_state": false, "go_backwards": false, "stateful": false, "unroll": false, "time_major": false, "units": 32, "activation": "tanh", "recurrent_activation": "sigmoid", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 1}, "recurrent_initializer": {"class_name": "Orthogonal", "config": {"gain": 1.0, "seed": null}, "shared_object_id": 2}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 3}, "unit_forget_bias": true, "kernel_regularizer": null, "recurrent_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "recurrent_constraint": null, "bias_constraint": null, "dropout": 0.0, "recurrent_dropout": 0.0, "implementation": 2}}, {"class_name": "Dense", "config": {"name": "dense", "trainable": true, "dtype": "float32", "units": 1, "activation": "sigmoid", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}]}, "shared_object_id": 9, "input_spec": [{"class_name": "InputSpec", "config": {"dtype": null, "shape": {"class_name": "__tuple__", "items": [null, 5, 3]}, "ndim": 3, "max_ndim": null, "min_ndim": null, "axes": {}}}], "build_input_shape": {"class_name": "TensorShape", "items": [null, 5, 3]}, "is_graph_network": true, "full_save_spec": {"class_name": "__tuple__", "items": [[{"class_name": "TypeSpec", "type_spec": "tf.TensorSpec", "serialized": [{"class_name": "TensorShape", "items": [null, 5, 3]}, "float32", "input_1"]}], {}]}, "save_spec": {"class_name": "TypeSpec", "type_spec": "tf.TensorSpec", "serialized": [{"class_name": "TensorShape", "items": [null, 5, 3]}, "float32", "input_1"]}, "keras_version": "2.12.0", "backend": "tensorflow", "model_config": {"class_name": "Sequential", "config": {"name": "sequential", "layers": [{"class_name": "InputLayer", "config": {"batch_input_shape": {"class_name": "__tuple__", "items": [null, 5, 3]}, "dtype": "float32", "sparse": false, "ragged": false, "name": "input_1"}, "shared_object_id": 0}, {"class_name": "LSTM", "config": {"name": "lstm", "trainable": true, "dtype": "float32", "return_sequences": false, "return_state": false, "go_backwards": false, "stateful": false, "unroll": false, "time_major": false, "units": 32, "activation": "tanh", "recurrent_activation": "sigmoid", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 1}, "recurrent_initializer": {"class_name": "Orthogonal", "config": {"gain": 1.0, "seed": null}, "shared_object_id": 2}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 3}, "unit_forget_bias": true, "kernel_regularizer": null, "recurrent_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "recurrent_constraint": null, "bias_constraint": null, "dropout": 0.0, "recurrent_dropout": 0.0, "implementation": 2}, "shared_object_id": 5}, {"class_name": "Dense", "config": {"name": "dense", "trainable": true, "dtype": "float32", "units": 1, "activation": "sigmoid", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 6}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 7}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "shared_object_id": 8}]}}, "training_config": {"loss": "binary_crossentropy", "metrics": [[{"class_name": "MeanMetricWrapper", "config": {"name": "accuracy", "dtype": "float32", "fn": "binary_accuracy"}, "shared_object_id": 11}]], "weighted_metrics": null, "loss_weights": null, "optimizer_config": {"class_name": "Custom>Adam", "config": {"name": "Adam", "weight_decay": null, "clipnorm": null, "global_clipnorm": null, "clipvalue": null, "use_ema": false, "ema_momentum": 0.99, "ema_overwrite_frequency": null, "jit_compile": false, "is_legacy_optimizer": false, "learning_rate": 0.0010000000474974513, "beta_1": 0.9, "beta_2": 0.999, "epsilon": 1e-07, "amsgrad": false}}}}2
3+
� root.layer_with_weights-0"_tf_keras_rnn_layer*� {"name": "lstm", "trainable": true, "expects_training_arg": true, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": true, "class_name": "LSTM", "config": {"name": "lstm", "trainable": true, "dtype": "float32", "return_sequences": false, "return_state": false, "go_backwards": false, "stateful": false, "unroll": false, "time_major": false, "units": 32, "activation": "tanh", "recurrent_activation": "sigmoid", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 1}, "recurrent_initializer": {"class_name": "Orthogonal", "config": {"gain": 1.0, "seed": null}, "shared_object_id": 2}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 3}, "unit_forget_bias": true, "kernel_regularizer": null, "recurrent_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "recurrent_constraint": null, "bias_constraint": null, "dropout": 0.0, "recurrent_dropout": 0.0, "implementation": 2}, "shared_object_id": 5, "input_spec": [{"class_name": "InputSpec", "config": {"dtype": null, "shape": {"class_name": "__tuple__", "items": [null, null, 3]}, "ndim": 3, "max_ndim": null, "min_ndim": null, "axes": {}}, "shared_object_id": 12}], "build_input_shape": {"class_name": "TensorShape", "items": [null, 5, 3]}}2
4+
�root.layer_with_weights-1"_tf_keras_layer*�{"name": "dense", "trainable": true, "expects_training_arg": false, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": true, "class_name": "Dense", "config": {"name": "dense", "trainable": true, "dtype": "float32", "units": 1, "activation": "sigmoid", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 6}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 7}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "shared_object_id": 8, "input_spec": {"class_name": "InputSpec", "config": {"dtype": null, "shape": null, "ndim": null, "max_ndim": null, "min_ndim": 2, "axes": {"-1": 32}}, "shared_object_id": 13}, "build_input_shape": {"class_name": "TensorShape", "items": [null, 32]}}2
5+
�root.layer_with_weights-0.cell"_tf_keras_layer*�{"name": "lstm_cell", "trainable": true, "expects_training_arg": true, "dtype": "float32", "batch_input_shape": null, "stateful": false, "must_restore_from_config": false, "preserve_input_structure_in_config": false, "autocast": true, "class_name": "LSTMCell", "config": {"name": "lstm_cell", "trainable": true, "dtype": "float32", "units": 32, "activation": "tanh", "recurrent_activation": "sigmoid", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}, "shared_object_id": 1}, "recurrent_initializer": {"class_name": "Orthogonal", "config": {"gain": 1.0, "seed": null}, "shared_object_id": 2}, "bias_initializer": {"class_name": "Zeros", "config": {}, "shared_object_id": 3}, "unit_forget_bias": true, "kernel_regularizer": null, "recurrent_regularizer": null, "bias_regularizer": null, "kernel_constraint": null, "recurrent_constraint": null, "bias_constraint": null, "dropout": 0.0, "recurrent_dropout": 0.0, "implementation": 2}, "shared_object_id": 4, "build_input_shape": {"class_name": "__tuple__", "items": [null, 3]}}2
6+
�Rroot.keras_api.metrics.0"_tf_keras_metric*�{"class_name": "Mean", "name": "loss", "dtype": "float32", "config": {"name": "loss", "dtype": "float32"}, "shared_object_id": 14}2
7+
�Sroot.keras_api.metrics.1"_tf_keras_metric*�{"class_name": "MeanMetricWrapper", "name": "accuracy", "dtype": "float32", "config": {"name": "accuracy", "dtype": "float32", "fn": "binary_accuracy"}, "shared_object_id": 11}2
Binary file not shown.

test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
using System.Threading.Tasks;
77
using Tensorflow.Common.Types;
88
using Tensorflow.Keras.Engine;
9-
using Tensorflow.Keras.Layers.Rnn;
9+
using Tensorflow.Keras.Layers;
1010
using Tensorflow.Keras.Saving;
1111
using Tensorflow.NumPy;
1212
using Tensorflow.Train;

test/TensorFlowNET.Keras.UnitTest/Model/ModelLoadTest.cs

+14-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
using Microsoft.VisualStudio.TestTools.UnitTesting;
1+
using Microsoft.VisualStudio.TestPlatform.Utilities;
2+
using Microsoft.VisualStudio.TestTools.UnitTesting;
23
using System.Linq;
4+
using Tensorflow.Keras.Engine;
35
using Tensorflow.Keras.Optimizers;
46
using Tensorflow.Keras.UnitTest.Helpers;
57
using Tensorflow.NumPy;
@@ -79,6 +81,17 @@ public void ModelWithSelfDefinedModule()
7981
model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
8082
}
8183

84+
[TestMethod]
85+
public void LSTMLoad()
86+
{
87+
var model = tf.keras.models.load_model(@"Assets/lstm_from_sequential");
88+
model.summary();
89+
model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.MeanSquaredError(), new string[] { "accuracy" });
90+
var inputs = tf.random.normal(shape: (10, 5, 3));
91+
var outputs = tf.random.normal(shape: (10, 1));
92+
model.fit(inputs.numpy(), outputs.numpy(), batch_size: 10, epochs: 5, workers: 16, use_multiprocessing: true);
93+
}
94+
8295
[Ignore]
8396
[TestMethod]
8497
public void VGG19()

0 commit comments

Comments
 (0)