Skip to content

Commit 559d471

Browse files
AsakusaRinneOceania2018
authored andcommitted
Align keras.Input with tensorflow python.
1 parent b8fd21c commit 559d471

File tree

4 files changed

+65
-38
lines changed

4 files changed

+65
-38
lines changed

src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs

+8-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System;
2+
using Tensorflow.Framework.Models;
23
using Tensorflow.NumPy;
34
using static Google.Protobuf.Reflection.FieldDescriptorProto.Types;
45

@@ -133,11 +134,16 @@ public ILayer EinsumDense(string equation,
133134
public ILayer GlobalMaxPooling1D(string data_format = "channels_last");
134135
public ILayer GlobalMaxPooling2D(string data_format = "channels_last");
135136

136-
public Tensors Input(Shape shape,
137+
public Tensors Input(Shape shape = null,
137138
int batch_size = -1,
138139
string name = null,
140+
TF_DataType dtype = TF_DataType.DtInvalid,
139141
bool sparse = false,
140-
bool ragged = false);
142+
Tensor tensor = null,
143+
bool ragged = false,
144+
TypeSpec type_spec = null,
145+
Shape batch_input_shape = null,
146+
Shape batch_shape = null);
141147
public ILayer InputLayer(Shape input_shape,
142148
string name = null,
143149
bool sparse = false,

src/TensorFlowNET.Keras/KerasInterface.cs

+12-28
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
using Tensorflow.Keras.Optimizers;
1313
using Tensorflow.Keras.Utils;
1414
using System.Threading;
15+
using Tensorflow.Framework.Models;
1516

1617
namespace Tensorflow.Keras
1718
{
@@ -66,33 +67,16 @@ public Functional Model(Tensors inputs, Tensors outputs, string name = null)
6667
/// If set, the layer will not create a placeholder tensor.
6768
/// </param>
6869
/// <returns></returns>
69-
public Tensor Input(Shape shape = null,
70-
int batch_size = -1,
71-
Shape batch_input_shape = null,
72-
TF_DataType dtype = TF_DataType.DtInvalid,
73-
string name = null,
74-
bool sparse = false,
75-
bool ragged = false,
76-
Tensor tensor = null)
77-
{
78-
if (batch_input_shape != null)
79-
shape = batch_input_shape.dims.Skip(1).ToArray();
80-
81-
var args = new InputLayerArgs
82-
{
83-
Name = name,
84-
InputShape = shape,
85-
BatchInputShape = batch_input_shape,
86-
BatchSize = batch_size,
87-
DType = dtype,
88-
Sparse = sparse,
89-
Ragged = ragged,
90-
InputTensor = tensor
91-
};
92-
93-
var layer = new InputLayer(args);
94-
95-
return layer.InboundNodes[0].Outputs;
96-
}
70+
public Tensors Input(Shape shape = null,
71+
int batch_size = -1,
72+
string name = null,
73+
TF_DataType dtype = TF_DataType.DtInvalid,
74+
bool sparse = false,
75+
Tensor tensor = null,
76+
bool ragged = false,
77+
TypeSpec type_spec = null,
78+
Shape batch_input_shape = null,
79+
Shape batch_shape = null) => keras.layers.Input(shape, batch_size, name,
80+
dtype, sparse, tensor, ragged, type_spec, batch_input_shape, batch_shape);
9781
}
9882
}

src/TensorFlowNET.Keras/Layers/LayersApi.cs

+44-7
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System;
2+
using Tensorflow.Framework.Models;
23
using Tensorflow.Keras.ArgsDefinition;
34
using Tensorflow.Keras.ArgsDefinition.Core;
45
using Tensorflow.Keras.ArgsDefinition.Rnn;
@@ -471,20 +472,56 @@ public ILayer Flatten(string data_format = null)
471472
/// In this case, values of 'None' in the 'shape' argument represent ragged dimensions. For more information about RaggedTensors, see this guide.
472473
/// </param>
473474
/// <returns>A tensor.</returns>
474-
public Tensors Input(Shape shape,
475+
public Tensors Input(Shape shape = null,
475476
int batch_size = -1,
476477
string name = null,
478+
TF_DataType dtype = TF_DataType.DtInvalid,
477479
bool sparse = false,
478-
bool ragged = false)
480+
Tensor tensor = null,
481+
bool ragged = false,
482+
TypeSpec type_spec = null,
483+
Shape batch_input_shape = null,
484+
Shape batch_shape = null)
479485
{
480-
var input_layer = new InputLayer(new InputLayerArgs
486+
if(sparse && ragged)
487+
{
488+
throw new ValueError("Cannot set both `sparse` and `ragged` to `true` in a Keras `Input`.");
489+
}
490+
491+
InputLayerArgs input_layer_config = new()
481492
{
482-
InputShape = shape,
483-
BatchSize= batch_size,
484493
Name = name,
494+
DType = dtype,
485495
Sparse = sparse,
486-
Ragged = ragged
487-
});
496+
Ragged = ragged,
497+
InputTensor = tensor,
498+
// skip the `type_spec`
499+
};
500+
501+
if(shape is not null && batch_input_shape is not null)
502+
{
503+
throw new ValueError("Only provide the `shape` OR `batch_input_shape` argument "
504+
+ "to Input, not both at the same time.");
505+
}
506+
507+
if(batch_input_shape is null && shape is null && tensor is null && type_spec is null)
508+
{
509+
throw new ValueError("Please provide to Input a `shape` or a `tensor` or a `type_spec` argument. Note that " +
510+
"`shape` does not include the batch dimension.");
511+
}
512+
513+
if(batch_input_shape is not null)
514+
{
515+
shape = batch_input_shape["1:"];
516+
input_layer_config.BatchInputShape = batch_input_shape;
517+
}
518+
else
519+
{
520+
input_layer_config.BatchSize = batch_size;
521+
input_layer_config.InputShape = shape;
522+
}
523+
524+
var input_layer = new InputLayer(input_layer_config);
488525

489526
return input_layer.InboundNodes[0].Outputs;
490527
}

test/TensorFlowNET.Keras.UnitTest/Layers/AttentionTest.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ public void test_masked_attention()
158158
var value = keras.Input(shape: (2, 8));
159159
var mask_tensor = keras.Input(shape:(4, 2));
160160
var attention_layer = keras.layers.MultiHeadAttention(num_heads: 2, key_dim: 2);
161-
attention_layer.Apply(new[] { query, value, mask_tensor });
161+
attention_layer.Apply(new Tensor[] { query, value, mask_tensor });
162162

163163
var from_data = 10 * np.random.randn(batch_size, 4, 8);
164164
var to_data = 10 * np.random.randn(batch_size, 2, 8);

0 commit comments

Comments
 (0)