Skip to content

Commit 81a9d23

Browse files
authored
Merge pull request #1098 from AsakusaRinne/rnn-dev
fix: some possible errors of RNN.
2 parents 9da157f + dcaa0f4 commit 81a9d23

File tree

2 files changed

+46
-35
lines changed

2 files changed

+46
-35
lines changed

src/TensorFlowNET.Core/Tensors/Tensors.cs

+31-10
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,12 @@ public Tensor? SingleOrNull
5858
public Tensor this[params string[] slices]
5959
=> this.First()[slices];
6060

61-
public Tensors(Tensor tensor) : base(tensor)
62-
{
63-
64-
}
65-
6661
private Tensors(Nest<Tensor> nested) : base(nested)
6762
{
6863

6964
}
7065

71-
public Tensors(params Tensor[] tensors): base(tensors.Select(x => new Nest<Tensor>(x)))
66+
public Tensors(params Tensor[] tensors): base(DealWithConstructorArrayInput(tensors))
7267
{
7368

7469
}
@@ -83,6 +78,22 @@ public Tensors(NDArray nd): base(ops.convert_to_tensor(nd))
8378

8479
}
8580

81+
private static Nest<Tensor> DealWithConstructorArrayInput(Tensor[] tensors)
82+
{
83+
if (tensors.Length == 0)
84+
{
85+
return Nest<Tensor>.Empty;
86+
}
87+
else if(tensors.Length == 1)
88+
{
89+
return new Nest<Tensor>(tensors[0]);
90+
}
91+
else
92+
{
93+
return new Nest<Tensor>(tensors.Select(x => new Nest<Tensor>(x)));
94+
}
95+
}
96+
8697
public bool IsSingle()
8798
{
8899
return Length == 1;
@@ -107,9 +118,14 @@ public void Add(Tensor tensor)
107118
ListValue = new() { new Nest<Tensor>(Value), new Nest<Tensor>(tensor) };
108119
Value = null;
109120
}
110-
else
121+
else if(NestType == NestType.List)
122+
{
123+
ListValue!.Add(new Nest<Tensor>(tensor));
124+
}
125+
else //Empty
111126
{
112-
ListValue.Add(new Nest<Tensor>(tensor));
127+
NestType = NestType.Node;
128+
Value = tensor;
113129
}
114130
}
115131

@@ -128,9 +144,14 @@ public void AddRange(IEnumerable<Tensor> tensors)
128144
ListValue.AddRange(tensors.Select(x => new Nest<Tensor>(x)));
129145
Value = null;
130146
}
131-
else
147+
else if(NestType == NestType.List)
132148
{
133-
ListValue.AddRange(tensors.Select(x => new Nest<Tensor>(x)));
149+
ListValue!.AddRange(tensors.Select(x => new Nest<Tensor>(x)));
150+
}
151+
else // empty
152+
{
153+
NestType = NestType.List;
154+
ListValue = tensors.Select(x => new Nest<Tensor>(x)).ToList();
134155
}
135156
}
136157

src/TensorFlowNET.Keras/BackendImpl.cs

+15-25
Original file line numberDiff line numberDiff line change
@@ -651,13 +651,13 @@ object _get_input_tensor(int time)
651651
states = Nest.PackSequenceAs(states, flat_final_states).ToTensors();
652652
if (return_all_outputs)
653653
{
654-
successive_outputs.Add(output);
655-
successive_states.Add(states);
654+
successive_outputs = successive_outputs.MergeWith(output);
655+
successive_outputs = successive_states.MergeWith(states);
656656
}
657657
else
658658
{
659-
successive_outputs = new Tensors { output };
660-
successive_states = new Tensors { states };
659+
successive_outputs = new Tensors(output);
660+
successive_states = new Tensors(states);
661661
}
662662

663663
}
@@ -722,16 +722,11 @@ object _get_input_tensor(int time)
722722
// Get the time(0) input and compute the output for that, the output will
723723
// be used to determine the dtype of output tensor array. Don't read from
724724
// input_ta due to TensorArray clear_after_read default to True.
725-
var inps = new Tensors();
726-
foreach (var inp in flatted_inptus)
727-
{
728-
inps.Add(inp[0]);
729-
}
730-
var input_time_zero = Nest.PackSequenceAs(inputs, inps).ToTensors();
725+
var input_time_zero = Nest.PackSequenceAs(inputs, flatted_inptus.Select(x => x[0]).ToArray()).ToTensors();
731726

732727
// output_time_zero is used to determine the cell output shape and its
733728
// dtype. the value is discarded.
734-
(output_time_zero, _) = step_function((Tensor)input_time_zero,
729+
(output_time_zero, _) = step_function(input_time_zero,
735730
constants is null ? initial_states : initial_states.MergeWith(constants));
736731

737732
int output_ta_size = return_all_outputs ? time_steps_t : 1;
@@ -816,6 +811,7 @@ object _get_input_tensor(int time)
816811

817812
Func<Tensor, Tensor> cond = (time) => (time < time_steps_t);
818813
int parallel_iterations = 32;
814+
new_states = states;
819815
if (masking_fn != null)
820816
{
821817
// Mask for the T output will be base on the output of T - 1. In the
@@ -846,7 +842,7 @@ RNN step function.
846842
// TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type
847843
var current_input = Nest.PackSequenceAs(inputs, flat_current_input).ToTensors();
848844
var mask_t = masking_fn(time);
849-
var (output, new_states_internal) = step_function(current_input, states.MergeWith(constants));
845+
var (output, new_states_internal) = step_function(current_input, new_states.MergeWith(constants));
850846
// mask output
851847
var flat_output = Nest.Flatten(output).ToList();
852848

@@ -871,11 +867,12 @@ RNN step function.
871867
new_states_internal = Nest.PackSequenceAs(new_states, flat_final_state).ToTensors();
872868

873869
var ta_index_to_write = return_all_outputs ? time : tf.constant(0);
874-
// TODO(Wanglongzhi2001),deal with zip output_ta_t
875-
foreach (var (ta, Out) in zip(output_ta_t, flat_new_output))
870+
output_ta_t = zip(output_ta_t, flat_new_output).Select(item =>
876871
{
877-
output_ta_t.Add(ta.write(ta_index_to_write, Out));
878-
}
872+
var (ta, out_) = item;
873+
return ta.write(ta_index_to_write, out_);
874+
}).ToList();
875+
879876

880877
new_states_internal = Nest.PackSequenceAs(initial_states, flat_new_state).ToTensors();
881878

@@ -921,15 +918,8 @@ Tensor _step(Tensor time)
921918
}
922919
var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: time, parallel_iterations: parallel_iterations);
923920
}
924-
//Tensors outputs = new Tensors();
925-
foreach (var o in output_ta)
926-
{
927-
outputs.Add(o.stack());
928-
}
929-
foreach (var o in outputs)
930-
{
931-
last_output.Add(o[-1]);
932-
}
921+
outputs = outputs.MergeWith(output_ta.Select(o => o.stack()).ToTensors());
922+
last_output = last_output.MergeWith(outputs.Select(o => o[-1]).ToTensors());
933923
outputs = Nest.PackSequenceAs(output_time_zero, outputs).ToTensors();
934924
last_output = Nest.PackSequenceAs(output_time_zero, last_output).ToTensors();
935925

0 commit comments

Comments
 (0)