@@ -17,6 +17,7 @@ limitations under the License.
17
17
using System ;
18
18
using System . Collections . Generic ;
19
19
using System . Linq ;
20
+ using Tensorflow . Eager ;
20
21
using Tensorflow . Framework ;
21
22
using static Tensorflow . Binding ;
22
23
@@ -48,6 +49,7 @@ public class _EagerTensorArray : TensorArray
48
49
public override Tensor flow => _flow ;
49
50
bool _clear_after_read ;
50
51
List < Tensor > _tensor_array ;
52
+ List < int > _previous_read_indices ;
51
53
52
54
public _EagerTensorArray ( TF_DataType dtype , Tensor size , bool dynamic_size = false ,
53
55
bool clear_after_read = true , string tensor_array_name = null , Tensor handle = null , Tensor flow = null ,
@@ -61,16 +63,20 @@ public _EagerTensorArray(TF_DataType dtype, Tensor size, bool dynamic_size = fal
61
63
_dtype = dtype . as_base_dtype ( ) ;
62
64
_dynamic_size = dynamic_size ;
63
65
_clear_after_read = clear_after_read ;
64
- _tensor_array = new List < Tensor > ( ) ;
66
+ _tensor_array = Enumerable . Repeat < Tensor > ( null , size . numpy ( ) ) . ToList ( ) ;
67
+ _previous_read_indices = new ( ) ;
65
68
}
66
69
67
70
public override TensorArray unstack ( Tensor value , string name = null )
68
71
{
69
- return tf_with ( ops . name_scope ( name , "TensorArrayUnstack" , new { _handle , value } ) , delegate
72
+ var tensors = array_ops . unstack ( value , name : name ) ;
73
+ if ( tensors . Length > _tensor_array . Count && ! _dynamic_size )
70
74
{
71
- var num_elements = array_ops . shape ( value ) [ 0 ] ;
72
- return scatter ( indices : math_ops . range ( 0 , num_elements ) , value : value , name : name ) ;
73
- } ) ;
75
+ throw new ValueError ( $ "Cannot unstack { tensors . Length } tensors into a TensorArray of static size { _tensor_array . Count } ") ;
76
+ }
77
+ _tensor_array = tensors . ToList ( ) ;
78
+ // TODO(Rinne): revise the implementation. Here we should return `parent()`.
79
+ return this ;
74
80
}
75
81
76
82
public TensorArray scatter ( Tensor indices , Tensor value , string name = null )
@@ -116,37 +122,95 @@ public void _maybe_colocate_with(Tensor value)
116
122
_colocate_with . Add ( value ) ;
117
123
}
118
124
125
+ private Tensor _maybe_zero ( int ix )
126
+ {
127
+ var val = _tensor_array [ ix ] ;
128
+ if ( val is null )
129
+ {
130
+ val = _tensor_array [ ix ] = array_ops . zeros ( _element_shape , _dtype ) ;
131
+ }
132
+ return val ;
133
+ }
134
+
119
135
public override Tensor read < T > ( T index , string name = null )
120
136
{
121
- int index_int = - 1 ;
137
+ int index_int ;
122
138
if ( index is int int_index )
123
139
index_int = int_index ;
124
140
else if ( index is Tensor tensor_index )
125
141
index_int = tensor_index . numpy ( ) ;
126
142
else
127
143
throw new ValueError ( "" ) ;
128
144
145
+ if ( index_int >= _tensor_array . Count )
146
+ {
147
+ throw new OutOfRangeError ( $ "Tried to read from index { index_int } but array size is: { _tensor_array . Count } ") ;
148
+ }
149
+
150
+ var res = _tensor_array [ index_int ] ;
151
+ if ( res is null )
152
+ {
153
+ if ( _previous_read_indices . Contains ( index_int ) )
154
+ {
155
+ throw new InvalidArgumentError ( $ "Could not read index { index_int } twice because it was cleared after " +
156
+ $ "a previous read (perhaps try setting clear_after_read = false?)") ;
157
+ }
158
+ else
159
+ {
160
+ res = _maybe_zero ( index_int ) ;
161
+ }
162
+ }
163
+
129
164
if ( _clear_after_read )
130
165
{
131
166
_tensor_array [ index_int ] = null ;
167
+ _previous_read_indices . Add ( index_int ) ;
132
168
}
133
-
134
- return _tensor_array [ index_int ] ;
169
+ return res ;
135
170
}
136
171
137
172
public override TensorArray write ( Tensor index , Tensor value , string name = null )
138
173
{
139
- if ( _infer_shape )
140
- _element_shape = _element_shape . merge_with ( value . shape ) ;
141
- _tensor_array . add ( value ) ;
142
- return this ;
174
+ int index_int ;
175
+ if ( index is EagerTensor eager )
176
+ {
177
+ return write < Tensor > ( eager . numpy ( ) , value , name ) ;
178
+ }
179
+ throw new InvalidArgumentError ( "The index is supposed to be an EagerTensor" ) ;
143
180
}
144
181
145
182
public override TensorArray write < T > ( int index , T value , string name = null )
146
183
{
147
- var value_tensor = ops . convert_to_tensor ( value , preferred_dtype : _dtype , name : "value" ) ;
148
- var index_tensor = ops . convert_to_tensor ( index , name : "index" ) ;
149
- return write ( index_tensor , value_tensor , name : name ) ;
184
+ int size = _tensor_array . Count ;
185
+ if ( index >= size )
186
+ {
187
+ if ( ! _dynamic_size )
188
+ {
189
+ throw new OutOfRangeError ( $ "Tried to write to index { index } but array is not resizeable and size " +
190
+ $ "is: { size } ") ;
191
+ }
192
+ _tensor_array . AddRange ( Enumerable . Repeat < Tensor > ( null , index - size + 1 ) ) ;
193
+ }
194
+
195
+ Tensor tensor = ops . convert_to_tensor ( value , preferred_dtype : _dtype , name : "value" ) ;
196
+
197
+ if ( _dtype != tensor . dtype )
198
+ {
199
+ throw new InvalidArgumentError ( $ "TensorArray dtype is { _dtype . as_python_name ( ) } but Op is " +
200
+ $ "trying to write dtype { tensor . dtype . as_python_name ( ) } ") ;
201
+ }
202
+
203
+ if ( ! _element_shape . is_compatible_with ( tensor . shape ) )
204
+ {
205
+ throw new ValueError ( $ "Incompatible shape for value ({ tensor . shape } ), expected ({ _element_shape } )") ;
206
+ }
207
+
208
+ if ( _infer_shape )
209
+ {
210
+ _element_shape = _element_shape . merge_with ( tensor . shape ) ;
211
+ }
212
+ _tensor_array [ index ] = tensor ;
213
+ return this ;
150
214
}
151
215
152
216
private Tensor size ( string name = null )
@@ -156,11 +220,26 @@ private Tensor size(string name = null)
156
220
157
221
public override Tensor stack ( string name = null )
158
222
{
159
- ops . colocate_with ( _handle ) ;
160
- return tf_with ( ops . name_scope ( name , "TensorArrayStack" , new { _handle } ) , delegate
223
+ if ( _tensor_array . Count > 0 )
161
224
{
162
- return gather ( math_ops . range ( 0 , size ( ) ) , name : name ) ;
163
- } ) ;
225
+ for ( int i = 0 ; i < _tensor_array . Count ; i ++ )
226
+ {
227
+ _maybe_zero ( i ) ;
228
+ }
229
+ }
230
+ if ( _tensor_array . Count == 0 && _element_shape . IsFullyDefined )
231
+ {
232
+ return ops . convert_to_tensor ( new Shape ( new long [ ] { 0 } . Concat ( _element_shape . dims ) . ToArray ( ) ) , name : name , dtype : _dtype ) ;
233
+ }
234
+ else
235
+ {
236
+ return ops . convert_to_tensor ( _tensor_array , name : name , dtype : _dtype ) ;
237
+ }
238
+ //ops.colocate_with(_handle);
239
+ //return tf_with(ops.name_scope(name, "TensorArrayStack", new { _handle }), delegate
240
+ //{
241
+ // return gather(math_ops.range(0, size()), name: name);
242
+ //});
164
243
}
165
244
166
245
public override Tensor gather ( Tensor indices , string name = null )
0 commit comments