5
5
using System . Collections . Generic ;
6
6
using System . Text ;
7
7
8
- namespace Tensorflow . Keras . Common
8
+ namespace Tensorflow . Keras . Saving . Common
9
9
{
10
10
class ShapeInfoFromPython
11
11
{
12
12
public string class_name { get ; set ; }
13
13
public long ? [ ] items { get ; set ; }
14
14
}
15
- public class CustomizedShapeJsonConverter : JsonConverter
15
+ public class CustomizedShapeJsonConverter : JsonConverter
16
16
{
17
17
public override bool CanConvert ( Type objectType )
18
18
{
@@ -25,20 +25,20 @@ public override bool CanConvert(Type objectType)
25
25
26
26
public override void WriteJson ( JsonWriter writer , object ? value , JsonSerializer serializer )
27
27
{
28
- if ( value is null )
28
+ if ( value is null )
29
29
{
30
30
var token = JToken . FromObject ( null ) ;
31
31
token . WriteTo ( writer ) ;
32
32
}
33
- else if ( value is not Shape )
33
+ else if ( value is not Shape )
34
34
{
35
35
throw new TypeError ( $ "Unable to use `CustomizedShapeJsonConverter` to serialize the type { value . GetType ( ) } .") ;
36
36
}
37
37
else
38
38
{
39
39
var shape = ( value as Shape ) ! ;
40
40
long ? [ ] dims = new long ? [ shape . ndim ] ;
41
- for ( int i = 0 ; i < dims . Length ; i ++ )
41
+ for ( int i = 0 ; i < dims . Length ; i ++ )
42
42
{
43
43
if ( shape . dims [ i ] == - 1 )
44
44
{
@@ -61,7 +61,7 @@ public override void WriteJson(JsonWriter writer, object? value, JsonSerializer
61
61
public override object ? ReadJson ( JsonReader reader , Type objectType , object ? existingValue , JsonSerializer serializer )
62
62
{
63
63
long ? [ ] dims ;
64
- try
64
+ if ( reader . TokenType == JsonToken . StartObject )
65
65
{
66
66
var shape_info_from_python = serializer . Deserialize < ShapeInfoFromPython > ( reader ) ;
67
67
if ( shape_info_from_python is null )
@@ -70,14 +70,22 @@ public override void WriteJson(JsonWriter writer, object? value, JsonSerializer
70
70
}
71
71
dims = shape_info_from_python . items ;
72
72
}
73
- catch ( JsonSerializationException )
73
+ else if ( reader . TokenType == JsonToken . StartArray )
74
74
{
75
75
dims = serializer . Deserialize < long ? [ ] > ( reader ) ;
76
76
}
77
+ else if ( reader . TokenType == JsonToken . Null )
78
+ {
79
+ return null ;
80
+ }
81
+ else
82
+ {
83
+ throw new ValueError ( $ "Cannot deserialize the token { reader } as Shape.") ;
84
+ }
77
85
long [ ] convertedDims = new long [ dims . Length ] ;
78
- for ( int i = 0 ; i < dims . Length ; i ++ )
86
+ for ( int i = 0 ; i < dims . Length ; i ++ )
79
87
{
80
- convertedDims [ i ] = dims [ i ] ?? ( - 1 ) ;
88
+ convertedDims [ i ] = dims [ i ] ?? - 1 ;
81
89
}
82
90
return new Shape ( convertedDims ) ;
83
91
}
0 commit comments