Skip to content

Add feature(not completed):add SimpleRNNCell, StackedRNNCell, RNN and test. #1100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,14 @@ public class GeneralizedTensorShape: IEnumerable<long?[]>, INestStructure<long?>
/// create a single-dim generalized Tensor shape.
/// </summary>
/// <param name="dim"></param>
public GeneralizedTensorShape(int dim)
public GeneralizedTensorShape(int dim, int size = 1)
{
Shapes = new TensorShapeConfig[] { new TensorShapeConfig() { Items = new long?[] { dim } } };
var elem = new TensorShapeConfig() { Items = new long?[] { dim } };
Shapes = Enumerable.Repeat(elem, size).ToArray();
//Shapes = new TensorShapeConfig[size];
//Shapes.Initialize(new TensorShapeConfig() { Items = new long?[] { dim } });
//Array.Initialize(Shapes, new TensorShapeConfig() { Items = new long?[] { dim } });
////Shapes = new TensorShapeConfig[] { new TensorShapeConfig() { Items = new long?[] { dim } } };
}

public GeneralizedTensorShape(Shape shape)
Expand Down Expand Up @@ -113,6 +118,11 @@ public INestStructure<TOut> MapStructure<TOut>(Func<long?, TOut> func)
return new Nest<long?>(Shapes.Select(s => DealWithSingleShape(s)));
}
}



public static implicit operator GeneralizedTensorShape(int dims)
=> new GeneralizedTensorShape(dims);

public IEnumerator<long?[]> GetEnumerator()
{
Expand Down
3 changes: 3 additions & 0 deletions src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ public class RNNArgs : AutoSerializeLayerArgs
[JsonProperty("cell")]
// TODO: the cell should be serialized with `serialize_keras_object`.
public IRnnCell Cell { get; set; } = null;
[JsonProperty("cells")]
public IList<IRnnCell> Cells { get; set; } = null;

[JsonProperty("return_sequences")]
public bool ReturnSequences { get; set; } = false;
[JsonProperty("return_state")]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
using System.Collections.Generic;
using Tensorflow.Keras.Layers.Rnn;

namespace Tensorflow.Keras.ArgsDefinition.Rnn
{
public class StackedRNNCellsArgs : LayerArgs
{
public IList<RnnCell> Cells { get; set; }
public IList<IRnnCell> Cells { get; set; }
public Dictionary<string, object> Kwargs { get; set; } = null;
}
}
34 changes: 34 additions & 0 deletions src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using Tensorflow.Framework.Models;
using Tensorflow.Keras.Layers.Rnn;
using Tensorflow.NumPy;
using static Google.Protobuf.Reflection.FieldDescriptorProto.Types;

Expand Down Expand Up @@ -192,6 +193,19 @@ public ILayer Rescaling(float scale,
float offset = 0,
Shape input_shape = null);

public IRnnCell SimpleRNNCell(
int units,
string activation = "tanh",
bool use_bias = true,
string kernel_initializer = "glorot_uniform",
string recurrent_initializer = "orthogonal",
string bias_initializer = "zeros",
float dropout = 0f,
float recurrent_dropout = 0f);

public IRnnCell StackedRNNCells(
IEnumerable<IRnnCell> cells);

public ILayer SimpleRNN(int units,
string activation = "tanh",
string kernel_initializer = "glorot_uniform",
Expand All @@ -200,6 +214,26 @@ public ILayer SimpleRNN(int units,
bool return_sequences = false,
bool return_state = false);

public ILayer RNN(
IRnnCell cell,
bool return_sequences = false,
bool return_state = false,
bool go_backwards = false,
bool stateful = false,
bool unroll = false,
bool time_major = false
);

public ILayer RNN(
IEnumerable<IRnnCell> cell,
bool return_sequences = false,
bool return_state = false,
bool go_backwards = false,
bool stateful = false,
bool unroll = false,
bool time_major = false
);

public ILayer Subtract();
}
}
14 changes: 13 additions & 1 deletion src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,19 @@ public TensorArray scatter(Tensor indices, Tensor value, string name = null)

return ta;
});*/
throw new NotImplementedException("");
//if (indices is EagerTensor)
//{
// indices = indices as EagerTensor;
// indices = indices.numpy();
//}

//foreach (var (index, val) in zip(indices.ToArray<int>(), array_ops.unstack(value)))
//{
// this.write(index, val);
//}
//return base;
//throw new NotImplementedException("");
return this;
}

public void _merge_element_shape(Shape shape)
Expand Down
5 changes: 4 additions & 1 deletion src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Eager;
using static Tensorflow.Binding;

namespace Tensorflow.Operations
Expand Down Expand Up @@ -146,7 +147,9 @@ public TensorArray scatter(Tensor indices, Tensor value, string name = null)

return ta;
});*/
throw new NotImplementedException("");

//throw new NotImplementedException("");
return this;
}

public void _merge_element_shape(Shape shape)
Expand Down
27 changes: 14 additions & 13 deletions src/TensorFlowNET.Keras/BackendImpl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ Tensor swap_batch_timestep(Tensor input_t)
}

}

// tf.where needs its condition tensor to be the same shape as its two
// result tensors, but in our case the condition (mask) tensor is
// (nsamples, 1), and inputs are (nsamples, ndimensions) or even more.
Expand All @@ -535,7 +535,7 @@ Tensors _expand_mask(Tensors mask_t, Tensors input_t, int fixed_dim = 1)
{
mask_t = tf.expand_dims(mask_t, -1);
}
var multiples = Enumerable.Repeat(1, fixed_dim).ToArray().concat(input_t.shape.as_int_list().ToList().GetRange(fixed_dim, input_t.rank));
var multiples = Enumerable.Repeat(1, fixed_dim).ToArray().concat(input_t.shape.as_int_list().Skip(fixed_dim).ToArray());
return tf.tile(mask_t, multiples);
}

Expand Down Expand Up @@ -570,9 +570,6 @@ Tensors _expand_mask(Tensors mask_t, Tensors input_t, int fixed_dim = 1)
// individually. The result of this will be a tuple of lists, each of
// the item in tuple is list of the tensor with shape (batch, feature)




Tensors _process_single_input_t(Tensor input_t)
{
var unstaked_input_t = array_ops.unstack(input_t); // unstack for time_step dim
Expand Down Expand Up @@ -609,7 +606,7 @@ object _get_input_tensor(int time)
var mask_list = tf.unstack(mask);
if (go_backwards)
{
mask_list.Reverse();
mask_list.Reverse().ToArray();
}

for (int i = 0; i < time_steps; i++)
Expand All @@ -629,9 +626,10 @@ object _get_input_tensor(int time)
}
else
{
prev_output = successive_outputs[successive_outputs.Length - 1];
prev_output = successive_outputs.Last();
}

// output could be a tensor
output = tf.where(tiled_mask_t, output, prev_output);

var flat_states = Nest.Flatten(states).ToList();
Expand Down Expand Up @@ -661,13 +659,13 @@ object _get_input_tensor(int time)
}

}
last_output = successive_outputs[successive_outputs.Length - 1];
new_states = successive_states[successive_states.Length - 1];
last_output = successive_outputs.Last();
new_states = successive_states.Last();
outputs = tf.stack(successive_outputs);

if (zero_output_for_mask)
{
last_output = tf.where(_expand_mask(mask_list[mask_list.Length - 1], last_output), last_output, tf.zeros_like(last_output));
last_output = tf.where(_expand_mask(mask_list.Last(), last_output), last_output, tf.zeros_like(last_output));
outputs = tf.where(_expand_mask(mask, outputs, fixed_dim: 2), outputs, tf.zeros_like(outputs));
}
else // mask is null
Expand All @@ -689,8 +687,8 @@ object _get_input_tensor(int time)
successive_states = new Tensors { newStates };
}
}
last_output = successive_outputs[successive_outputs.Length - 1];
new_states = successive_states[successive_states.Length - 1];
last_output = successive_outputs.Last();
new_states = successive_states.Last();
outputs = tf.stack(successive_outputs);
}
}
Expand All @@ -701,6 +699,8 @@ object _get_input_tensor(int time)
// Create input tensor array, if the inputs is nested tensors, then it
// will be flattened first, and tensor array will be created one per
// flattened tensor.


var input_ta = new List<TensorArray>();
for (int i = 0; i < flatted_inptus.Count; i++)
{
Expand All @@ -719,6 +719,7 @@ object _get_input_tensor(int time)
}
}


// Get the time(0) input and compute the output for that, the output will
// be used to determine the dtype of output tensor array. Don't read from
// input_ta due to TensorArray clear_after_read default to True.
Expand Down Expand Up @@ -773,7 +774,7 @@ object _get_input_tensor(int time)
return res;
};
}
// TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor)?
// TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor), it could be an integer or tensor
else if (input_length is Tensor)
{
if (go_backwards)
Expand Down
77 changes: 77 additions & 0 deletions src/TensorFlowNET.Keras/Layers/LayersApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,34 @@ public ILayer LeakyReLU(float alpha = 0.3f)
Alpha = alpha
});


public IRnnCell SimpleRNNCell(
int units,
string activation = "tanh",
bool use_bias = true,
string kernel_initializer = "glorot_uniform",
string recurrent_initializer = "orthogonal",
string bias_initializer = "zeros",
float dropout = 0f,
float recurrent_dropout = 0f)
=> new SimpleRNNCell(new SimpleRNNCellArgs
{
Units = units,
Activation = keras.activations.GetActivationFromName(activation),
UseBias = use_bias,
KernelInitializer = GetInitializerByName(kernel_initializer),
RecurrentInitializer = GetInitializerByName(recurrent_initializer),
Dropout = dropout,
RecurrentDropout = recurrent_dropout
});

public IRnnCell StackedRNNCells(
IEnumerable<IRnnCell> cells)
=> new StackedRNNCells(new StackedRNNCellsArgs
{
Cells = cells.ToList()
});

/// <summary>
///
/// </summary>
Expand All @@ -709,6 +737,55 @@ public ILayer SimpleRNN(int units,
ReturnState = return_state
});

/// <summary>
///
/// </summary>
/// <param name="cell"></param>
/// <param name="return_sequences"></param>
/// <param name="return_state"></param>
/// <param name="go_backwards"></param>
/// <param name="stateful"></param>
/// <param name="unroll"></param>
/// <param name="time_major"></param>
/// <returns></returns>
public ILayer RNN(
IRnnCell cell,
bool return_sequences = false,
bool return_state = false,
bool go_backwards = false,
bool stateful = false,
bool unroll = false,
bool time_major = false)
=> new RNN(new RNNArgs
{
Cell = cell,
ReturnSequences = return_sequences,
ReturnState = return_state,
GoBackwards = go_backwards,
Stateful = stateful,
Unroll = unroll,
TimeMajor = time_major
});

public ILayer RNN(
IEnumerable<IRnnCell> cell,
bool return_sequences = false,
bool return_state = false,
bool go_backwards = false,
bool stateful = false,
bool unroll = false,
bool time_major = false)
=> new RNN(new RNNArgs
{
Cells = cell.ToList(),
ReturnSequences = return_sequences,
ReturnState = return_state,
GoBackwards = go_backwards,
Stateful = stateful,
Unroll = unroll,
TimeMajor = time_major
});

/// <summary>
/// Long Short-Term Memory layer - Hochreiter 1997.
/// </summary>
Expand Down
15 changes: 15 additions & 0 deletions src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,21 @@ public DropoutRNNCellMixin(LayerArgs args): base(args)

}

protected void _create_non_trackable_mask_cache()
{

}

public void reset_dropout_mask()
{

}

public void reset_recurrent_dropout_mask()
{

}

public Tensors? get_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1)
{
if (dropout == 0f)
Expand Down
Loading