Skip to content

Commit 747e658

Browse files
committed
Change type of BuildInputShape to KerasShapesWrapper.
1 parent 3d0e2d0 commit 747e658

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+373
-123
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
using Newtonsoft.Json.Linq;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Text;
5+
6+
namespace Tensorflow.Extensions
7+
{
8+
public static class JObjectExtensions
9+
{
10+
public static T? TryGetOrReturnNull<T>(this JObject obj, string key)
11+
{
12+
var res = obj[key];
13+
if(res is null)
14+
{
15+
return default(T);
16+
}
17+
else
18+
{
19+
return res.ToObject<T>();
20+
}
21+
}
22+
}
23+
}

src/TensorFlowNET.Core/Framework/Models/TensorSpec.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ public class TensorSpec : DenseSpec
77
public TensorSpec(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) :
88
base(shape, dtype, name)
99
{
10-
10+
1111
}
1212

1313
public TensorSpec _unbatch()

src/TensorFlowNET.Core/Keras/Activations/Activations.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using Newtonsoft.Json;
22
using System.Reflection;
33
using System.Runtime.Versioning;
4-
using Tensorflow.Keras.Common;
4+
using Tensorflow.Keras.Saving.Common;
55

66
namespace Tensorflow.Keras
77
{

src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using System;
33
using System.Collections.Generic;
44
using System.Text;
5+
using Tensorflow.Keras.Saving;
56

67
namespace Tensorflow.Keras.ArgsDefinition
78
{
@@ -18,7 +19,7 @@ public class AutoSerializeLayerArgs: LayerArgs
1819
[JsonProperty("dtype")]
1920
public override TF_DataType DType { get => base.DType; set => base.DType = value; }
2021
[JsonProperty("batch_input_shape", NullValueHandling = NullValueHandling.Ignore)]
21-
public override Shape BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; }
22+
public override KerasShapesWrapper BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; }
2223
[JsonProperty("trainable")]
2324
public override bool Trainable { get => base.Trainable; set => base.Trainable = value; }
2425
}

src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/InputLayerArgs.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using Newtonsoft.Json;
22
using Newtonsoft.Json.Serialization;
3-
using Tensorflow.Keras.Common;
3+
using Tensorflow.Keras.Saving;
44

55
namespace Tensorflow.Keras.ArgsDefinition
66
{
@@ -17,6 +17,6 @@ public class InputLayerArgs : LayerArgs
1717
[JsonProperty("dtype")]
1818
public override TF_DataType DType { get => base.DType; set => base.DType = value; }
1919
[JsonProperty("batch_input_shape", NullValueHandling = NullValueHandling.Ignore)]
20-
public override Shape BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; }
20+
public override KerasShapesWrapper BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; }
2121
}
2222
}

src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public class LayerArgs: IKerasConfig
3333
/// <summary>
3434
/// Only applicable to input layers.
3535
/// </summary>
36-
public virtual Shape BatchInputShape { get; set; }
36+
public virtual KerasShapesWrapper BatchInputShape { get; set; }
3737

3838
public virtual int BatchSize { get; set; } = -1;
3939

src/TensorFlowNET.Core/Keras/Layers/ILayer.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ public interface ILayer: IWithTrackable, IKerasConfigable
1010
string Name { get; }
1111
bool Trainable { get; }
1212
bool Built { get; }
13-
void build(Shape input_shape);
13+
void build(KerasShapesWrapper input_shape);
1414
List<ILayer> Layers { get; }
1515
List<INode> InboundNodes { get; }
1616
List<INode> OutboundNodes { get; }
@@ -22,8 +22,8 @@ public interface ILayer: IWithTrackable, IKerasConfigable
2222
void set_weights(IEnumerable<NDArray> weights);
2323
List<NDArray> get_weights();
2424
Shape OutputShape { get; }
25-
Shape BatchInputShape { get; }
26-
TensorShapeConfig BuildInputShape { get; }
25+
KerasShapesWrapper BatchInputShape { get; }
26+
KerasShapesWrapper BuildInputShape { get; }
2727
TF_DataType DType { get; }
2828
int count_params();
2929
void adapt(Tensor data, int? batch_size = null, int? steps = null);

src/TensorFlowNET.Core/Keras/Common/CustomizedActivationJsonConverter.cs renamed to src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedActivationJsonConverter.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
using System.Text;
77
using static Tensorflow.Binding;
88

9-
namespace Tensorflow.Keras.Common
9+
namespace Tensorflow.Keras.Saving.Common
1010
{
1111
public class CustomizedActivationJsonConverter : JsonConverter
1212
{

src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs renamed to src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedAxisJsonConverter.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
using System.Collections.Generic;
55
using System.Text;
66

7-
namespace Tensorflow.Keras.Common
7+
namespace Tensorflow.Keras.Saving.Common
88
{
99
public class CustomizedAxisJsonConverter : JsonConverter
1010
{
@@ -38,7 +38,7 @@ public override void WriteJson(JsonWriter writer, object? value, JsonSerializer
3838
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
3939
{
4040
int[]? axis;
41-
if(reader.ValueType == typeof(long))
41+
if (reader.ValueType == typeof(long))
4242
{
4343
axis = new int[1];
4444
axis[0] = (int)serializer.Deserialize(reader, typeof(int));
@@ -51,7 +51,7 @@ public override void WriteJson(JsonWriter writer, object? value, JsonSerializer
5151
{
5252
throw new ValueError("Cannot deserialize 'null' to `Axis`.");
5353
}
54-
return new Axis((int[])(axis!));
54+
return new Axis(axis!);
5555
}
5656
}
5757
}

src/TensorFlowNET.Core/Keras/Common/CustomizedDTypeJsonConverter.cs renamed to src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedDTypeJsonConverter.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using Newtonsoft.Json.Linq;
22
using Newtonsoft.Json;
33

4-
namespace Tensorflow.Keras.Common
4+
namespace Tensorflow.Keras.Saving.Common
55
{
66
public class CustomizedDTypeJsonConverter : JsonConverter
77
{
@@ -16,7 +16,7 @@ public override bool CanConvert(Type objectType)
1616

1717
public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
1818
{
19-
var token = JToken.FromObject(dtypes.as_numpy_name((TF_DataType)value));
19+
var token = JToken.FromObject(((TF_DataType)value).as_numpy_name());
2020
token.WriteTo(writer);
2121
}
2222

src/TensorFlowNET.Core/Keras/Common/CustomizedIInitializerJsonConverter.cs renamed to src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedIInitializerJsonConverter.cs

+6-5
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
using System.Collections.Generic;
55
using System.Text;
66
using Tensorflow.Operations;
7+
78
using Tensorflow.Operations.Initializers;
89

9-
namespace Tensorflow.Keras.Common
10+
namespace Tensorflow.Keras.Saving.Common
1011
{
1112
class InitializerInfo
1213
{
@@ -27,7 +28,7 @@ public override bool CanConvert(Type objectType)
2728
public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
2829
{
2930
var initializer = value as IInitializer;
30-
if(initializer is null)
31+
if (initializer is null)
3132
{
3233
JToken.FromObject(null).WriteTo(writer);
3334
return;
@@ -42,7 +43,7 @@ public override void WriteJson(JsonWriter writer, object? value, JsonSerializer
4243
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
4344
{
4445
var info = serializer.Deserialize<InitializerInfo>(reader);
45-
if(info is null)
46+
if (info is null)
4647
{
4748
return null;
4849
}
@@ -54,8 +55,8 @@ public override void WriteJson(JsonWriter writer, object? value, JsonSerializer
5455
"Orthogonal" => new Orthogonal(info.config["gain"].ToObject<float>(), info.config["seed"].ToObject<int?>()),
5556
"RandomNormal" => new RandomNormal(info.config["mean"].ToObject<float>(), info.config["stddev"].ToObject<float>(),
5657
info.config["seed"].ToObject<int?>()),
57-
"RandomUniform" => new RandomUniform(minval:info.config["minval"].ToObject<float>(),
58-
maxval:info.config["maxval"].ToObject<float>(), seed: info.config["seed"].ToObject<int?>()),
58+
"RandomUniform" => new RandomUniform(minval: info.config["minval"].ToObject<float>(),
59+
maxval: info.config["maxval"].ToObject<float>(), seed: info.config["seed"].ToObject<int?>()),
5960
"TruncatedNormal" => new TruncatedNormal(info.config["mean"].ToObject<float>(), info.config["stddev"].ToObject<float>(),
6061
info.config["seed"].ToObject<int?>()),
6162
"VarianceScaling" => new VarianceScaling(info.config["scale"].ToObject<float>(), info.config["mode"].ToObject<string>(),
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
using Newtonsoft.Json.Linq;
2+
using Newtonsoft.Json;
3+
using System;
4+
using System.Collections.Generic;
5+
using System.Text;
6+
7+
namespace Tensorflow.Keras.Saving.Json
8+
{
9+
public class CustomizedKerasShapesWrapperJsonConverter : JsonConverter
10+
{
11+
public override bool CanConvert(Type objectType)
12+
{
13+
return objectType == typeof(KerasShapesWrapper);
14+
}
15+
16+
public override bool CanRead => true;
17+
18+
public override bool CanWrite => true;
19+
20+
public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
21+
{
22+
if (value is null)
23+
{
24+
JToken.FromObject(null).WriteTo(writer);
25+
return;
26+
}
27+
if (value is not KerasShapesWrapper wrapper)
28+
{
29+
throw new TypeError($"Expected `KerasShapesWrapper` to be serialized, bug got {value.GetType()}");
30+
}
31+
if (wrapper.Shapes.Length == 0)
32+
{
33+
JToken.FromObject(null).WriteTo(writer);
34+
}
35+
else if (wrapper.Shapes.Length == 1)
36+
{
37+
JToken.FromObject(wrapper.Shapes[0]).WriteTo(writer);
38+
}
39+
else
40+
{
41+
JToken.FromObject(wrapper.Shapes).WriteTo(writer);
42+
}
43+
}
44+
45+
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
46+
{
47+
if (reader.TokenType == JsonToken.StartArray)
48+
{
49+
TensorShapeConfig[] shapes = serializer.Deserialize<TensorShapeConfig[]>(reader);
50+
if (shapes is null)
51+
{
52+
return null;
53+
}
54+
return new KerasShapesWrapper(shapes);
55+
}
56+
else if (reader.TokenType == JsonToken.StartObject)
57+
{
58+
var shape = serializer.Deserialize<TensorShapeConfig>(reader);
59+
if (shape is null)
60+
{
61+
return null;
62+
}
63+
return new KerasShapesWrapper(shape);
64+
}
65+
else if (reader.TokenType == JsonToken.Null)
66+
{
67+
return null;
68+
}
69+
else
70+
{
71+
throw new ValueError($"Cannot deserialize the token type {reader.TokenType}");
72+
}
73+
}
74+
}
75+
}

src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs renamed to src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedNodeConfigJsonConverter.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
using System.Text;
88
using Tensorflow.Keras.Saving;
99

10-
namespace Tensorflow.Keras.Common
10+
namespace Tensorflow.Keras.Saving.Common
1111
{
1212
public class CustomizedNodeConfigJsonConverter : JsonConverter
1313
{
@@ -46,10 +46,10 @@ public override void WriteJson(JsonWriter writer, object? value, JsonSerializer
4646
{
4747
throw new ValueError("Cannot deserialize 'null' to `Shape`.");
4848
}
49-
if(values.Length == 1)
49+
if (values.Length == 1)
5050
{
5151
var array = values[0] as JArray;
52-
if(array is null)
52+
if (array is null)
5353
{
5454
throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`.");
5555
}

src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs renamed to src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedShapeJsonConverter.cs

+17-9
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55
using System.Collections.Generic;
66
using System.Text;
77

8-
namespace Tensorflow.Keras.Common
8+
namespace Tensorflow.Keras.Saving.Common
99
{
1010
class ShapeInfoFromPython
1111
{
1212
public string class_name { get; set; }
1313
public long?[] items { get; set; }
1414
}
15-
public class CustomizedShapeJsonConverter: JsonConverter
15+
public class CustomizedShapeJsonConverter : JsonConverter
1616
{
1717
public override bool CanConvert(Type objectType)
1818
{
@@ -25,20 +25,20 @@ public override bool CanConvert(Type objectType)
2525

2626
public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
2727
{
28-
if(value is null)
28+
if (value is null)
2929
{
3030
var token = JToken.FromObject(null);
3131
token.WriteTo(writer);
3232
}
33-
else if(value is not Shape)
33+
else if (value is not Shape)
3434
{
3535
throw new TypeError($"Unable to use `CustomizedShapeJsonConverter` to serialize the type {value.GetType()}.");
3636
}
3737
else
3838
{
3939
var shape = (value as Shape)!;
4040
long?[] dims = new long?[shape.ndim];
41-
for(int i = 0; i < dims.Length; i++)
41+
for (int i = 0; i < dims.Length; i++)
4242
{
4343
if (shape.dims[i] == -1)
4444
{
@@ -61,7 +61,7 @@ public override void WriteJson(JsonWriter writer, object? value, JsonSerializer
6161
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
6262
{
6363
long?[] dims;
64-
try
64+
if (reader.TokenType == JsonToken.StartObject)
6565
{
6666
var shape_info_from_python = serializer.Deserialize<ShapeInfoFromPython>(reader);
6767
if (shape_info_from_python is null)
@@ -70,14 +70,22 @@ public override void WriteJson(JsonWriter writer, object? value, JsonSerializer
7070
}
7171
dims = shape_info_from_python.items;
7272
}
73-
catch(JsonSerializationException)
73+
else if (reader.TokenType == JsonToken.StartArray)
7474
{
7575
dims = serializer.Deserialize<long?[]>(reader);
7676
}
77+
else if (reader.TokenType == JsonToken.Null)
78+
{
79+
return null;
80+
}
81+
else
82+
{
83+
throw new ValueError($"Cannot deserialize the token {reader} as Shape.");
84+
}
7785
long[] convertedDims = new long[dims.Length];
78-
for(int i = 0; i < dims.Length; i++)
86+
for (int i = 0; i < dims.Length; i++)
7987
{
80-
convertedDims[i] = dims[i] ?? (-1);
88+
convertedDims[i] = dims[i] ?? -1;
8189
}
8290
return new Shape(convertedDims);
8391
}

0 commit comments

Comments
 (0)