Skip to content

Align keras.Input with tensorflow python. #993

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using Tensorflow.Framework.Models;
using Tensorflow.NumPy;
using static Google.Protobuf.Reflection.FieldDescriptorProto.Types;

Expand Down Expand Up @@ -133,11 +134,16 @@ public ILayer EinsumDense(string equation,
public ILayer GlobalMaxPooling1D(string data_format = "channels_last");
public ILayer GlobalMaxPooling2D(string data_format = "channels_last");

public Tensors Input(Shape shape,
public Tensors Input(Shape shape = null,
int batch_size = -1,
string name = null,
TF_DataType dtype = TF_DataType.DtInvalid,
bool sparse = false,
bool ragged = false);
Tensor tensor = null,
bool ragged = false,
TypeSpec type_spec = null,
Shape batch_input_shape = null,
Shape batch_shape = null);
public ILayer InputLayer(Shape input_shape,
string name = null,
bool sparse = false,
Expand Down
40 changes: 12 additions & 28 deletions src/TensorFlowNET.Keras/KerasInterface.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
using Tensorflow.Keras.Optimizers;
using Tensorflow.Keras.Utils;
using System.Threading;
using Tensorflow.Framework.Models;

namespace Tensorflow.Keras
{
Expand Down Expand Up @@ -66,33 +67,16 @@ public Functional Model(Tensors inputs, Tensors outputs, string name = null)
/// If set, the layer will not create a placeholder tensor.
/// </param>
/// <returns></returns>
public Tensor Input(Shape shape = null,
int batch_size = -1,
Shape batch_input_shape = null,
TF_DataType dtype = TF_DataType.DtInvalid,
string name = null,
bool sparse = false,
bool ragged = false,
Tensor tensor = null)
{
if (batch_input_shape != null)
shape = batch_input_shape.dims.Skip(1).ToArray();

var args = new InputLayerArgs
{
Name = name,
InputShape = shape,
BatchInputShape = batch_input_shape,
BatchSize = batch_size,
DType = dtype,
Sparse = sparse,
Ragged = ragged,
InputTensor = tensor
};

var layer = new InputLayer(args);

return layer.InboundNodes[0].Outputs;
}
public Tensors Input(Shape shape = null,
int batch_size = -1,
string name = null,
TF_DataType dtype = TF_DataType.DtInvalid,
bool sparse = false,
Tensor tensor = null,
bool ragged = false,
TypeSpec type_spec = null,
Shape batch_input_shape = null,
Shape batch_shape = null) => keras.layers.Input(shape, batch_size, name,
dtype, sparse, tensor, ragged, type_spec, batch_input_shape, batch_shape);
}
}
51 changes: 44 additions & 7 deletions src/TensorFlowNET.Keras/Layers/LayersApi.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using Tensorflow.Framework.Models;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.ArgsDefinition.Core;
using Tensorflow.Keras.ArgsDefinition.Rnn;
Expand Down Expand Up @@ -471,20 +472,56 @@ public ILayer Flatten(string data_format = null)
/// In this case, values of 'None' in the 'shape' argument represent ragged dimensions. For more information about RaggedTensors, see this guide.
/// </param>
/// <returns>A tensor.</returns>
public Tensors Input(Shape shape,
public Tensors Input(Shape shape = null,
int batch_size = -1,
string name = null,
TF_DataType dtype = TF_DataType.DtInvalid,
bool sparse = false,
bool ragged = false)
Tensor tensor = null,
bool ragged = false,
TypeSpec type_spec = null,
Shape batch_input_shape = null,
Shape batch_shape = null)
{
var input_layer = new InputLayer(new InputLayerArgs
if(sparse && ragged)
{
throw new ValueError("Cannot set both `sparse` and `ragged` to `true` in a Keras `Input`.");
}

InputLayerArgs input_layer_config = new()
{
InputShape = shape,
BatchSize= batch_size,
Name = name,
DType = dtype,
Sparse = sparse,
Ragged = ragged
});
Ragged = ragged,
InputTensor = tensor,
// skip the `type_spec`
};

if(shape is not null && batch_input_shape is not null)
{
throw new ValueError("Only provide the `shape` OR `batch_input_shape` argument "
+ "to Input, not both at the same time.");
}

if(batch_input_shape is null && shape is null && tensor is null && type_spec is null)
{
throw new ValueError("Please provide to Input a `shape` or a `tensor` or a `type_spec` argument. Note that " +
"`shape` does not include the batch dimension.");
}

if(batch_input_shape is not null)
{
shape = batch_input_shape["1:"];
input_layer_config.BatchInputShape = batch_input_shape;
}
else
{
input_layer_config.BatchSize = batch_size;
input_layer_config.InputShape = shape;
}

var input_layer = new InputLayer(input_layer_config);

return input_layer.InboundNodes[0].Outputs;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ public void test_masked_attention()
var value = keras.Input(shape: (2, 8));
var mask_tensor = keras.Input(shape:(4, 2));
var attention_layer = keras.layers.MultiHeadAttention(num_heads: 2, key_dim: 2);
attention_layer.Apply(new[] { query, value, mask_tensor });
attention_layer.Apply(new Tensor[] { query, value, mask_tensor });

var from_data = 10 * np.random.randn(batch_size, 4, 8);
var to_data = 10 * np.random.randn(batch_size, 2, 8);
Expand Down