Skip to content

Commit 6ec39ba

Browse files
committed
Fix inferred_value of KerasTensor. #1142
1 parent 12e3f54 commit 6ec39ba

File tree

10 files changed

+88
-14
lines changed

10 files changed

+88
-14
lines changed

src/TensorFlowNET.Core/APIs/tf.reshape.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,6 @@ public Tensor reshape(Tensor tensor,
3131
public Tensor reshape(Tensor tensor,
3232
object[] shape,
3333
string name = null)
34-
=> gen_array_ops.reshape(tensor, ops.convert_to_tensor(shape), name);
34+
=> array_ops.reshape(tensor, shape, name);
3535
}
3636
}

src/TensorFlowNET.Core/APIs/tf.tile.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ public Tensor tile(Tensor input, Tensor multiples, string name = null)
2323
=> gen_array_ops.tile(input, multiples, name);
2424

2525
public Tensor tile(Tensor input, object[] multiples, string name = null)
26-
=> gen_array_ops.tile(input, ops.convert_to_tensor(multiples), name);
26+
=> array_ops.tile(input, multiples, name);
2727

2828
public Tensor tile(Tensor input, Shape multiples, string name = null)
2929
{

src/TensorFlowNET.Core/GlobalUsing.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@
55
global using System.Data;
66
global using System.Linq;
77
global using Tensorflow.Keras.Engine;
8-
global using Tensorflow.Framework.Models;
8+
global using Tensorflow.Framework.Models;
9+
global using static Tensorflow.Binding;

src/TensorFlowNET.Core/Keras/Engine/KerasTensor.cs

+15-4
Original file line numberDiff line numberDiff line change
@@ -30,21 +30,32 @@ public KerasTensor(TensorSpec type_spec, Shape inferred_value = null, string nam
3030
public static KerasTensor from_tensor(Tensor tensor)
3131
{
3232
var type_spec = tensor.ToTensorSpec();
33-
var kt = new KerasTensor(type_spec, name: tensor.name);
33+
Shape? inferred_value = default;
34+
if (tensor.dtype == TF_DataType.TF_INT32 && tensor.rank < 2)
35+
{
36+
inferred_value = tf.ones(tensor).shape;
37+
}
38+
var kt = new KerasTensor(type_spec, inferred_value: inferred_value, name: tensor.name);
3439
kt.original_tensors = tensor;
3540
return kt;
3641
}
3742

43+
public KerasTensor this[int idx]
44+
=> _original_tensors.First()[idx];
45+
46+
public KerasTensor this[params Slice[] slices]
47+
=> _original_tensors.First()[slices];
48+
3849
public override string ToString()
3950
=> _original_tensors.Length switch
4051
{
41-
> 1 => "[" + string.Join(", ", _original_tensors.Select(x => $"KerasTensor: shape={x.shape} dtype={x.dtype}")) + "]",
42-
1 => $"KerasTensor: shape={_original_tensors.shape} {GetInferredValueString()} dtype={_original_tensors.dtype}",
52+
> 1 => "[" + string.Join(", ", _original_tensors.Select(x => $"KerasTensor: shape={x.shape} dtype={x.dtype.as_numpy_name()}{GetInferredValueString()}")) + "]",
53+
1 => $"KerasTensor: shape={_original_tensors.shape} dtype={_original_tensors.dtype.as_numpy_name()}{GetInferredValueString()}",
4354
_ => _original_tensors.ToString(),
4455
};
4556

4657
private string GetInferredValueString()
47-
=> _inferred_value == null ? "" : "";
58+
=> _inferred_value == null ? "" : $" inferred_value={_inferred_value}";
4859

4960
public static implicit operator Tensors(KerasTensor kt)
5061
=> kt._original_tensors;

src/TensorFlowNET.Core/Operations/array_ops.cs

+27-2
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ public static Tensor zeros(Tensors shape, TF_DataType dtype = TF_DataType.TF_FLO
137137
if(shape.Length > 1)
138138
{
139139
shapeTensor = ops.convert_to_tensor(shape, dtypes.int32);
140-
if(shapeTensor.ndim > 1)
140+
if (shapeTensor.ndim > 1)
141141
{
142142
shapeTensor = array_ops.reshape(shapeTensor, new Shape(-1));
143143
}
@@ -304,6 +304,10 @@ public static Tensor _autopacking_helper(IEnumerable<object> list_or_tuple, TF_D
304304
{
305305
elems_as_tensors.Add(tensor);
306306
}
307+
else if (elem is KerasTensor kt)
308+
{
309+
elems_as_tensors.Add(kt);
310+
}
307311
else
308312
{
309313
var elem_tensor = constant_op.constant(elem, dtype: dtype, name: i.ToString());
@@ -404,7 +408,10 @@ public static Tensor reshape(Tensor tensor, Shape shape, string name = null)
404408
=> gen_array_ops.reshape(tensor, shape, name: name);
405409

406410
public static Tensor reshape(Tensor tensor, object[] shape, string name = null)
407-
=> gen_array_ops.reshape(tensor, ops.convert_to_tensor(shape), name: name);
411+
{
412+
var dims = shape_utils.from_object_array(shape);
413+
return gen_array_ops.reshape(tensor, dims, name: name);
414+
}
408415

409416
private static Tensor ones_like_impl<T>(T tensor, TF_DataType dtype, string name, bool optimize = true)
410417
{
@@ -425,6 +432,10 @@ public static Tensor ones(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT
425432
return tf_with(ops.name_scope(name, "ones", new { shape }), scope =>
426433
{
427434
name = scope;
435+
if (shape._shape_tuple().Length == 0)
436+
{
437+
shape = reshape(shape, new Shape(-1));
438+
}
428439
var output = gen_array_ops.fill(shape, constant_op.constant(1.0f, dtype: dtype), name: name);
429440
return output;
430441
});
@@ -647,6 +658,20 @@ public static Tensor tile(Tensor input, Tensor multiples, string name = null)
647658
}
648659
});
649660

661+
public static Tensor tile(Tensor input, object[] multiples, string name = null)
662+
{
663+
Shape dims = shape_utils.from_object_array(multiples);
664+
665+
return tf.Context.ExecuteOp("Tile", name, new ExecuteOpArgs(input, dims)
666+
{
667+
GetGradientAttrs = (op) => new
668+
{
669+
T = op.get_attr<TF_DataType>("T"),
670+
Tmultiples = op.get_attr<TF_DataType>("Tmultiples")
671+
}
672+
});
673+
}
674+
650675
public static Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true)
651676
{
652677
return tf_with(ops.name_scope(name, "zeros_like", new Tensor[] { tensor }), scope =>

src/TensorFlowNET.Core/Tensors/shape_utils.cs

+27
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Linq;
3+
using Tensorflow.Eager;
34
using static Tensorflow.Binding;
45

56
namespace Tensorflow
@@ -13,5 +14,31 @@ public static Tensor static_or_dynamic_map_fn(Func<Tensor, Tensor> fn, Tensor el
1314

1415
throw new NotImplementedException("");
1516
}
17+
18+
public static Shape from_object_array(object[] shape)
19+
{
20+
var dims = shape.Select(x =>
21+
{
22+
if (x is KerasTensor kt && kt.inferred_value != null)
23+
{
24+
return kt.inferred_value.as_int_list()[0];
25+
}
26+
else if (x is EagerTensor et && et.dtype == TF_DataType.TF_INT32)
27+
{
28+
return et.ToArray<int>()[0];
29+
}
30+
else if (x is int i)
31+
{
32+
return i;
33+
}
34+
else if (x is long l)
35+
{
36+
return l;
37+
}
38+
throw new NotImplementedException();
39+
}).ToArray();
40+
41+
return new Shape(dims);
42+
}
1643
}
1744
}

src/TensorFlowNET.Core/Tensors/tf.constant.cs

+3
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ public Tensor zeros(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, stri
4646
public Tensor ones(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null)
4747
=> array_ops.ones(shape, dtype, name);
4848

49+
public Tensor ones(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null)
50+
=> array_ops.ones(shape, dtype, name);
51+
4952
public Tensor size(Tensor input,
5053
string name = null,
5154
TF_DataType out_type = TF_DataType.TF_INT32) => array_ops.size(input,

src/TensorFlowNET.Core/ops.cs

+9-2
Original file line numberDiff line numberDiff line change
@@ -144,11 +144,18 @@ public static Tensor convert_to_tensor(object value,
144144
}
145145
if (!graph.building_function)
146146
{
147-
throw new RuntimeError("Attempting to capture an EagerTensor without building a function.");
148-
// return eager_tensor.AsPlaceholder(name: name);
147+
// throw new RuntimeError("Attempting to capture an EagerTensor without building a function.");
148+
return eager_tensor.AsPlaceholder(name: name);
149149
}
150150
}
151151
}
152+
else if (value is KerasTensor kt)
153+
{
154+
if (kt.inferred_value != null)
155+
{
156+
return convert_to_tensor(kt.inferred_value, dtype: kt.dtype, name: name);
157+
}
158+
}
152159

153160
// graph mode
154161
Tensor ret = value switch

src/TensorFlowNET.Keras/Tensorflow.Keras.csproj

+1-1
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
141141

142142
<ItemGroup>
143143
<PackageReference Include="HDF5-CSharp" Version="1.17.0" />
144-
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" />
144+
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.149" />
145145
<PackageReference Include="SharpZipLib" Version="1.4.2" />
146146
</ItemGroup>
147147

test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj

+2-2
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@
4141

4242
<ItemGroup>
4343
<PackageReference Include="FluentAssertions" Version="5.10.3" />
44-
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" />
45-
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.3.2" />
44+
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.149" />
45+
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.6.3" />
4646
<PackageReference Include="MSTest.TestAdapter" Version="2.2.10" />
4747
<PackageReference Include="MSTest.TestFramework" Version="2.2.10" />
4848
</ItemGroup>

0 commit comments

Comments
 (0)