Skip to content

Commit ed1a8d2

Browse files
committed
Add shape and dtype to KerasTensor
1 parent 70f873e commit ed1a8d2

File tree

3 files changed

+52
-26
lines changed

3 files changed

+52
-26
lines changed

src/TensorFlowNET.Core/Operations/array_ops.cs

+31-18
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,17 @@ public static Tensor shape_internal(Tensor input, string name = null, bool optim
603603
}
604604
}
605605

606-
return gen_array_ops.shape(input, name: name, out_type: out_type);
606+
return tf.Context.ExecuteOp("Shape", name, new ExecuteOpArgs(input)
607+
{
608+
GetGradientAttrs = (op) => new
609+
{
610+
T = op.get_attr<TF_DataType>("T"),
611+
out_type = op.get_attr<TF_DataType>("out_type")
612+
}
613+
}.SetAttributes(new
614+
{
615+
out_type
616+
})).First();
607617
});
608618
}
609619

@@ -703,23 +713,26 @@ public static Tensor strided_slice(Tensor input_, Tensor begin, Tensor end,
703713
int new_axis_mask = 0,
704714
int shrink_axis_mask = 0,
705715
string name = null)
706-
{
707-
var op = gen_array_ops.strided_slice(
708-
input: input_,
709-
begin: begin,
710-
end: end,
711-
strides: strides,
712-
begin_mask: begin_mask,
713-
end_mask: end_mask,
714-
ellipsis_mask: ellipsis_mask,
715-
new_axis_mask: new_axis_mask,
716-
shrink_axis_mask: shrink_axis_mask,
717-
name: name);
718-
719-
string parent_name = name;
720-
721-
return op;
722-
}
716+
=> tf.Context.ExecuteOp("StridedSlice", name, new ExecuteOpArgs(input_, begin, end, strides)
717+
{
718+
GetGradientAttrs = (op) => new
719+
{
720+
T = op.get_attr<TF_DataType>("T"),
721+
Index = op.get_attr<TF_DataType>("Index"),
722+
begin_mask = op.get_attr<long>("begin_mask"),
723+
end_mask = op.get_attr<long>("end_mask"),
724+
ellipsis_mask = op.get_attr<long>("ellipsis_mask"),
725+
new_axis_mask = op.get_attr<long>("new_axis_mask"),
726+
shrink_axis_mask = op.get_attr<long>("shrink_axis_mask")
727+
}
728+
}.SetAttributes(new
729+
{
730+
begin_mask,
731+
end_mask,
732+
ellipsis_mask,
733+
new_axis_mask,
734+
shrink_axis_mask
735+
}));
723736

724737
/// <summary>
725738
/// Returns the gradient of `StridedSlice`.

src/TensorFlowNET.Core/Tensors/KerasTensor.cs

+20-7
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,17 @@
55
/// </summary>
66
public class KerasTensor
77
{
8-
private Tensor _tensor;
9-
public void SetTensor(Tensors tensor)
10-
=> _tensor = tensor;
8+
private Tensors _inferred_value;
9+
public Tensors inferred_value
10+
{
11+
get => _inferred_value;
12+
set => _inferred_value = value;
13+
}
1114

12-
private TensorSpec _type_spec;
1315
private string _name;
16+
private TensorSpec _type_spec;
17+
public Shape shape => _type_spec.shape;
18+
public TF_DataType dtype => _type_spec.dtype;
1419

1520
public KerasTensor(TensorSpec type_spec, string name = null)
1621
{
@@ -22,15 +27,23 @@ public static KerasTensor from_tensor(Tensor tensor)
2227
{
2328
var type_spec = tensor.ToTensorSpec();
2429
var kt = new KerasTensor(type_spec, name: tensor.name);
25-
kt.SetTensor(tensor);
30+
kt.inferred_value = tensor;
2631
return kt;
2732
}
2833

34+
public override string ToString()
35+
=> _inferred_value.Length switch
36+
{
37+
> 1 => "[" + string.Join(", ", _inferred_value.Select(x => $"<KerasTensor: shape={x.shape} dtype={x.dtype}>")) + "]",
38+
1 => $"<KerasTensor: shape={_inferred_value.shape} dtype={_inferred_value.dtype}>",
39+
_ => _inferred_value.ToString(),
40+
};
41+
2942
public static implicit operator Tensors(KerasTensor kt)
30-
=> kt._tensor;
43+
=> kt._inferred_value;
3144

3245
public static implicit operator Tensor(KerasTensor kt)
33-
=> kt._tensor;
46+
=> kt._inferred_value;
3447

3548
public static implicit operator KerasTensor(Tensor tensor)
3649
=> from_tensor(tensor);

src/TensorFlowNET.Core/Tensors/Tensor.Index.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public Tensor this[params Slice[] slices]
4242
array_ops.stack(args.End),
4343
array_ops.stack(args.Strides));
4444

45-
return gen_array_ops.strided_slice(
45+
return array_ops.strided_slice(
4646
this,
4747
packed_begin,
4848
packed_end,

0 commit comments

Comments
 (0)