Skip to content

Commit 537b3e1

Browse files
committed
feat: support simple RNN.
1 parent 4939105 commit 537b3e1

File tree

9 files changed

+507
-414
lines changed

9 files changed

+507
-414
lines changed

src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@ public interface IRnnCell: ILayer
99
{
1010
GeneralizedTensorShape StateSize { get; }
1111
GeneralizedTensorShape OutputSize { get; }
12+
bool IsTFRnnCell { get; }
1213
/// <summary>
1314
/// Whether the optional RNN args are supported when appying the layer.
1415
/// In other words, whether `Apply` is overwrited with process of `RnnOptionalArgs`.
1516
/// </summary>
1617
bool SupportOptionalArgs { get; }
17-
(Tensor, Tensors) Call(Tensors inputs, Tensors states, bool? training = null);
1818
}
1919
}

src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs

+1
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ public void adapt(Tensor data, int? batch_size = null, int? steps = null)
183183
}
184184
public GeneralizedTensorShape StateSize => throw new NotImplementedException();
185185
public GeneralizedTensorShape OutputSize => throw new NotImplementedException();
186+
public bool IsTFRnnCell => throw new NotImplementedException();
186187
public bool SupportOptionalArgs => throw new NotImplementedException();
187188
}
188189
}

src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs

+98-19
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
using System;
1818
using System.Collections.Generic;
1919
using System.Linq;
20+
using Tensorflow.Eager;
2021
using Tensorflow.Framework;
2122
using static Tensorflow.Binding;
2223

@@ -48,6 +49,7 @@ public class _EagerTensorArray : TensorArray
4849
public override Tensor flow => _flow;
4950
bool _clear_after_read;
5051
List<Tensor> _tensor_array;
52+
List<int> _previous_read_indices;
5153

5254
public _EagerTensorArray(TF_DataType dtype, Tensor size, bool dynamic_size = false,
5355
bool clear_after_read = true, string tensor_array_name = null, Tensor handle = null, Tensor flow = null,
@@ -61,16 +63,20 @@ public _EagerTensorArray(TF_DataType dtype, Tensor size, bool dynamic_size = fal
6163
_dtype = dtype.as_base_dtype();
6264
_dynamic_size = dynamic_size;
6365
_clear_after_read = clear_after_read;
64-
_tensor_array = new List<Tensor>();
66+
_tensor_array = Enumerable.Repeat<Tensor>(null, size.numpy()).ToList();
67+
_previous_read_indices = new();
6568
}
6669

6770
public override TensorArray unstack(Tensor value, string name = null)
6871
{
69-
return tf_with(ops.name_scope(name, "TensorArrayUnstack", new { _handle, value }), delegate
72+
var tensors = array_ops.unstack(value, name: name);
73+
if(tensors.Length > _tensor_array.Count && !_dynamic_size)
7074
{
71-
var num_elements = array_ops.shape(value)[0];
72-
return scatter(indices: math_ops.range(0, num_elements), value: value, name: name);
73-
});
75+
throw new ValueError($"Cannot unstack {tensors.Length} tensors into a TensorArray of static size {_tensor_array.Count}");
76+
}
77+
_tensor_array = tensors.ToList();
78+
// TODO(Rinne): revise the implementation. Here we should return `parent()`.
79+
return this;
7480
}
7581

7682
public TensorArray scatter(Tensor indices, Tensor value, string name = null)
@@ -116,37 +122,95 @@ public void _maybe_colocate_with(Tensor value)
116122
_colocate_with.Add(value);
117123
}
118124

125+
private Tensor _maybe_zero(int ix)
126+
{
127+
var val = _tensor_array[ix];
128+
if(val is null)
129+
{
130+
val = _tensor_array[ix] = array_ops.zeros(_element_shape, _dtype);
131+
}
132+
return val;
133+
}
134+
119135
public override Tensor read<T>(T index, string name = null)
120136
{
121-
int index_int = -1;
137+
int index_int;
122138
if (index is int int_index)
123139
index_int = int_index;
124140
else if (index is Tensor tensor_index)
125141
index_int = tensor_index.numpy();
126142
else
127143
throw new ValueError("");
128144

145+
if(index_int >= _tensor_array.Count)
146+
{
147+
throw new OutOfRangeError($"Tried to read from index {index_int} but array size is: {_tensor_array.Count} ");
148+
}
149+
150+
var res = _tensor_array[index_int];
151+
if(res is null)
152+
{
153+
if (_previous_read_indices.Contains(index_int))
154+
{
155+
throw new InvalidArgumentError($"Could not read index {index_int} twice because it was cleared after " +
156+
$"a previous read (perhaps try setting clear_after_read = false?)");
157+
}
158+
else
159+
{
160+
res = _maybe_zero(index_int);
161+
}
162+
}
163+
129164
if (_clear_after_read)
130165
{
131166
_tensor_array[index_int] = null;
167+
_previous_read_indices.Add(index_int);
132168
}
133-
134-
return _tensor_array[index_int];
169+
return res;
135170
}
136171

137172
public override TensorArray write(Tensor index, Tensor value, string name = null)
138173
{
139-
if (_infer_shape)
140-
_element_shape = _element_shape.merge_with(value.shape);
141-
_tensor_array.add(value);
142-
return this;
174+
int index_int;
175+
if(index is EagerTensor eager)
176+
{
177+
return write<Tensor>(eager.numpy(), value, name);
178+
}
179+
throw new InvalidArgumentError("The index is supposed to be an EagerTensor");
143180
}
144181

145182
public override TensorArray write<T>(int index, T value, string name = null)
146183
{
147-
var value_tensor = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value");
148-
var index_tensor = ops.convert_to_tensor(index, name: "index");
149-
return write(index_tensor, value_tensor, name: name);
184+
int size = _tensor_array.Count;
185+
if(index >= size)
186+
{
187+
if (!_dynamic_size)
188+
{
189+
throw new OutOfRangeError($"Tried to write to index {index} but array is not resizeable and size " +
190+
$"is: {size} ");
191+
}
192+
_tensor_array.AddRange(Enumerable.Repeat<Tensor>(null, index - size + 1));
193+
}
194+
195+
Tensor tensor = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value");
196+
197+
if(_dtype != tensor.dtype)
198+
{
199+
throw new InvalidArgumentError($"TensorArray dtype is {_dtype.as_python_name()} but Op is " +
200+
$"trying to write dtype {tensor.dtype.as_python_name()} ");
201+
}
202+
203+
if (!_element_shape.is_compatible_with(tensor.shape))
204+
{
205+
throw new ValueError($"Incompatible shape for value ({tensor.shape}), expected ({_element_shape})");
206+
}
207+
208+
if (_infer_shape)
209+
{
210+
_element_shape = _element_shape.merge_with(tensor.shape);
211+
}
212+
_tensor_array[index] = tensor;
213+
return this;
150214
}
151215

152216
private Tensor size(string name = null)
@@ -156,11 +220,26 @@ private Tensor size(string name = null)
156220

157221
public override Tensor stack(string name = null)
158222
{
159-
ops.colocate_with(_handle);
160-
return tf_with(ops.name_scope(name, "TensorArrayStack", new { _handle }), delegate
223+
if(_tensor_array.Count > 0)
161224
{
162-
return gather(math_ops.range(0, size()), name: name);
163-
});
225+
for(int i = 0; i < _tensor_array.Count; i++)
226+
{
227+
_maybe_zero(i);
228+
}
229+
}
230+
if(_tensor_array.Count == 0 && _element_shape.IsFullyDefined)
231+
{
232+
return ops.convert_to_tensor(new Shape(new long[] { 0 }.Concat(_element_shape.dims).ToArray()), name: name, dtype: _dtype);
233+
}
234+
else
235+
{
236+
return ops.convert_to_tensor(_tensor_array, name: name, dtype: _dtype);
237+
}
238+
//ops.colocate_with(_handle);
239+
//return tf_with(ops.name_scope(name, "TensorArrayStack", new { _handle }), delegate
240+
//{
241+
// return gather(math_ops.range(0, size()), name: name);
242+
//});
164243
}
165244

166245
public override Tensor gather(Tensor indices, string name = null)

0 commit comments

Comments
 (0)