Skip to content

Commit efa28d4

Browse files
committed
Add a Tag to Tensor, Fixed Tensor Slice, AdamOptimizer #271
1 parent 60ec5af commit efa28d4

File tree

7 files changed

+87
-21
lines changed

7 files changed

+87
-21
lines changed

src/TensorFlowNET.Core/Framework/IndexedSlices.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,18 @@ public IndexedSlices(Tensor values, Tensor indices, Tensor dense_shape = null)
3131
_values = values;
3232
_indices = indices;
3333
_dense_shape = dense_shape;
34+
35+
_values.Tag = this;
3436
}
3537

3638
public static implicit operator Tensor(IndexedSlices indexedSlices)
3739
{
3840
return indexedSlices.values;
3941
}
42+
43+
public static implicit operator IndexedSlices(Tensor tensor)
44+
{
45+
return tensor.Tag as IndexedSlices;
46+
}
4047
}
4148
}

src/TensorFlowNET.Core/Gradients/array_grad.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ public static Tensor[] _GatherV2Grad(Operation op, Tensor[] grads)
156156
// For axis 0 gathers, build an appropriately shaped IndexedSlices.
157157
if((int)axis_static == 0)
158158
{
159-
var params_tail_shape = params_shape[1];
159+
var params_tail_shape = params_shape[new NumSharp.Slice(start:1)];
160160
var values_shape = array_ops.concat(new[] { indices_size, params_tail_shape }, 0);
161161
var values = array_ops.reshape(grad, values_shape);
162162
indices = array_ops.reshape(indices, indices_size);

src/TensorFlowNET.Core/Operations/gen_array_ops.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,8 +223,8 @@ public static (Tensor, Tensor) unique(Tensor x, TF_DataType out_idx = TF_DataTyp
223223
{
224224
var _op = _op_def_lib._apply_op_helper("Unique", name, new { x, out_idx });
225225
// TODO
226-
throw new NotImplementedException("_result = _UniqueOutput._make(_result)");
227-
// return _op.outputs[0];
226+
//var _result = _UniqueOutput._make(_op.outputs);
227+
return (_op.outputs[0], _op.outputs[1]);
228228
}
229229

230230
public static Tensor where()

src/TensorFlowNET.Core/Tensors/Tensor.cs

Lines changed: 64 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ public partial class Tensor : IDisposable, ITensorOrOperation
5858

5959
private TF_Output? _tf_output;
6060

61+
/// <summary>
62+
/// used for keep other pointer when do implicit operating
63+
/// </summary>
64+
public object Tag { get; set; }
65+
6166
public int[] shape
6267
{
6368
get
@@ -219,11 +224,11 @@ public TF_DataType ToTFDataType(Type type)
219224
}
220225
}
221226

222-
public Tensor this[int start, int? stop, int? step]
227+
public Tensor this[Slice slice]
223228
{
224229
get
225230
{
226-
var slice_spec = new int[] { start };
231+
var slice_spec = new int[] { slice.Start.Value };
227232
var begin = new List<int>();
228233
var end = new List<int>();
229234
var strides = new List<int>();
@@ -236,14 +241,16 @@ public TF_DataType ToTFDataType(Type type)
236241
foreach (var s in slice_spec)
237242
{
238243
begin.Add(s);
239-
if (stop == null)
244+
if(slice.Stop.HasValue)
245+
{
246+
end.Add(slice.Stop.Value);
247+
}
248+
else
240249
{
241250
end.Add(0);
242251
end_mask |= (1 << index);
243252
}
244-
else
245-
end.Add(s + 1);
246-
strides.Add(1);
253+
strides.Add(slice.Step);
247254

248255
index += 1;
249256
}
@@ -277,7 +284,57 @@ public TF_DataType ToTFDataType(Type type)
277284
}
278285
}
279286

280-
public Tensor this[int slice_spec] => this[slice_spec, null, null];
287+
public Tensor this[int start]
288+
{
289+
get
290+
{
291+
var slice_spec = new int[] { start };
292+
var begin = new List<int>();
293+
var end = new List<int>();
294+
var strides = new List<int>();
295+
296+
var index = 0;
297+
var (new_axis_mask, shrink_axis_mask) = (0, 0);
298+
var (begin_mask, end_mask) = (0, 0);
299+
var ellipsis_mask = 0;
300+
301+
foreach (var s in slice_spec)
302+
{
303+
begin.Add(s);
304+
end.Add(s + 1);
305+
strides.Add(1);
306+
shrink_axis_mask |= (1 << index);
307+
index += 1;
308+
}
309+
310+
return with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope =>
311+
{
312+
string name = scope;
313+
if (begin != null)
314+
{
315+
var (packed_begin, packed_end, packed_strides) =
316+
(array_ops.stack(begin.ToArray()),
317+
array_ops.stack(end.ToArray()),
318+
array_ops.stack(strides.ToArray()));
319+
320+
return gen_array_ops.strided_slice(
321+
this,
322+
packed_begin,
323+
packed_end,
324+
packed_strides,
325+
begin_mask: begin_mask,
326+
end_mask: end_mask,
327+
shrink_axis_mask: shrink_axis_mask,
328+
new_axis_mask: new_axis_mask,
329+
ellipsis_mask: ellipsis_mask,
330+
331+
name: name);
332+
}
333+
334+
throw new NotImplementedException("");
335+
});
336+
}
337+
}
281338

282339
public override string ToString()
283340
{

src/TensorFlowNET.Core/Train/Optimizer.cs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -227,9 +227,8 @@ public virtual Operation _apply_sparse(IndexedSlices grad, RefVariable var)
227227
public virtual (Tensor, Tensor) _deduplicate_indexed_slices(Tensor values, Tensor indices)
228228
{
229229
var (unique_indices, new_index_positions) = array_ops.unique(indices);
230-
var summed_values = math_ops.unsorted_segment_sum(
231-
values, new_index_positions,
232-
array_ops.shape(unique_indices)[0]);
230+
var shape = array_ops.shape(unique_indices)[0];
231+
var summed_values = math_ops.unsorted_segment_sum(values, new_index_positions, shape);
233232
return (summed_values, unique_indices);
234233
}
235234

src/TensorFlowNET.Core/Train/_OptimizableVariable.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Text;
4+
using Tensorflow.Framework;
45

56
namespace Tensorflow
67
{

src/TensorFlowNET.Core/Train/optimizer.py.cs

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,16 @@ public Tensor target()
2929

3030
public Operation update_op(Optimizer optimizer, Tensor g)
3131
{
32-
var update_op = optimizer._apply_dense(g, _v);
33-
34-
return update_op;
35-
}
36-
37-
public Operation update_op(Optimizer optimizer, IndexedSlices g)
38-
{
39-
var update_op = optimizer._apply_dense(g, _v);
32+
Operation update_op = null;
33+
34+
if (g.Tag == null)
35+
{
36+
update_op = optimizer._apply_dense(g, _v);
37+
}
38+
else if (g.Tag is IndexedSlices)
39+
{
40+
return optimizer._apply_sparse_duplicate_indices(g, _v);
41+
}
4042

4143
return update_op;
4244
}

0 commit comments

Comments
 (0)