Skip to content

Commit 1d97b71

Browse files
authored
Merge pull request #1100 from Wanglongzhi2001/rnn-dev
Add feature(not completed):add SimpleRNNCell, StackedRNNCell, RNN and test.
2 parents 81a9d23 + db8e43b commit 1d97b71

File tree

14 files changed

+445
-119
lines changed

14 files changed

+445
-119
lines changed

src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs

+12-2
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,14 @@ public class GeneralizedTensorShape: IEnumerable<long?[]>, INestStructure<long?>
1212
/// create a single-dim generalized Tensor shape.
1313
/// </summary>
1414
/// <param name="dim"></param>
15-
public GeneralizedTensorShape(int dim)
15+
public GeneralizedTensorShape(int dim, int size = 1)
1616
{
17-
Shapes = new TensorShapeConfig[] { new TensorShapeConfig() { Items = new long?[] { dim } } };
17+
var elem = new TensorShapeConfig() { Items = new long?[] { dim } };
18+
Shapes = Enumerable.Repeat(elem, size).ToArray();
19+
//Shapes = new TensorShapeConfig[size];
20+
//Shapes.Initialize(new TensorShapeConfig() { Items = new long?[] { dim } });
21+
//Array.Initialize(Shapes, new TensorShapeConfig() { Items = new long?[] { dim } });
22+
////Shapes = new TensorShapeConfig[] { new TensorShapeConfig() { Items = new long?[] { dim } } };
1823
}
1924

2025
public GeneralizedTensorShape(Shape shape)
@@ -113,6 +118,11 @@ public INestStructure<TOut> MapStructure<TOut>(Func<long?, TOut> func)
113118
return new Nest<long?>(Shapes.Select(s => DealWithSingleShape(s)));
114119
}
115120
}
121+
122+
123+
124+
public static implicit operator GeneralizedTensorShape(int dims)
125+
=> new GeneralizedTensorShape(dims);
116126

117127
public IEnumerator<long?[]> GetEnumerator()
118128
{

src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs

+3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ public class RNNArgs : AutoSerializeLayerArgs
1010
[JsonProperty("cell")]
1111
// TODO: the cell should be serialized with `serialize_keras_object`.
1212
public IRnnCell Cell { get; set; } = null;
13+
[JsonProperty("cells")]
14+
public IList<IRnnCell> Cells { get; set; } = null;
15+
1316
[JsonProperty("return_sequences")]
1417
public bool ReturnSequences { get; set; } = false;
1518
[JsonProperty("return_state")]
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
using System.Collections.Generic;
2+
using Tensorflow.Keras.Layers.Rnn;
23

34
namespace Tensorflow.Keras.ArgsDefinition.Rnn
45
{
56
public class StackedRNNCellsArgs : LayerArgs
67
{
7-
public IList<RnnCell> Cells { get; set; }
8+
public IList<IRnnCell> Cells { get; set; }
89
public Dictionary<string, object> Kwargs { get; set; } = null;
910
}
1011
}

src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs

+34
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using Tensorflow.Framework.Models;
3+
using Tensorflow.Keras.Layers.Rnn;
34
using Tensorflow.NumPy;
45
using static Google.Protobuf.Reflection.FieldDescriptorProto.Types;
56

@@ -192,6 +193,19 @@ public ILayer Rescaling(float scale,
192193
float offset = 0,
193194
Shape input_shape = null);
194195

196+
public IRnnCell SimpleRNNCell(
197+
int units,
198+
string activation = "tanh",
199+
bool use_bias = true,
200+
string kernel_initializer = "glorot_uniform",
201+
string recurrent_initializer = "orthogonal",
202+
string bias_initializer = "zeros",
203+
float dropout = 0f,
204+
float recurrent_dropout = 0f);
205+
206+
public IRnnCell StackedRNNCells(
207+
IEnumerable<IRnnCell> cells);
208+
195209
public ILayer SimpleRNN(int units,
196210
string activation = "tanh",
197211
string kernel_initializer = "glorot_uniform",
@@ -200,6 +214,26 @@ public ILayer SimpleRNN(int units,
200214
bool return_sequences = false,
201215
bool return_state = false);
202216

217+
public ILayer RNN(
218+
IRnnCell cell,
219+
bool return_sequences = false,
220+
bool return_state = false,
221+
bool go_backwards = false,
222+
bool stateful = false,
223+
bool unroll = false,
224+
bool time_major = false
225+
);
226+
227+
public ILayer RNN(
228+
IEnumerable<IRnnCell> cell,
229+
bool return_sequences = false,
230+
bool return_state = false,
231+
bool go_backwards = false,
232+
bool stateful = false,
233+
bool unroll = false,
234+
bool time_major = false
235+
);
236+
203237
public ILayer Subtract();
204238
}
205239
}

src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs

+13-1
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,19 @@ public TensorArray scatter(Tensor indices, Tensor value, string name = null)
109109
110110
return ta;
111111
});*/
112-
throw new NotImplementedException("");
112+
//if (indices is EagerTensor)
113+
//{
114+
// indices = indices as EagerTensor;
115+
// indices = indices.numpy();
116+
//}
117+
118+
//foreach (var (index, val) in zip(indices.ToArray<int>(), array_ops.unstack(value)))
119+
//{
120+
// this.write(index, val);
121+
//}
122+
//return base;
123+
//throw new NotImplementedException("");
124+
return this;
113125
}
114126

115127
public void _merge_element_shape(Shape shape)

src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs

+4-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
using System;
1818
using System.Collections.Generic;
1919
using System.Linq;
20+
using Tensorflow.Eager;
2021
using static Tensorflow.Binding;
2122

2223
namespace Tensorflow.Operations
@@ -146,7 +147,9 @@ public TensorArray scatter(Tensor indices, Tensor value, string name = null)
146147
147148
return ta;
148149
});*/
149-
throw new NotImplementedException("");
150+
151+
//throw new NotImplementedException("");
152+
return this;
150153
}
151154

152155
public void _merge_element_shape(Shape shape)

src/TensorFlowNET.Keras/BackendImpl.cs

+14-13
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,7 @@ Tensor swap_batch_timestep(Tensor input_t)
510510
}
511511

512512
}
513-
513+
514514
// tf.where needs its condition tensor to be the same shape as its two
515515
// result tensors, but in our case the condition (mask) tensor is
516516
// (nsamples, 1), and inputs are (nsamples, ndimensions) or even more.
@@ -535,7 +535,7 @@ Tensors _expand_mask(Tensors mask_t, Tensors input_t, int fixed_dim = 1)
535535
{
536536
mask_t = tf.expand_dims(mask_t, -1);
537537
}
538-
var multiples = Enumerable.Repeat(1, fixed_dim).ToArray().concat(input_t.shape.as_int_list().ToList().GetRange(fixed_dim, input_t.rank));
538+
var multiples = Enumerable.Repeat(1, fixed_dim).ToArray().concat(input_t.shape.as_int_list().Skip(fixed_dim).ToArray());
539539
return tf.tile(mask_t, multiples);
540540
}
541541

@@ -570,9 +570,6 @@ Tensors _expand_mask(Tensors mask_t, Tensors input_t, int fixed_dim = 1)
570570
// individually. The result of this will be a tuple of lists, each of
571571
// the item in tuple is list of the tensor with shape (batch, feature)
572572

573-
574-
575-
576573
Tensors _process_single_input_t(Tensor input_t)
577574
{
578575
var unstaked_input_t = array_ops.unstack(input_t); // unstack for time_step dim
@@ -609,7 +606,7 @@ object _get_input_tensor(int time)
609606
var mask_list = tf.unstack(mask);
610607
if (go_backwards)
611608
{
612-
mask_list.Reverse();
609+
mask_list.Reverse().ToArray();
613610
}
614611

615612
for (int i = 0; i < time_steps; i++)
@@ -629,9 +626,10 @@ object _get_input_tensor(int time)
629626
}
630627
else
631628
{
632-
prev_output = successive_outputs[successive_outputs.Length - 1];
629+
prev_output = successive_outputs.Last();
633630
}
634631

632+
// output could be a tensor
635633
output = tf.where(tiled_mask_t, output, prev_output);
636634

637635
var flat_states = Nest.Flatten(states).ToList();
@@ -661,13 +659,13 @@ object _get_input_tensor(int time)
661659
}
662660

663661
}
664-
last_output = successive_outputs[successive_outputs.Length - 1];
665-
new_states = successive_states[successive_states.Length - 1];
662+
last_output = successive_outputs.Last();
663+
new_states = successive_states.Last();
666664
outputs = tf.stack(successive_outputs);
667665

668666
if (zero_output_for_mask)
669667
{
670-
last_output = tf.where(_expand_mask(mask_list[mask_list.Length - 1], last_output), last_output, tf.zeros_like(last_output));
668+
last_output = tf.where(_expand_mask(mask_list.Last(), last_output), last_output, tf.zeros_like(last_output));
671669
outputs = tf.where(_expand_mask(mask, outputs, fixed_dim: 2), outputs, tf.zeros_like(outputs));
672670
}
673671
else // mask is null
@@ -689,8 +687,8 @@ object _get_input_tensor(int time)
689687
successive_states = new Tensors { newStates };
690688
}
691689
}
692-
last_output = successive_outputs[successive_outputs.Length - 1];
693-
new_states = successive_states[successive_states.Length - 1];
690+
last_output = successive_outputs.Last();
691+
new_states = successive_states.Last();
694692
outputs = tf.stack(successive_outputs);
695693
}
696694
}
@@ -701,6 +699,8 @@ object _get_input_tensor(int time)
701699
// Create input tensor array, if the inputs is nested tensors, then it
702700
// will be flattened first, and tensor array will be created one per
703701
// flattened tensor.
702+
703+
704704
var input_ta = new List<TensorArray>();
705705
for (int i = 0; i < flatted_inptus.Count; i++)
706706
{
@@ -719,6 +719,7 @@ object _get_input_tensor(int time)
719719
}
720720
}
721721

722+
722723
// Get the time(0) input and compute the output for that, the output will
723724
// be used to determine the dtype of output tensor array. Don't read from
724725
// input_ta due to TensorArray clear_after_read default to True.
@@ -773,7 +774,7 @@ object _get_input_tensor(int time)
773774
return res;
774775
};
775776
}
776-
// TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor)?
777+
// TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor), it could be an integer or tensor
777778
else if (input_length is Tensor)
778779
{
779780
if (go_backwards)

src/TensorFlowNET.Keras/Layers/LayersApi.cs

+77
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,34 @@ public ILayer LeakyReLU(float alpha = 0.3f)
685685
Alpha = alpha
686686
});
687687

688+
689+
public IRnnCell SimpleRNNCell(
690+
int units,
691+
string activation = "tanh",
692+
bool use_bias = true,
693+
string kernel_initializer = "glorot_uniform",
694+
string recurrent_initializer = "orthogonal",
695+
string bias_initializer = "zeros",
696+
float dropout = 0f,
697+
float recurrent_dropout = 0f)
698+
=> new SimpleRNNCell(new SimpleRNNCellArgs
699+
{
700+
Units = units,
701+
Activation = keras.activations.GetActivationFromName(activation),
702+
UseBias = use_bias,
703+
KernelInitializer = GetInitializerByName(kernel_initializer),
704+
RecurrentInitializer = GetInitializerByName(recurrent_initializer),
705+
Dropout = dropout,
706+
RecurrentDropout = recurrent_dropout
707+
});
708+
709+
public IRnnCell StackedRNNCells(
710+
IEnumerable<IRnnCell> cells)
711+
=> new StackedRNNCells(new StackedRNNCellsArgs
712+
{
713+
Cells = cells.ToList()
714+
});
715+
688716
/// <summary>
689717
///
690718
/// </summary>
@@ -709,6 +737,55 @@ public ILayer SimpleRNN(int units,
709737
ReturnState = return_state
710738
});
711739

740+
/// <summary>
741+
///
742+
/// </summary>
743+
/// <param name="cell"></param>
744+
/// <param name="return_sequences"></param>
745+
/// <param name="return_state"></param>
746+
/// <param name="go_backwards"></param>
747+
/// <param name="stateful"></param>
748+
/// <param name="unroll"></param>
749+
/// <param name="time_major"></param>
750+
/// <returns></returns>
751+
public ILayer RNN(
752+
IRnnCell cell,
753+
bool return_sequences = false,
754+
bool return_state = false,
755+
bool go_backwards = false,
756+
bool stateful = false,
757+
bool unroll = false,
758+
bool time_major = false)
759+
=> new RNN(new RNNArgs
760+
{
761+
Cell = cell,
762+
ReturnSequences = return_sequences,
763+
ReturnState = return_state,
764+
GoBackwards = go_backwards,
765+
Stateful = stateful,
766+
Unroll = unroll,
767+
TimeMajor = time_major
768+
});
769+
770+
public ILayer RNN(
771+
IEnumerable<IRnnCell> cell,
772+
bool return_sequences = false,
773+
bool return_state = false,
774+
bool go_backwards = false,
775+
bool stateful = false,
776+
bool unroll = false,
777+
bool time_major = false)
778+
=> new RNN(new RNNArgs
779+
{
780+
Cells = cell.ToList(),
781+
ReturnSequences = return_sequences,
782+
ReturnState = return_state,
783+
GoBackwards = go_backwards,
784+
Stateful = stateful,
785+
Unroll = unroll,
786+
TimeMajor = time_major
787+
});
788+
712789
/// <summary>
713790
/// Long Short-Term Memory layer - Hochreiter 1997.
714791
/// </summary>

src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs

+15
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,21 @@ public DropoutRNNCellMixin(LayerArgs args): base(args)
1717

1818
}
1919

20+
protected void _create_non_trackable_mask_cache()
21+
{
22+
23+
}
24+
25+
public void reset_dropout_mask()
26+
{
27+
28+
}
29+
30+
public void reset_recurrent_dropout_mask()
31+
{
32+
33+
}
34+
2035
public Tensors? get_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1)
2136
{
2237
if (dropout == 0f)

0 commit comments

Comments
 (0)