Skip to content

fix: some possible errors of RNN. #1098

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 7, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 31 additions & 10 deletions src/TensorFlowNET.Core/Tensors/Tensors.cs
Original file line number Diff line number Diff line change
@@ -58,17 +58,12 @@ public Tensor? SingleOrNull
public Tensor this[params string[] slices]
=> this.First()[slices];

public Tensors(Tensor tensor) : base(tensor)
{

}

private Tensors(Nest<Tensor> nested) : base(nested)
{

}

public Tensors(params Tensor[] tensors): base(tensors.Select(x => new Nest<Tensor>(x)))
public Tensors(params Tensor[] tensors): base(DealWithConstructorArrayInput(tensors))
{

}
@@ -83,6 +78,22 @@ public Tensors(NDArray nd): base(ops.convert_to_tensor(nd))

}

private static Nest<Tensor> DealWithConstructorArrayInput(Tensor[] tensors)
{
if (tensors.Length == 0)
{
return Nest<Tensor>.Empty;
}
else if(tensors.Length == 1)
{
return new Nest<Tensor>(tensors[0]);
}
else
{
return new Nest<Tensor>(tensors.Select(x => new Nest<Tensor>(x)));
}
}

public bool IsSingle()
{
return Length == 1;
@@ -107,9 +118,14 @@ public void Add(Tensor tensor)
ListValue = new() { new Nest<Tensor>(Value), new Nest<Tensor>(tensor) };
Value = null;
}
else
else if(NestType == NestType.List)
{
ListValue!.Add(new Nest<Tensor>(tensor));
}
else //Empty
{
ListValue.Add(new Nest<Tensor>(tensor));
NestType = NestType.Node;
Value = tensor;
}
}

@@ -128,9 +144,14 @@ public void AddRange(IEnumerable<Tensor> tensors)
ListValue.AddRange(tensors.Select(x => new Nest<Tensor>(x)));
Value = null;
}
else
else if(NestType == NestType.List)
{
ListValue.AddRange(tensors.Select(x => new Nest<Tensor>(x)));
ListValue!.AddRange(tensors.Select(x => new Nest<Tensor>(x)));
}
else // empty
{
NestType = NestType.List;
ListValue = tensors.Select(x => new Nest<Tensor>(x)).ToList();
}
}

40 changes: 15 additions & 25 deletions src/TensorFlowNET.Keras/BackendImpl.cs
Original file line number Diff line number Diff line change
@@ -651,13 +651,13 @@ object _get_input_tensor(int time)
states = Nest.PackSequenceAs(states, flat_final_states).ToTensors();
if (return_all_outputs)
{
successive_outputs.Add(output);
successive_states.Add(states);
successive_outputs = successive_outputs.MergeWith(output);
successive_outputs = successive_states.MergeWith(states);
}
else
{
successive_outputs = new Tensors { output };
successive_states = new Tensors { states };
successive_outputs = new Tensors(output);
successive_states = new Tensors(states);
}

}
@@ -722,16 +722,11 @@ object _get_input_tensor(int time)
// Get the time(0) input and compute the output for that, the output will
// be used to determine the dtype of output tensor array. Don't read from
// input_ta due to TensorArray clear_after_read default to True.
var inps = new Tensors();
foreach (var inp in flatted_inptus)
{
inps.Add(inp[0]);
}
var input_time_zero = Nest.PackSequenceAs(inputs, inps).ToTensors();
var input_time_zero = Nest.PackSequenceAs(inputs, flatted_inptus.Select(x => x[0]).ToArray()).ToTensors();

// output_time_zero is used to determine the cell output shape and its
// dtype. the value is discarded.
(output_time_zero, _) = step_function((Tensor)input_time_zero,
(output_time_zero, _) = step_function(input_time_zero,
constants is null ? initial_states : initial_states.MergeWith(constants));

int output_ta_size = return_all_outputs ? time_steps_t : 1;
@@ -816,6 +811,7 @@ object _get_input_tensor(int time)

Func<Tensor, Tensor> cond = (time) => (time < time_steps_t);
int parallel_iterations = 32;
new_states = states;
if (masking_fn != null)
{
// Mask for the T output will be base on the output of T - 1. In the
@@ -846,7 +842,7 @@ RNN step function.
// TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type
var current_input = Nest.PackSequenceAs(inputs, flat_current_input).ToTensors();
var mask_t = masking_fn(time);
var (output, new_states_internal) = step_function(current_input, states.MergeWith(constants));
var (output, new_states_internal) = step_function(current_input, new_states.MergeWith(constants));
// mask output
var flat_output = Nest.Flatten(output).ToList();

@@ -871,11 +867,12 @@ RNN step function.
new_states_internal = Nest.PackSequenceAs(new_states, flat_final_state).ToTensors();

var ta_index_to_write = return_all_outputs ? time : tf.constant(0);
// TODO(Wanglongzhi2001),deal with zip output_ta_t
foreach (var (ta, Out) in zip(output_ta_t, flat_new_output))
output_ta_t = zip(output_ta_t, flat_new_output).Select(item =>
{
output_ta_t.Add(ta.write(ta_index_to_write, Out));
}
var (ta, out_) = item;
return ta.write(ta_index_to_write, out_);
}).ToList();


new_states_internal = Nest.PackSequenceAs(initial_states, flat_new_state).ToTensors();

@@ -921,15 +918,8 @@ Tensor _step(Tensor time)
}
var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: time, parallel_iterations: parallel_iterations);
}
//Tensors outputs = new Tensors();
foreach (var o in output_ta)
{
outputs.Add(o.stack());
}
foreach (var o in outputs)
{
last_output.Add(o[-1]);
}
outputs = outputs.MergeWith(output_ta.Select(o => o.stack()).ToTensors());
last_output = last_output.MergeWith(outputs.Select(o => o[-1]).ToTensors());
outputs = Nest.PackSequenceAs(output_time_zero, outputs).ToTensors();
last_output = Nest.PackSequenceAs(output_time_zero, last_output).ToTensors();