Skip to content

Commit 70f873e

Browse files
committed
Initially adding KerasTensor. #1142
1 parent 6dd15f7 commit 70f873e

File tree

6 files changed

+49
-5
lines changed

6 files changed

+49
-5
lines changed

src/TensorFlowNET.Core/GlobalUsing.cs

+3-1
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,6 @@
33
global using System.Text;
44
global using System.Collections;
55
global using System.Data;
6-
global using System.Linq;
6+
global using System.Linq;
7+
global using Tensorflow.Keras.Engine;
8+
global using Tensorflow.Framework.Models;

src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using Tensorflow.Framework.Models;
3+
using Tensorflow.Keras.Engine;
34
using Tensorflow.Keras.Layers.Rnn;
45
using Tensorflow.NumPy;
56
using static Google.Protobuf.Reflection.FieldDescriptorProto.Types;
@@ -135,7 +136,7 @@ public ILayer EinsumDense(string equation,
135136
public ILayer GlobalMaxPooling1D(string data_format = "channels_last");
136137
public ILayer GlobalMaxPooling2D(string data_format = "channels_last");
137138

138-
public Tensors Input(Shape shape = null,
139+
public KerasTensor Input(Shape shape = null,
139140
int batch_size = -1,
140141
string name = null,
141142
TF_DataType dtype = TF_DataType.DtInvalid,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
namespace Tensorflow.Keras.Engine;
2+
3+
/// <summary>
4+
/// A representation of a Keras in/output during Functional API construction.
5+
/// </summary>
6+
public class KerasTensor
7+
{
8+
private Tensor _tensor;
9+
public void SetTensor(Tensors tensor)
10+
=> _tensor = tensor;
11+
12+
private TensorSpec _type_spec;
13+
private string _name;
14+
15+
public KerasTensor(TensorSpec type_spec, string name = null)
16+
{
17+
_type_spec = type_spec;
18+
_name = name;
19+
}
20+
21+
public static KerasTensor from_tensor(Tensor tensor)
22+
{
23+
var type_spec = tensor.ToTensorSpec();
24+
var kt = new KerasTensor(type_spec, name: tensor.name);
25+
kt.SetTensor(tensor);
26+
return kt;
27+
}
28+
29+
public static implicit operator Tensors(KerasTensor kt)
30+
=> kt._tensor;
31+
32+
public static implicit operator Tensor(KerasTensor kt)
33+
=> kt._tensor;
34+
35+
public static implicit operator KerasTensor(Tensor tensor)
36+
=> from_tensor(tensor);
37+
38+
public static implicit operator KerasTensor(Tensors tensors)
39+
=> from_tensor(tensors.First());
40+
}

src/TensorFlowNET.Keras/BackendImpl.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ public void track_variable(IVariableV1 v)
7676
_GRAPH_VARIABLES[graph.graph_key] = v;
7777
}
7878

79-
public Tensor placeholder(Shape shape = null,
79+
public KerasTensor placeholder(Shape shape = null,
8080
int ndim = -1,
8181
TF_DataType dtype = TF_DataType.DtInvalid,
8282
bool sparse = false,

src/TensorFlowNET.Keras/GlobalUsing.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@
44
global using System.Linq;
55
global using static Tensorflow.Binding;
66
global using static Tensorflow.KerasApi;
7-
global using Tensorflow.NumPy;
7+
global using Tensorflow.NumPy;
8+
global using Tensorflow.Keras.Engine;

src/TensorFlowNET.Keras/Layers/LayersApi.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,7 @@ public ILayer Flatten(string data_format = null)
466466
/// In this case, values of 'None' in the 'shape' argument represent ragged dimensions. For more information about RaggedTensors, see this guide.
467467
/// </param>
468468
/// <returns>A tensor.</returns>
469-
public Tensors Input(Shape shape = null,
469+
public KerasTensor Input(Shape shape = null,
470470
int batch_size = -1,
471471
string name = null,
472472
TF_DataType dtype = TF_DataType.DtInvalid,

0 commit comments

Comments
 (0)