diff --git a/src/TensorFlowNET.Core/Tensors/Tensors.cs b/src/TensorFlowNET.Core/Tensors/Tensors.cs index cba8f9541..259b1eec7 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensors.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensors.cs @@ -58,17 +58,12 @@ public Tensor? SingleOrNull public Tensor this[params string[] slices] => this.First()[slices]; - public Tensors(Tensor tensor) : base(tensor) - { - - } - private Tensors(Nest nested) : base(nested) { } - public Tensors(params Tensor[] tensors): base(tensors.Select(x => new Nest(x))) + public Tensors(params Tensor[] tensors): base(DealWithConstructorArrayInput(tensors)) { } @@ -83,6 +78,22 @@ public Tensors(NDArray nd): base(ops.convert_to_tensor(nd)) } + private static Nest DealWithConstructorArrayInput(Tensor[] tensors) + { + if (tensors.Length == 0) + { + return Nest.Empty; + } + else if(tensors.Length == 1) + { + return new Nest(tensors[0]); + } + else + { + return new Nest(tensors.Select(x => new Nest(x))); + } + } + public bool IsSingle() { return Length == 1; @@ -107,9 +118,14 @@ public void Add(Tensor tensor) ListValue = new() { new Nest(Value), new Nest(tensor) }; Value = null; } - else + else if(NestType == NestType.List) + { + ListValue!.Add(new Nest(tensor)); + } + else //Empty { - ListValue.Add(new Nest(tensor)); + NestType = NestType.Node; + Value = tensor; } } @@ -128,9 +144,14 @@ public void AddRange(IEnumerable tensors) ListValue.AddRange(tensors.Select(x => new Nest(x))); Value = null; } - else + else if(NestType == NestType.List) { - ListValue.AddRange(tensors.Select(x => new Nest(x))); + ListValue!.AddRange(tensors.Select(x => new Nest(x))); + } + else // empty + { + NestType = NestType.List; + ListValue = tensors.Select(x => new Nest(x)).ToList(); } } diff --git a/src/TensorFlowNET.Keras/BackendImpl.cs b/src/TensorFlowNET.Keras/BackendImpl.cs index 30b73e82f..144910669 100644 --- a/src/TensorFlowNET.Keras/BackendImpl.cs +++ b/src/TensorFlowNET.Keras/BackendImpl.cs @@ -651,13 +651,13 @@ object _get_input_tensor(int time) states = Nest.PackSequenceAs(states, flat_final_states).ToTensors(); if (return_all_outputs) { - successive_outputs.Add(output); - successive_states.Add(states); + successive_outputs = successive_outputs.MergeWith(output); + successive_outputs = successive_states.MergeWith(states); } else { - successive_outputs = new Tensors { output }; - successive_states = new Tensors { states }; + successive_outputs = new Tensors(output); + successive_states = new Tensors(states); } } @@ -722,16 +722,11 @@ object _get_input_tensor(int time) // Get the time(0) input and compute the output for that, the output will // be used to determine the dtype of output tensor array. Don't read from // input_ta due to TensorArray clear_after_read default to True. - var inps = new Tensors(); - foreach (var inp in flatted_inptus) - { - inps.Add(inp[0]); - } - var input_time_zero = Nest.PackSequenceAs(inputs, inps).ToTensors(); + var input_time_zero = Nest.PackSequenceAs(inputs, flatted_inptus.Select(x => x[0]).ToArray()).ToTensors(); // output_time_zero is used to determine the cell output shape and its // dtype. the value is discarded. - (output_time_zero, _) = step_function((Tensor)input_time_zero, + (output_time_zero, _) = step_function(input_time_zero, constants is null ? initial_states : initial_states.MergeWith(constants)); int output_ta_size = return_all_outputs ? time_steps_t : 1; @@ -816,6 +811,7 @@ object _get_input_tensor(int time) Func cond = (time) => (time < time_steps_t); int parallel_iterations = 32; + new_states = states; if (masking_fn != null) { // Mask for the T output will be base on the output of T - 1. In the @@ -846,7 +842,7 @@ RNN step function. // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type var current_input = Nest.PackSequenceAs(inputs, flat_current_input).ToTensors(); var mask_t = masking_fn(time); - var (output, new_states_internal) = step_function(current_input, states.MergeWith(constants)); + var (output, new_states_internal) = step_function(current_input, new_states.MergeWith(constants)); // mask output var flat_output = Nest.Flatten(output).ToList(); @@ -871,11 +867,12 @@ RNN step function. new_states_internal = Nest.PackSequenceAs(new_states, flat_final_state).ToTensors(); var ta_index_to_write = return_all_outputs ? time : tf.constant(0); - // TODO(Wanglongzhi2001),deal with zip output_ta_t - foreach (var (ta, Out) in zip(output_ta_t, flat_new_output)) + output_ta_t = zip(output_ta_t, flat_new_output).Select(item => { - output_ta_t.Add(ta.write(ta_index_to_write, Out)); - } + var (ta, out_) = item; + return ta.write(ta_index_to_write, out_); + }).ToList(); + new_states_internal = Nest.PackSequenceAs(initial_states, flat_new_state).ToTensors(); @@ -921,15 +918,8 @@ Tensor _step(Tensor time) } var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: time, parallel_iterations: parallel_iterations); } - //Tensors outputs = new Tensors(); - foreach (var o in output_ta) - { - outputs.Add(o.stack()); - } - foreach (var o in outputs) - { - last_output.Add(o[-1]); - } + outputs = outputs.MergeWith(output_ta.Select(o => o.stack()).ToTensors()); + last_output = last_output.MergeWith(outputs.Select(o => o[-1]).ToTensors()); outputs = Nest.PackSequenceAs(output_time_zero, outputs).ToTensors(); last_output = Nest.PackSequenceAs(output_time_zero, last_output).ToTensors();