diff --git a/src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs b/src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs index e05d3deb3..c61d04b25 100644 --- a/src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs +++ b/src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs @@ -12,9 +12,14 @@ public class GeneralizedTensorShape: IEnumerable, INestStructure /// create a single-dim generalized Tensor shape. /// /// - public GeneralizedTensorShape(int dim) + public GeneralizedTensorShape(int dim, int size = 1) { - Shapes = new TensorShapeConfig[] { new TensorShapeConfig() { Items = new long?[] { dim } } }; + var elem = new TensorShapeConfig() { Items = new long?[] { dim } }; + Shapes = Enumerable.Repeat(elem, size).ToArray(); + //Shapes = new TensorShapeConfig[size]; + //Shapes.Initialize(new TensorShapeConfig() { Items = new long?[] { dim } }); + //Array.Initialize(Shapes, new TensorShapeConfig() { Items = new long?[] { dim } }); + ////Shapes = new TensorShapeConfig[] { new TensorShapeConfig() { Items = new long?[] { dim } } }; } public GeneralizedTensorShape(Shape shape) @@ -113,6 +118,11 @@ public INestStructure MapStructure(Func func) return new Nest(Shapes.Select(s => DealWithSingleShape(s))); } } + + + + public static implicit operator GeneralizedTensorShape(int dims) + => new GeneralizedTensorShape(dims); public IEnumerator GetEnumerator() { diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs index ed5a1d6dd..116ff7a2f 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs @@ -10,6 +10,9 @@ public class RNNArgs : AutoSerializeLayerArgs [JsonProperty("cell")] // TODO: the cell should be serialized with `serialize_keras_object`. public IRnnCell Cell { get; set; } = null; + [JsonProperty("cells")] + public IList Cells { get; set; } = null; + [JsonProperty("return_sequences")] public bool ReturnSequences { get; set; } = false; [JsonProperty("return_state")] diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs index fdfadab85..ea6f830b8 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs @@ -1,10 +1,11 @@ using System.Collections.Generic; +using Tensorflow.Keras.Layers.Rnn; namespace Tensorflow.Keras.ArgsDefinition.Rnn { public class StackedRNNCellsArgs : LayerArgs { - public IList Cells { get; set; } + public IList Cells { get; set; } public Dictionary Kwargs { get; set; } = null; } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs index 6a29f9e5e..3b2238164 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs @@ -1,5 +1,6 @@ using System; using Tensorflow.Framework.Models; +using Tensorflow.Keras.Layers.Rnn; using Tensorflow.NumPy; using static Google.Protobuf.Reflection.FieldDescriptorProto.Types; @@ -192,6 +193,19 @@ public ILayer Rescaling(float scale, float offset = 0, Shape input_shape = null); + public IRnnCell SimpleRNNCell( + int units, + string activation = "tanh", + bool use_bias = true, + string kernel_initializer = "glorot_uniform", + string recurrent_initializer = "orthogonal", + string bias_initializer = "zeros", + float dropout = 0f, + float recurrent_dropout = 0f); + + public IRnnCell StackedRNNCells( + IEnumerable cells); + public ILayer SimpleRNN(int units, string activation = "tanh", string kernel_initializer = "glorot_uniform", @@ -200,6 +214,26 @@ public ILayer SimpleRNN(int units, bool return_sequences = false, bool return_state = false); + public ILayer RNN( + IRnnCell cell, + bool return_sequences = false, + bool return_state = false, + bool go_backwards = false, + bool stateful = false, + bool unroll = false, + bool time_major = false + ); + + public ILayer RNN( + IEnumerable cell, + bool return_sequences = false, + bool return_state = false, + bool go_backwards = false, + bool stateful = false, + bool unroll = false, + bool time_major = false + ); + public ILayer Subtract(); } } diff --git a/src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs b/src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs index ed65a08d7..08e73fe67 100644 --- a/src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs +++ b/src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs @@ -109,7 +109,19 @@ public TensorArray scatter(Tensor indices, Tensor value, string name = null) return ta; });*/ - throw new NotImplementedException(""); + //if (indices is EagerTensor) + //{ + // indices = indices as EagerTensor; + // indices = indices.numpy(); + //} + + //foreach (var (index, val) in zip(indices.ToArray(), array_ops.unstack(value))) + //{ + // this.write(index, val); + //} + //return base; + //throw new NotImplementedException(""); + return this; } public void _merge_element_shape(Shape shape) diff --git a/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs b/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs index 16870e9f6..dde2624af 100644 --- a/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs +++ b/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs @@ -17,6 +17,7 @@ limitations under the License. using System; using System.Collections.Generic; using System.Linq; +using Tensorflow.Eager; using static Tensorflow.Binding; namespace Tensorflow.Operations @@ -146,7 +147,9 @@ public TensorArray scatter(Tensor indices, Tensor value, string name = null) return ta; });*/ - throw new NotImplementedException(""); + + //throw new NotImplementedException(""); + return this; } public void _merge_element_shape(Shape shape) diff --git a/src/TensorFlowNET.Keras/BackendImpl.cs b/src/TensorFlowNET.Keras/BackendImpl.cs index 144910669..1336e9af5 100644 --- a/src/TensorFlowNET.Keras/BackendImpl.cs +++ b/src/TensorFlowNET.Keras/BackendImpl.cs @@ -510,7 +510,7 @@ Tensor swap_batch_timestep(Tensor input_t) } } - + // tf.where needs its condition tensor to be the same shape as its two // result tensors, but in our case the condition (mask) tensor is // (nsamples, 1), and inputs are (nsamples, ndimensions) or even more. @@ -535,7 +535,7 @@ Tensors _expand_mask(Tensors mask_t, Tensors input_t, int fixed_dim = 1) { mask_t = tf.expand_dims(mask_t, -1); } - var multiples = Enumerable.Repeat(1, fixed_dim).ToArray().concat(input_t.shape.as_int_list().ToList().GetRange(fixed_dim, input_t.rank)); + var multiples = Enumerable.Repeat(1, fixed_dim).ToArray().concat(input_t.shape.as_int_list().Skip(fixed_dim).ToArray()); return tf.tile(mask_t, multiples); } @@ -570,9 +570,6 @@ Tensors _expand_mask(Tensors mask_t, Tensors input_t, int fixed_dim = 1) // individually. The result of this will be a tuple of lists, each of // the item in tuple is list of the tensor with shape (batch, feature) - - - Tensors _process_single_input_t(Tensor input_t) { var unstaked_input_t = array_ops.unstack(input_t); // unstack for time_step dim @@ -609,7 +606,7 @@ object _get_input_tensor(int time) var mask_list = tf.unstack(mask); if (go_backwards) { - mask_list.Reverse(); + mask_list.Reverse().ToArray(); } for (int i = 0; i < time_steps; i++) @@ -629,9 +626,10 @@ object _get_input_tensor(int time) } else { - prev_output = successive_outputs[successive_outputs.Length - 1]; + prev_output = successive_outputs.Last(); } + // output could be a tensor output = tf.where(tiled_mask_t, output, prev_output); var flat_states = Nest.Flatten(states).ToList(); @@ -661,13 +659,13 @@ object _get_input_tensor(int time) } } - last_output = successive_outputs[successive_outputs.Length - 1]; - new_states = successive_states[successive_states.Length - 1]; + last_output = successive_outputs.Last(); + new_states = successive_states.Last(); outputs = tf.stack(successive_outputs); if (zero_output_for_mask) { - last_output = tf.where(_expand_mask(mask_list[mask_list.Length - 1], last_output), last_output, tf.zeros_like(last_output)); + last_output = tf.where(_expand_mask(mask_list.Last(), last_output), last_output, tf.zeros_like(last_output)); outputs = tf.where(_expand_mask(mask, outputs, fixed_dim: 2), outputs, tf.zeros_like(outputs)); } else // mask is null @@ -689,8 +687,8 @@ object _get_input_tensor(int time) successive_states = new Tensors { newStates }; } } - last_output = successive_outputs[successive_outputs.Length - 1]; - new_states = successive_states[successive_states.Length - 1]; + last_output = successive_outputs.Last(); + new_states = successive_states.Last(); outputs = tf.stack(successive_outputs); } } @@ -701,6 +699,8 @@ object _get_input_tensor(int time) // Create input tensor array, if the inputs is nested tensors, then it // will be flattened first, and tensor array will be created one per // flattened tensor. + + var input_ta = new List(); for (int i = 0; i < flatted_inptus.Count; i++) { @@ -719,6 +719,7 @@ object _get_input_tensor(int time) } } + // Get the time(0) input and compute the output for that, the output will // be used to determine the dtype of output tensor array. Don't read from // input_ta due to TensorArray clear_after_read default to True. @@ -773,7 +774,7 @@ object _get_input_tensor(int time) return res; }; } - // TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor)? + // TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor), it could be an integer or tensor else if (input_length is Tensor) { if (go_backwards) diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs index 3b095bc2a..dd25122d5 100644 --- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs @@ -685,6 +685,34 @@ public ILayer LeakyReLU(float alpha = 0.3f) Alpha = alpha }); + + public IRnnCell SimpleRNNCell( + int units, + string activation = "tanh", + bool use_bias = true, + string kernel_initializer = "glorot_uniform", + string recurrent_initializer = "orthogonal", + string bias_initializer = "zeros", + float dropout = 0f, + float recurrent_dropout = 0f) + => new SimpleRNNCell(new SimpleRNNCellArgs + { + Units = units, + Activation = keras.activations.GetActivationFromName(activation), + UseBias = use_bias, + KernelInitializer = GetInitializerByName(kernel_initializer), + RecurrentInitializer = GetInitializerByName(recurrent_initializer), + Dropout = dropout, + RecurrentDropout = recurrent_dropout + }); + + public IRnnCell StackedRNNCells( + IEnumerable cells) + => new StackedRNNCells(new StackedRNNCellsArgs + { + Cells = cells.ToList() + }); + /// /// /// @@ -709,6 +737,55 @@ public ILayer SimpleRNN(int units, ReturnState = return_state }); + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public ILayer RNN( + IRnnCell cell, + bool return_sequences = false, + bool return_state = false, + bool go_backwards = false, + bool stateful = false, + bool unroll = false, + bool time_major = false) + => new RNN(new RNNArgs + { + Cell = cell, + ReturnSequences = return_sequences, + ReturnState = return_state, + GoBackwards = go_backwards, + Stateful = stateful, + Unroll = unroll, + TimeMajor = time_major + }); + + public ILayer RNN( + IEnumerable cell, + bool return_sequences = false, + bool return_state = false, + bool go_backwards = false, + bool stateful = false, + bool unroll = false, + bool time_major = false) + => new RNN(new RNNArgs + { + Cells = cell.ToList(), + ReturnSequences = return_sequences, + ReturnState = return_state, + GoBackwards = go_backwards, + Stateful = stateful, + Unroll = unroll, + TimeMajor = time_major + }); + /// /// Long Short-Term Memory layer - Hochreiter 1997. /// diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs b/src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs index 21396853f..78d3dac96 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs @@ -17,6 +17,21 @@ public DropoutRNNCellMixin(LayerArgs args): base(args) } + protected void _create_non_trackable_mask_cache() + { + + } + + public void reset_dropout_mask() + { + + } + + public void reset_recurrent_dropout_mask() + { + + } + public Tensors? get_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1) { if (dropout == 0f) diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs index ab4cef124..0ebd73628 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs @@ -38,7 +38,17 @@ public RNN(RNNArgs args) : base(PreConstruct(args)) SupportsMasking = true; // if is StackedRnncell - _cell = args.Cell; + if (args.Cells != null) + { + _cell = new StackedRNNCells(new StackedRNNCellsArgs + { + Cells = args.Cells + }); + } + else + { + _cell = args.Cell; + } // get input_shape _args = PreConstruct(args); @@ -122,6 +132,8 @@ private OneOf> compute_output_shape(Shape input_shape) var state_shape = new int[] { (int)batch }.concat(flat_state.as_int_list()); return new Shape(state_shape); }; + + var state_shape = _get_state_shape(state_size); return new List { output_shape, state_shape }; @@ -240,7 +252,7 @@ protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bo if (_cell is StackedRNNCells) { var stack_cell = _cell as StackedRNNCells; - foreach (var cell in stack_cell.Cells) + foreach (IRnnCell cell in stack_cell.Cells) { _maybe_reset_cell_dropout_mask(cell); } @@ -253,7 +265,7 @@ protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bo } Shape input_shape; - if (!inputs.IsSingle()) + if (!inputs.IsNested()) { // In the case of nested input, use the first element for shape check // input_shape = nest.flatten(inputs)[0].shape; @@ -267,7 +279,7 @@ protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bo var timesteps = _args.TimeMajor ? input_shape[0] : input_shape[1]; - if (_args.Unroll && timesteps != null) + if (_args.Unroll && timesteps == null) { throw new ValueError( "Cannot unroll a RNN if the " + @@ -302,7 +314,6 @@ protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bo states = new Tensors(states.SkipLast(_num_constants)); states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states[0]) : states; var (output, new_states) = _cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants }); - // TODO(Wanglongzhi2001),should cell_call_fn's return value be Tensors, Tensors? return (output, new_states.Single); }; } @@ -310,13 +321,14 @@ protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bo { step = (inputs, states) => { - states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states[0]) : states; + states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states.First()) : states; var (output, new_states) = _cell.Apply(inputs, states); - return (output, new_states.Single); + return (output, new_states); }; } - - var (last_output, outputs, states) = keras.backend.rnn(step, + + var (last_output, outputs, states) = keras.backend.rnn( + step, inputs, initial_state, constants: constants, @@ -394,6 +406,7 @@ public override Tensors Apply(Tensors inputs, Tensors initial_states = null, boo initial_state = null; inputs = inputs[0]; } + if (_args.Stateful) { @@ -402,7 +415,7 @@ public override Tensors Apply(Tensors inputs, Tensors initial_states = null, boo var tmp = new Tensor[] { }; foreach (var s in nest.flatten(States)) { - tmp.add(tf.math.count_nonzero((Tensor)s)); + tmp.add(tf.math.count_nonzero(s.Single())); } var non_zero_count = tf.add_n(tmp); //initial_state = tf.cond(non_zero_count > 0, () => States, () => initial_state); @@ -415,6 +428,15 @@ public override Tensors Apply(Tensors inputs, Tensors initial_states = null, boo { initial_state = States; } + // TODO(Wanglongzhi2001), +// initial_state = tf.nest.map_structure( +//# When the layer has a inferred dtype, use the dtype from the +//# cell. +// lambda v: tf.cast( +// v, self.compute_dtype or self.cell.compute_dtype +// ), +// initial_state, +// ) } else if (initial_state is null) @@ -424,10 +446,9 @@ public override Tensors Apply(Tensors inputs, Tensors initial_states = null, boo if (initial_state.Length != States.Length) { - throw new ValueError( - $"Layer {this} expects {States.Length} state(s), " + - $"but it received {initial_state.Length} " + - $"initial state(s). Input received: {inputs}"); + throw new ValueError($"Layer {this} expects {States.Length} state(s), " + + $"but it received {initial_state.Length} " + + $"initial state(s). Input received: {inputs}"); } return (inputs, initial_state, constants); @@ -458,11 +479,11 @@ private void _validate_args_if_ragged(bool is_ragged_input, Tensors mask) void _maybe_reset_cell_dropout_mask(ILayer cell) { - //if (cell is DropoutRNNCellMixin) - //{ - // cell.reset_dropout_mask(); - // cell.reset_recurrent_dropout_mask(); - //} + if (cell is DropoutRNNCellMixin CellDRCMixin) + { + CellDRCMixin.reset_dropout_mask(); + CellDRCMixin.reset_recurrent_dropout_mask(); + } } private static RNNArgs PreConstruct(RNNArgs args) @@ -537,15 +558,24 @@ public Tensors __call__(Tensors inputs, Tensor state = null, Tensor training = n protected Tensors get_initial_state(Tensors inputs) { + var get_initial_state_fn = _cell.GetType().GetMethod("get_initial_state"); + var input = inputs[0]; - var input_shape = input.shape; + var input_shape = inputs.shape; var batch_size = _args.TimeMajor ? input_shape[1] : input_shape[0]; var dtype = input.dtype; - Tensors init_state; - if (_cell is RnnCellBase rnn_base_cell) + + Tensors init_state = new Tensors(); + + if(get_initial_state_fn != null) { - init_state = rnn_base_cell.GetInitialState(null, batch_size, dtype); + init_state = (Tensors)get_initial_state_fn.Invoke(_cell, new object[] { inputs, batch_size, dtype }); + } + //if (_cell is RnnCellBase rnn_base_cell) + //{ + // init_state = rnn_base_cell.GetInitialState(null, batch_size, dtype); + //} else { init_state = RnnUtils.generate_zero_filled_state(batch_size, _cell.StateSize, dtype); diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs index f0b2ed4d7..39610ff52 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs @@ -6,6 +6,7 @@ using Tensorflow.Keras.Saving; using Tensorflow.Common.Types; using Tensorflow.Common.Extensions; +using Tensorflow.Keras.Utils; namespace Tensorflow.Keras.Layers.Rnn { @@ -77,8 +78,10 @@ protected override Tensors Call(Tensors inputs, Tensors states = null, bool? tra var rec_dp_mask = get_recurrent_dropout_maskcell_for_cell(prev_output, training.Value); Tensor h; + var ranks = inputs.rank; if (dp_mask != null) { + h = math_ops.matmul(math_ops.multiply(inputs.Single, dp_mask.Single), _kernel.AsTensor()); } else @@ -95,7 +98,7 @@ protected override Tensors Call(Tensors inputs, Tensors states = null, bool? tra { prev_output = math_ops.multiply(prev_output, rec_dp_mask); } - + var tmp = _recurrent_kernel.AsTensor(); Tensor output = h + math_ops.matmul(prev_output, _recurrent_kernel.AsTensor()); if (_args.Activation != null) @@ -113,5 +116,10 @@ protected override Tensors Call(Tensors inputs, Tensors states = null, bool? tra return new Tensors(output, output); } } + + public Tensors get_initial_state(Tensors inputs = null, long? batch_size = null, TF_DataType? dtype = null) + { + return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size.Value, dtype.Value); + } } } diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs b/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs index 0b92fd3cf..56634853d 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs @@ -1,17 +1,20 @@ using System; using System.Collections.Generic; using System.ComponentModel; +using System.Linq; +using Tensorflow.Common.Extensions; using Tensorflow.Common.Types; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; +using Tensorflow.Keras.Utils; namespace Tensorflow.Keras.Layers.Rnn { public class StackedRNNCells : Layer, IRnnCell { - public IList Cells { get; set; } + public IList Cells { get; set; } public bool reverse_state_order; public StackedRNNCells(StackedRNNCellsArgs args) : base(args) @@ -20,8 +23,19 @@ public StackedRNNCells(StackedRNNCellsArgs args) : base(args) { args.Kwargs = new Dictionary(); } - + foreach (var cell in args.Cells) + { + //Type type = cell.GetType(); + //var CallMethodInfo = type.GetMethod("Call"); + //if (CallMethodInfo == null) + //{ + // throw new ValueError( + // "All cells must have a `Call` method. " + + // $"Received cell without a `Call` method: {cell}"); + //} + } Cells = args.Cells; + reverse_state_order = (bool)args.Kwargs.Get("reverse_state_order", false); if (reverse_state_order) @@ -33,91 +47,112 @@ public StackedRNNCells(StackedRNNCellsArgs args) : base(args) } } - public object state_size + public GeneralizedTensorShape StateSize { - get => throw new NotImplementedException(); - //@property - //def state_size(self) : - // return tuple(c.state_size for c in - // (self.cells[::- 1] if self.reverse_state_order else self.cells)) + get + { + GeneralizedTensorShape state_size = new GeneralizedTensorShape(1, Cells.Count); + if (reverse_state_order && Cells.Count > 0) + { + var idxAndCell = Cells.Reverse().Select((cell, idx) => (idx, cell)); + foreach (var cell in idxAndCell) + { + state_size.Shapes[cell.idx] = cell.cell.StateSize.Shapes.First(); + } + } + else + { + //foreach (var cell in Cells) + //{ + // state_size.Shapes.add(cell.StateSize.Shapes.First()); + + //} + var idxAndCell = Cells.Select((cell, idx) => (idx, cell)); + foreach (var cell in idxAndCell) + { + state_size.Shapes[cell.idx] = cell.cell.StateSize.Shapes.First(); + } + } + return state_size; + } } public object output_size { get { - var lastCell = Cells[Cells.Count - 1]; - - if (lastCell.output_size != -1) + var lastCell = Cells.LastOrDefault(); + if (lastCell.OutputSize.ToSingleShape() != -1) { - return lastCell.output_size; + return lastCell.OutputSize; } else if (RNN.is_multiple_state(lastCell.StateSize)) { - // return ((dynamic)Cells[-1].state_size)[0]; - throw new NotImplementedException(""); + return lastCell.StateSize.First(); + //throw new NotImplementedException(""); } else { - return Cells[-1].state_size; + return lastCell.StateSize; } } } - public object get_initial_state() + public Tensors get_initial_state(Tensors inputs = null, long? batch_size = null, TF_DataType? dtype = null) { - throw new NotImplementedException(); - // def get_initial_state(self, inputs= None, batch_size= None, dtype= None) : - // initial_states = [] - // for cell in self.cells[::- 1] if self.reverse_state_order else self.cells: - // get_initial_state_fn = getattr(cell, 'get_initial_state', None) - // if get_initial_state_fn: - // initial_states.append(get_initial_state_fn( - // inputs=inputs, batch_size=batch_size, dtype=dtype)) - // else: - // initial_states.append(_generate_zero_filled_state_for_cell( - // cell, inputs, batch_size, dtype)) - - // return tuple(initial_states) + var cells = reverse_state_order ? Cells.Reverse() : Cells; + Tensors initial_states = new Tensors(); + foreach (var cell in cells) + { + var get_initial_state_fn = cell.GetType().GetMethod("get_initial_state"); + if (get_initial_state_fn != null) + { + var result = (Tensors)get_initial_state_fn.Invoke(cell, new object[] { inputs, batch_size, dtype }); + initial_states.Add(result); + } + else + { + initial_states.Add(RnnUtils.generate_zero_filled_state_for_cell(cell, inputs, batch_size.Value, dtype.Value)); + } + } + return initial_states; } - public object call() + protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) { - throw new NotImplementedException(); - // def call(self, inputs, states, constants= None, training= None, ** kwargs): - // # Recover per-cell states. - // state_size = (self.state_size[::- 1] - // if self.reverse_state_order else self.state_size) - // nested_states = nest.pack_sequence_as(state_size, nest.flatten(states)) - - // # Call the cells in order and store the returned states. - // new_nested_states = [] - // for cell, states in zip(self.cells, nested_states) : - // states = states if nest.is_nested(states) else [states] - //# TF cell does not wrap the state into list when there is only one state. - // is_tf_rnn_cell = getattr(cell, '_is_tf_rnn_cell', None) is not None - // states = states[0] if len(states) == 1 and is_tf_rnn_cell else states - // if generic_utils.has_arg(cell.call, 'training'): - // kwargs['training'] = training - // else: - // kwargs.pop('training', None) - // # Use the __call__ function for callable objects, eg layers, so that it - // # will have the proper name scopes for the ops, etc. - // cell_call_fn = cell.__call__ if callable(cell) else cell.call - // if generic_utils.has_arg(cell.call, 'constants'): - // inputs, states = cell_call_fn(inputs, states, - // constants= constants, ** kwargs) - // else: - // inputs, states = cell_call_fn(inputs, states, ** kwargs) - // new_nested_states.append(states) + // Recover per-cell states. + var state_size = reverse_state_order ? StateSize.Reverse() : StateSize; + var nested_states = reverse_state_order ? state.Flatten().Reverse() : state.Flatten(); - // return inputs, nest.pack_sequence_as(state_size, - // nest.flatten(new_nested_states)) + + var new_nest_states = new Tensors(); + // Call the cells in order and store the returned states. + foreach (var (cell, states) in zip(Cells, nested_states)) + { + // states = states if tf.nest.is_nested(states) else [states] + var type = cell.GetType(); + bool IsTFRnnCell = type.GetProperty("IsTFRnnCell") != null; + state = len(state) == 1 && IsTFRnnCell ? state.FirstOrDefault() : state; + + RnnOptionalArgs? rnn_optional_args = optional_args as RnnOptionalArgs; + Tensors? constants = rnn_optional_args?.Constants; + + Tensors new_states; + (inputs, new_states) = cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants }); + + new_nest_states.Add(new_states); + } + new_nest_states = reverse_state_order ? new_nest_states.Reverse().ToArray() : new_nest_states.ToArray(); + return new Nest(new List> { + new Nest(new List> { new Nest(inputs.Single()) }), new Nest(new_nest_states) }) + .ToTensors(); } + + public void build() { - throw new NotImplementedException(); + built = true; // @tf_utils.shape_type_conversion // def build(self, input_shape) : // if isinstance(input_shape, list) : @@ -168,9 +203,9 @@ public void from_config() { throw new NotImplementedException(); } - public GeneralizedTensorShape StateSize => throw new NotImplementedException(); + public GeneralizedTensorShape OutputSize => throw new NotImplementedException(); - public bool IsTFRnnCell => throw new NotImplementedException(); + public bool IsTFRnnCell => true; public bool SupportOptionalArgs => throw new NotImplementedException(); } } diff --git a/test/TensorFlowNET.Keras.UnitTest/Callbacks/EarlystoppingTest.cs b/test/TensorFlowNET.Keras.UnitTest/Callbacks/EarlystoppingTest.cs index ac5ba15ed..29648790f 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Callbacks/EarlystoppingTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Callbacks/EarlystoppingTest.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using Tensorflow.Keras.Callbacks; using Tensorflow.Keras.Engine; +using Tensorflow.NumPy; using static Tensorflow.KerasApi; @@ -18,7 +19,7 @@ public void Earlystopping() var layers = keras.layers; var model = keras.Sequential(new List { - layers.Rescaling(1.0f / 255, input_shape: (32, 32, 3)), + layers.Rescaling(1.0f / 255, input_shape: (28, 28, 1)), layers.Conv2D(32, 3, padding: "same", activation: keras.activations.Relu), layers.MaxPooling2D(), layers.Flatten(), @@ -36,8 +37,20 @@ public void Earlystopping() var num_epochs = 3; var batch_size = 8; - var ((x_train, y_train), (x_test, y_test)) = keras.datasets.cifar10.load_data(); - x_train = x_train / 255.0f; + var data_loader = new MnistModelLoader(); + + var dataset = data_loader.LoadAsync(new ModelLoadSetting + { + TrainDir = "mnist", + OneHot = false, + ValidationSize = 59900, + }).Result; + + NDArray x1 = np.reshape(dataset.Train.Data, (dataset.Train.Data.shape[0], 28, 28, 1)); + NDArray x2 = x1; + + var x = new NDArray[] { x1, x2 }; + // define a CallbackParams first, the parameters you pass al least contain Model and Epochs. CallbackParams callback_parameters = new CallbackParams { @@ -47,10 +60,8 @@ public void Earlystopping() // define your earlystop ICallback earlystop = new EarlyStopping(callback_parameters, "accuracy"); // define a callbcaklist, then add the earlystopping to it. - var callbacks = new List(); - callbacks.add(earlystop); - - model.fit(x_train[new Slice(0, 2000)], y_train[new Slice(0, 2000)], batch_size, num_epochs, callbacks: callbacks); + var callbacks = new List{ earlystop}; + model.fit(x, dataset.Train.Labels, batch_size, num_epochs, callbacks: callbacks); } } diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs index 55663d41c..28a16ad4e 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs @@ -4,25 +4,111 @@ using System.Linq; using System.Text; using System.Threading.Tasks; +using Tensorflow.Common.Types; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Layers.Rnn; +using Tensorflow.Keras.Saving; using Tensorflow.NumPy; +using Tensorflow.Train; using static Tensorflow.Binding; +using static Tensorflow.KerasApi; namespace Tensorflow.Keras.UnitTest.Layers { [TestClass] public class Rnn { + [TestMethod] + public void SimpleRNNCell() + { + //var cell = tf.keras.layers.SimpleRNNCell(64, dropout: 0.5f, recurrent_dropout: 0.5f); + //var h0 = new Tensors { tf.zeros(new Shape(4, 64)) }; + //var x = tf.random.normal((4, 100)); + //var (y, h1) = cell.Apply(inputs: x, states: h0); + //var h2 = h1; + //Assert.AreEqual((4, 64), y.shape); + //Assert.AreEqual((4, 64), h2[0].shape); + + //var model = keras.Sequential(new List + //{ + // keras.layers.InputLayer(input_shape: (4,100)), + // keras.layers.SimpleRNNCell(64) + //}); + //model.summary(); + + var cell = tf.keras.layers.SimpleRNNCell(64, dropout: 0.5f, recurrent_dropout: 0.5f); + var h0 = new Tensors { tf.zeros(new Shape(4, 64)) }; + var x = tf.random.normal((4, 100)); + var (y, h1) = cell.Apply(inputs: x, states: h0); + var h2 = h1; + Assert.AreEqual((4, 64), y.shape); + Assert.AreEqual((4, 64), h2[0].shape); + } + + [TestMethod] + public void StackedRNNCell() + { + var inputs = tf.ones((32, 10)); + var states = new Tensors { tf.zeros((32, 4)), tf.zeros((32, 5)) }; + var cells = new IRnnCell[] { tf.keras.layers.SimpleRNNCell(4), tf.keras.layers.SimpleRNNCell(5) }; + var stackedRNNCell = tf.keras.layers.StackedRNNCells(cells); + var (output, state) = stackedRNNCell.Apply(inputs, states); + Console.WriteLine(output); + Console.WriteLine(state.shape); + Assert.AreEqual((32, 5), output.shape); + Assert.AreEqual((32, 4), state[0].shape); + } + [TestMethod] public void SimpleRNN() { - var inputs = np.arange(6 * 10 * 8).reshape((6, 10, 8)).astype(np.float32); - /*var simple_rnn = keras.layers.SimpleRNN(4); - var output = simple_rnn.Apply(inputs); - Assert.AreEqual((32, 4), output.shape);*/ - var simple_rnn = tf.keras.layers.SimpleRNN(4, return_sequences: true, return_state: true); - var (whole_sequence_output, final_state) = simple_rnn.Apply(inputs); - Console.WriteLine(whole_sequence_output); - Console.WriteLine(final_state); + //var inputs = np.arange(6 * 10 * 8).reshape((6, 10, 8)).astype(np.float32); + ///*var simple_rnn = keras.layers.SimpleRNN(4); + //var output = simple_rnn.Apply(inputs); + //Assert.AreEqual((32, 4), output.shape);*/ + + //var simple_rnn = tf.keras.layers.SimpleRNN(4, return_sequences: true, return_state: true); + //var (whole_sequence_output, final_state) = simple_rnn.Apply(inputs); + //Assert.AreEqual((6, 10, 4), whole_sequence_output.shape); + //Assert.AreEqual((6, 4), final_state.shape); + + var inputs = keras.Input(shape: (10, 8)); + var x = keras.layers.SimpleRNN(4).Apply(inputs); + var output = keras.layers.Dense(10).Apply(x); + var model = keras.Model(inputs, output); + model.summary(); + } + [TestMethod] + public void RNNForSimpleRNNCell() + { + var inputs = tf.random.normal((32, 10, 8)); + var cell = tf.keras.layers.SimpleRNNCell(10, dropout: 0.5f, recurrent_dropout: 0.5f); + var rnn = tf.keras.layers.RNN(cell: cell); + var output = rnn.Apply(inputs); + Assert.AreEqual((32, 10), output.shape); + } + [TestMethod] + public void RNNForStackedRNNCell() + { + var inputs = tf.random.normal((32, 10, 8)); + var cells = new IRnnCell[] { tf.keras.layers.SimpleRNNCell(4), tf.keras.layers.SimpleRNNCell(5) }; + var stackedRNNCell = tf.keras.layers.StackedRNNCells(cells); + var rnn = tf.keras.layers.RNN(cell: stackedRNNCell); + var output = rnn.Apply(inputs); + Assert.AreEqual((32, 5), output.shape); + } + + [TestMethod] + public void WlzTest() + { + long[] b = { 1, 2, 3 }; + + Shape a = new Shape(Unknown).concatenate(b); + Console.WriteLine(a); + + } + + } }