Skip to content

Commit 7b26d66

Browse files
committed
Adjust location of KerasTensor.
1 parent d452d8c commit 7b26d66

File tree

7 files changed

+173
-164
lines changed

7 files changed

+173
-164
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
namespace Tensorflow.Keras.Engine;
2+
3+
/// <summary>
4+
/// A representation of a Keras in/output during Functional API construction.
5+
/// </summary>
6+
public class KerasTensor
7+
{
8+
private Tensors _original_tensors;
9+
public Tensors original_tensors
10+
{
11+
get => _original_tensors;
12+
set => _original_tensors = value;
13+
}
14+
15+
private Shape _inferred_value;
16+
public Shape inferred_value => _inferred_value;
17+
18+
private string _name;
19+
private TensorSpec _type_spec;
20+
public Shape shape => _type_spec.shape;
21+
public TF_DataType dtype => _type_spec.dtype;
22+
23+
public KerasTensor(TensorSpec type_spec, Shape inferred_value = null, string name = null)
24+
{
25+
_type_spec = type_spec;
26+
_inferred_value = inferred_value;
27+
_name = name;
28+
}
29+
30+
public static KerasTensor from_tensor(Tensor tensor)
31+
{
32+
var type_spec = tensor.ToTensorSpec();
33+
var kt = new KerasTensor(type_spec, name: tensor.name);
34+
kt.original_tensors = tensor;
35+
return kt;
36+
}
37+
38+
public override string ToString()
39+
=> _original_tensors.Length switch
40+
{
41+
> 1 => "[" + string.Join(", ", _original_tensors.Select(x => $"KerasTensor: shape={x.shape} dtype={x.dtype}")) + "]",
42+
1 => $"KerasTensor: shape={_original_tensors.shape} {GetInferredValueString()} dtype={_original_tensors.dtype}",
43+
_ => _original_tensors.ToString(),
44+
};
45+
46+
private string GetInferredValueString()
47+
=> _inferred_value == null ? "" : "";
48+
49+
public static implicit operator Tensors(KerasTensor kt)
50+
=> kt._original_tensors;
51+
52+
public static implicit operator Tensor(KerasTensor kt)
53+
{
54+
Tensor tensor = kt._original_tensors;
55+
tensor.IsFromKerasTensor = true;
56+
return tensor;
57+
}
58+
59+
public static implicit operator KerasTensor(Tensor tensor)
60+
=> from_tensor(tensor);
61+
62+
public static implicit operator KerasTensor(Tensors tensors)
63+
=> from_tensor(tensors.First());
64+
}

src/TensorFlowNET.Core/Tensors/KerasTensor.cs

-53
This file was deleted.

src/TensorFlowNET.Core/Tensors/Tensor.Conversions.cs

+4-13
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,10 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17-
using Tensorflow.NumPy;
18-
using System;
19-
using System.Diagnostics.CodeAnalysis;
20-
using System.Text;
21-
using Tensorflow.Framework.Models;
22-
using static Tensorflow.Binding;
17+
namespace Tensorflow;
2318

24-
namespace Tensorflow
19+
public partial class Tensor
2520
{
26-
[SuppressMessage("ReSharper", "InvokeAsExtensionMethod")]
27-
public partial class Tensor
28-
{
29-
public TensorSpec ToTensorSpec()
30-
=> new TensorSpec(shape, dtype, name);
31-
}
21+
public TensorSpec ToTensorSpec()
22+
=> new TensorSpec(shape, dtype, name);
3223
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/*****************************************************************************
2+
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
******************************************************************************/
16+
17+
namespace Tensorflow;
18+
19+
public partial class Tensor
20+
{
21+
public bool IsFromKerasTensor { get; set; }
22+
23+
/// <summary>
24+
/// Keras History: (Layer, (node_index, tensor_index))
25+
/// </summary>
26+
public KerasHistory KerasHistory { get; set; }
27+
}

src/TensorFlowNET.Core/Tensors/Tensor.cs

-5
Original file line numberDiff line numberDiff line change
@@ -146,11 +146,6 @@ public int[] _shape_tuple()
146146
return rank < 0 ? null : shape.dims.Select(x => (int)x).ToArray();
147147
}
148148

149-
/// <summary>
150-
/// Keras History: (Layer, (node_index, tensor_index))
151-
/// </summary>
152-
public KerasHistory KerasHistory { get; set; }
153-
154149
/// <summary>
155150
/// Updates the shape of this tensor.
156151
/// </summary>
+9-16
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,15 @@
1-
using System;
2-
using System.Collections.Generic;
3-
using System.IO;
4-
using System.Text;
5-
using Tensorflow.Keras.Engine;
6-
using Tensorflow.Keras.Saving;
1+
using Tensorflow.Keras.Saving;
72
using Tensorflow.Keras.Saving.SavedModel;
8-
using ThirdParty.Tensorflow.Python.Keras.Protobuf;
93

10-
namespace Tensorflow.Keras.Models
4+
namespace Tensorflow.Keras.Models;
5+
6+
public class ModelsApi: IModelsApi
117
{
12-
public class ModelsApi: IModelsApi
13-
{
14-
public Functional from_config(FunctionalConfig config)
15-
=> Functional.from_config(config);
8+
public Functional from_config(FunctionalConfig config)
9+
=> Functional.from_config(config);
1610

17-
public IModel load_model(string filepath, bool compile = true, LoadOptions? options = null)
18-
{
19-
return KerasLoadModelUtils.load_model(filepath, compile: compile, options: options) as Model;
20-
}
11+
public IModel load_model(string filepath, bool compile = true, LoadOptions? options = null)
12+
{
13+
return KerasLoadModelUtils.load_model(filepath, compile: compile, options: options) as Model;
2114
}
2215
}
Original file line numberDiff line numberDiff line change
@@ -1,97 +1,89 @@
1-
using Google.Protobuf;
2-
using System;
3-
using System.Collections.Generic;
4-
using System.IO;
5-
using System.Text;
6-
using Tensorflow.Keras.Engine;
1+
using System.IO;
72
using Tensorflow.Train;
83
using ThirdParty.Tensorflow.Python.Keras.Protobuf;
9-
using static Tensorflow.Binding;
10-
using static Tensorflow.KerasApi;
114

12-
namespace Tensorflow.Keras.Saving.SavedModel
5+
namespace Tensorflow.Keras.Saving.SavedModel;
6+
7+
public class KerasLoadModelUtils
138
{
14-
public class KerasLoadModelUtils
9+
/// <summary>
10+
/// Corresponding to keras/saving/save.py/load_model
11+
/// </summary>
12+
/// <param name="filepath"></param>
13+
/// <param name="custom_objects"></param>
14+
/// <param name="compile"></param>
15+
/// <param name="options"></param>
16+
/// <returns></returns>
17+
public static Trackable load_model(string filepath, IDictionary<string, object>? custom_objects = null,
18+
bool compile = true, LoadOptions? options = null)
1519
{
16-
/// <summary>
17-
/// Corresponding to keras/saving/save.py/load_model
18-
/// </summary>
19-
/// <param name="filepath"></param>
20-
/// <param name="custom_objects"></param>
21-
/// <param name="compile"></param>
22-
/// <param name="options"></param>
23-
/// <returns></returns>
24-
public static Trackable load_model(string filepath, IDictionary<string, object>? custom_objects = null,
25-
bool compile = true, LoadOptions? options = null)
20+
using var savingScope = SharedObjectSavingScope.Enter();
21+
22+
using var ctx = LoadContext.load_context(options);
23+
24+
if (!File.Exists(filepath) && !Directory.Exists(filepath))
2625
{
27-
using (SharedObjectSavingScope.Enter())
28-
{
29-
using (LoadContext.load_context(options))
30-
{
31-
if (!File.Exists(filepath) && !Directory.Exists(filepath))
32-
{
33-
throw new IOException($"No file or directory found at {filepath}.");
34-
}
35-
if (Directory.Exists(filepath))
36-
{
37-
return load(filepath, compile, options);
38-
}
39-
else
40-
{
41-
throw new NotImplementedException("Model load of h5 format has not been supported. Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues if it's needed.");
42-
}
43-
}
44-
}
26+
throw new IOException($"No file or directory found at {filepath}.");
4527
}
4628

47-
private static Trackable load(string path, bool compile = true, LoadOptions? options = null)
29+
if (Directory.Exists(filepath))
30+
{
31+
return load(filepath, compile, options);
32+
}
33+
else
4834
{
49-
SavedMetadata metadata = new SavedMetadata();
50-
var meta_graph_def = Loader.parse_saved_model(path).MetaGraphs[0];
51-
var object_graph_def = meta_graph_def.ObjectGraphDef;
52-
string path_to_metadata_pb = Path.Combine(path, Constants.SAVED_METADATA_PATH);
53-
if (File.Exists(path_to_metadata_pb))
54-
{
55-
metadata.MergeFrom(new FileStream(path_to_metadata_pb, FileMode.Open, FileAccess.Read));
56-
}
57-
else
58-
{
59-
throw new NotImplementedException("SavedModel saved prior to TF 2.5 detected when loading Keras model, please" +
60-
" use higher version or submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues. to let us know you need it.");
61-
}
35+
throw new NotImplementedException("Model load of h5 format has not been supported. Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues if it's needed.");
36+
}
37+
}
6238

63-
if (metadata.Nodes is null || metadata.Nodes.Count == 0)
64-
{
65-
return Loader.load(path, options: options) as Model;
66-
}
39+
private static Trackable load(string path, bool compile = true, LoadOptions? options = null)
40+
{
41+
SavedMetadata metadata;
42+
var meta_graph_def = Loader.parse_saved_model(path).MetaGraphs[0];
43+
var object_graph_def = meta_graph_def.ObjectGraphDef;
44+
string path_to_metadata_pb = Path.Combine(path, Constants.SAVED_METADATA_PATH);
45+
if (File.Exists(path_to_metadata_pb))
46+
{
47+
using var stream = new FileStream(path_to_metadata_pb, FileMode.Open, FileAccess.Read);
48+
metadata = SavedMetadata.Parser.ParseFrom(stream);
49+
}
50+
else
51+
{
52+
throw new NotImplementedException("SavedModel saved prior to TF 2.5 detected when loading Keras model, please" +
53+
" use higher version or submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues. to let us know you need it.");
54+
}
6755

68-
var keras_loader = new KerasObjectLoader(metadata, object_graph_def);
69-
keras_loader.load_layers(compile: compile);
56+
if (metadata.Nodes is null || metadata.Nodes.Count == 0)
57+
{
58+
return Loader.load(path, options: options) as Model;
59+
}
7060

71-
Dictionary<string, (Trackable, Action<object, object, object>)> nodes_to_load = new();
72-
nodes_to_load["root"] = (null, null);
73-
foreach(var item in keras_loader.LoadedNodes)
74-
{
75-
nodes_to_load[keras_loader.get_path(item.Key)] = item.Value;
76-
}
77-
var loaded = Loader.load_partial(path, nodes_to_load, options);
61+
var keras_loader = new KerasObjectLoader(metadata, object_graph_def);
62+
keras_loader.load_layers(compile: compile);
7863

79-
keras_loader.finalize_objects();
80-
keras_loader.del_tracking();
64+
Dictionary<string, (Trackable, Action<object, object, object>)> nodes_to_load = new();
65+
nodes_to_load["root"] = (null, null);
66+
foreach(var item in keras_loader.LoadedNodes)
67+
{
68+
nodes_to_load[keras_loader.get_path(item.Key)] = item.Value;
69+
}
70+
var loaded = Loader.load_partial(path, nodes_to_load, options);
8171

82-
var model = loaded["root"];
72+
keras_loader.finalize_objects();
73+
keras_loader.del_tracking();
8374

84-
if(model is Model && compile)
85-
{
86-
// TODO(Rinne): implement it.
87-
}
75+
var model = loaded["root"];
8876

89-
if (!tf.Context.executing_eagerly())
90-
{
91-
// TODO(Rinne): implement it.
92-
}
77+
if (model is Model && compile)
78+
{
79+
// TODO(Rinne): implement it.
80+
}
9381

94-
return model;
82+
if (!tf.Context.executing_eagerly())
83+
{
84+
// TODO(Rinne): implement it.
9585
}
86+
87+
return model;
9688
}
9789
}

0 commit comments

Comments
 (0)