Skip to content

Partially Support the function loading #1022

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Apr 18, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions TensorFlow.NET.sln
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

Microsoft Visual Studio Solution File, Format Version 12.00
# Visual Studio Version 16
VisualStudioVersion = 16.0.31624.102
# Visual Studio Version 17
VisualStudioVersion = 17.4.33213.308
MinimumVisualStudioVersion = 10.0.40219.1
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Binding", "src\TensorFlowNET.Core\Tensorflow.Binding.csproj", "{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}"
EndProject
@@ -23,6 +23,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Keras.UnitTest",
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Graph.UnitTest", "test\TensorFlowNET.Graph.UnitTest\TensorFlowNET.Graph.UnitTest.csproj", "{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Tensorflow.Common", "Tensorflow.Common\Tensorflow.Common.csproj", "{0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@@ -153,6 +155,18 @@ Global
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x64.Build.0 = Release|x64
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.ActiveCfg = Release|Any CPU
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.Build.0 = Release|Any CPU
{0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Debug|Any CPU.Build.0 = Debug|Any CPU
{0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Debug|x64.ActiveCfg = Debug|Any CPU
{0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Debug|x64.Build.0 = Debug|Any CPU
{0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Debug|x86.ActiveCfg = Debug|Any CPU
{0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Debug|x86.Build.0 = Debug|Any CPU
{0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Release|Any CPU.ActiveCfg = Release|Any CPU
{0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Release|Any CPU.Build.0 = Release|Any CPU
{0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Release|x64.ActiveCfg = Release|Any CPU
{0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Release|x64.Build.0 = Release|Any CPU
{0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Release|x86.ActiveCfg = Release|Any CPU
{0C5DD8A8-AB1E-40AB-8CE3-F6EA0C1ED680}.Release|x86.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
31 changes: 31 additions & 0 deletions Tensorflow.Common/Extensions/DictionaryExtension.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Text;

namespace Tensorflow.Common.Extensions
{
public static class DictionaryExtension
{
public static void Deconstruct<T1, T2>(this KeyValuePair<T1, T2> pair, out T1 first, out T2 second)
{
first = pair.Key;
second = pair.Value;
}
public static void Update<T1, T2>(this Dictionary<T1, T2> dic, IDictionary<T1, T2> other)
{
foreach(var (key, value) in other)
{
dic[key] = value;
}
}
public static T2 GetOrDefault<T1, T2>(this Dictionary<T1, T2> dic, T1 key, T2 defaultValue)
{
if (dic.ContainsKey(key))
{
return dic[key];
}
return defaultValue;
}
}
}
13 changes: 13 additions & 0 deletions Tensorflow.Common/Extensions/OneofExtension.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
using OneOf;
using System;

namespace Tensorflow.Common.Extensions
{
public static class OneofExtension
{
public static bool IsTypeOrDeriveFrom<T>(this IOneOf src)
{
return src.Value is T;
}
}
}
11 changes: 11 additions & 0 deletions Tensorflow.Common/Tensorflow.Common.csproj
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="OneOf" Version="3.0.223" />
</ItemGroup>

</Project>
13 changes: 13 additions & 0 deletions Tensorflow.Common/Types/NamedTuple.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Text;

namespace Tensorflow.Common.Types
{
public class NamedTuple
{
public string Name { get; set; }
public Dictionary<string, object> ValueDict { get; set; }
}
}
17 changes: 17 additions & 0 deletions src/TensorFlowNET.Core/APIs/c_api.customize.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;

namespace Tensorflow
{
public partial class c_api
{
[DllImport(TensorFlowLibName)]
public static extern void TFC_SetAttr(SafeGraphHandle graph, IntPtr op, string attr_name, SafeBufferHandle attr_value_proto, SafeStatusHandle status);
[DllImport(TensorFlowLibName)]
public static extern IntPtr TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output);
[DllImport(TensorFlowLibName)]
public static extern void TFC_SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data, long proto_len, SafeStatusHandle status);
}
}
18 changes: 18 additions & 0 deletions src/TensorFlowNET.Core/APIs/tf.compat.cs
Original file line number Diff line number Diff line change
@@ -14,6 +14,7 @@ You may obtain a copy of the License at
limitations under the License.
******************************************************************************/

using Google.Protobuf;
using System.Text;

namespace Tensorflow
@@ -45,6 +46,23 @@ internal string as_str(byte[] bytes_or_text, Encoding? encoding = null)
{
return as_text(bytes_or_text, encoding);
}

public ByteString as_bytes(ByteString bytes, Encoding encoding = null)
{
return bytes;
}
public ByteString as_bytes(byte[] bytes, Encoding encoding = null)
{
return ByteString.CopyFrom(bytes);
}
public ByteString as_bytes(string text, Encoding encoding = null)
{
if(encoding is null)
{
encoding = Encoding.UTF8;
}
return ByteString.CopyFrom(encoding.GetBytes(text));
}
}

public bool executing_eagerly()
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Core/APIs/tf.io.cs
Original file line number Diff line number Diff line change
@@ -54,6 +54,6 @@ public ITensorOrOperation[] import_graph_def(GraphDef graph_def,
Dictionary<string, Tensor> input_map = null,
string[] return_elements = null,
string name = null,
OpList producer_op_list = null) => importer.import_graph_def(graph_def, input_map, return_elements, name, producer_op_list);
OpList producer_op_list = null) => importer.import_graph_def(graph_def, input_map, return_elements, name: name, producer_op_list: producer_op_list);
}
}
7 changes: 7 additions & 0 deletions src/TensorFlowNET.Core/APIs/tf.tensor.cs
Original file line number Diff line number Diff line change
@@ -14,6 +14,8 @@ You may obtain a copy of the License at
limitations under the License.
******************************************************************************/

using Tensorflow.Operations;

namespace Tensorflow
{
public partial class tensorflow
@@ -79,5 +81,10 @@ public Tensor[] split(Tensor value, int num_split, int axis, string name = null)
num_split: num_split,
axis: axis,
name: name);

public Tensor ensure_shape(Tensor x, Shape shape, string name = null)
{
return gen_ops.ensure_shape(x, shape, name);
}
}
}
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Core/Attributes/c_api.ops.cs
Original file line number Diff line number Diff line change
@@ -61,7 +61,7 @@ public partial class c_api
public static extern void TF_SetAttrBool(IntPtr desc, string attr_name, bool value);

[DllImport(TensorFlowLibName)]
public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, byte[] proto, int proto_len, SafeStatusHandle status);
public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, byte[] proto, ulong proto_len, SafeStatusHandle status);

/// <summary>
/// Set `num_dims` to -1 to represent "unknown rank".
1 change: 1 addition & 0 deletions src/TensorFlowNET.Core/Binding.Util.cs
Original file line number Diff line number Diff line change
@@ -22,6 +22,7 @@ limitations under the License.
using System.Diagnostics;
using System.IO;
using System.Linq;
using Tensorflow.Operations;

namespace Tensorflow
{
6 changes: 6 additions & 0 deletions src/TensorFlowNET.Core/Buffers/Buffer.cs
Original file line number Diff line number Diff line change
@@ -107,6 +107,12 @@ public unsafe byte[] ToArray()
}
}

public void Release()
{
_handle.Dispose();
_handle = null;
}

public override string ToString()
=> $"0x{_handle.DangerousGetHandle():x16}";

27 changes: 27 additions & 0 deletions src/TensorFlowNET.Core/Buffers/TF_Buffer.cs
Original file line number Diff line number Diff line change
@@ -25,5 +25,32 @@ public struct TF_Buffer
public IntPtr data;
public ulong length;
public IntPtr data_deallocator;

public unsafe Span<T> AsSpan<T>() where T: unmanaged
{
if(length > int.MaxValue)
{
throw new ValueError($"The length {length} is too large to use in the span.");
}
return new Span<T>(data.ToPointer(), (int)length);
}

public unsafe byte[] ToByteArray()
{
byte[] res = new byte[length];
if(length > int.MaxValue)
{
byte* root = (byte*)data;
for(ulong i = 0; i < length; i++)
{
res[i] = *(root++);
}
}
else
{
new Span<byte>(data.ToPointer(), (int)length).CopyTo(res.AsSpan());
}
return res;
}
}
}
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs
Original file line number Diff line number Diff line change
@@ -161,7 +161,7 @@ public static IList<Trackable> list_objects(ObjectGraphView graph_view)

internal static IEnumerable<Trackable> _objects_with_attributes(IEnumerable<Trackable> full_list)
{
return full_list.TakeWhile(x =>
return full_list.Where(x =>
{
var saveables = x.gather_saveables_for_checkpoint();
return saveables is not null && saveables.Count > 0;
34 changes: 20 additions & 14 deletions src/TensorFlowNET.Core/Checkpoint/SaveUtil.cs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
using System;
using OneOf;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text;
using Tensorflow.Train;
using Tensorflow.Training;
using Tensorflow.Common.Extensions;
using pbc = global::Google.Protobuf.Collections;

namespace Tensorflow.Checkpoint
@@ -28,7 +30,7 @@ Trackable object_to_save
);
public static class SaveUtil
{
public static (IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph)
public static (IDictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>>, IDictionary<Tensor, object>, IDictionary<string, IDictionary<string, Trackable>>, TrackableObjectGraph)
serialize_graph_view(ObjectGraphView graph_view, IDictionary<Trackable, Trackable>? object_map = null, bool call_with_mapped_captures = false, object? cache = null)
{
var (trackable_data, node_ids) = gather_trackable_data(graph_view, object_map);
@@ -104,7 +106,10 @@ private static TrackableObjectGraph fill_object_graph_proto(IList<TrackableData>
{
var td = trackable_data[i];
Debug.Assert(td.node_id == i);
object_graph_proto.Nodes.Add(new TrackableObjectGraph.Types.TrackableObject(td.slot_variable_proto, td.children_proto));
TrackableObjectGraph.Types.TrackableObject trackable_object = new();
trackable_object.SlotVariables.AddRange(td.slot_variable_proto);
trackable_object.Children.AddRange(td.children_proto);
object_graph_proto.Nodes.Add(trackable_object);
}
return object_graph_proto;
}
@@ -117,16 +122,16 @@ private static TrackableObjectGraph fill_object_graph_proto(IList<TrackableData>
/// <param name="call_with_mapped_captures"></param>
/// <param name="cache"></param>
/// <param name="object_graph_proto"></param>
private static IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> get_and_write_tensors_to_serialize(IList<TrackableData> tensor_trackables, IDictionary<Trackable, int> node_ids,
private static IDictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> get_and_write_tensors_to_serialize(IList<TrackableData> tensor_trackables, IDictionary<Trackable, int> node_ids,
bool call_with_mapped_captures, object? cache, TrackableObjectGraph object_graph_proto)
{
Dictionary<Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>> serialized_tensors = new();
Dictionary<Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>> serialized_tensors = new();
foreach(var td in tensor_trackables)
{
// TODO: deal with cache.
var legacy_name = SaveableCompat.get_saveable_name(td.object_to_save) ?? "";
Trackable trackable = null;
IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict;
IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> tensor_dict;
if(!saveable_object_util.trackable_has_serialize_to_tensor(td.object_to_save) || legacy_name.Length > 0)
{
(trackable, tensor_dict) = get_tensors_from_legacy_saveable(td, node_ids, call_with_mapped_captures, object_graph_proto);
@@ -148,12 +153,12 @@ private static IDictionary<Trackable, IDictionary<string, Maybe<Tensor, IDiction
return serialized_tensors;
}

private static IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto)
private static IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> get_tensors_from_trackable(TrackableData trackable_data, bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto)
{
var trackable = trackable_data.object_to_save;

// TODO: complete it. Note that actually `call_with_mapped_captures` is of function type.
IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> ret_tensor_dict;
IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> ret_tensor_dict;
if (call_with_mapped_captures)
{
throw new NotImplementedException();
@@ -163,8 +168,7 @@ private static IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> g
ret_tensor_dict = trackable.serialize_to_tensors();
}

// TODO: deal with the type `SaveSpce` (currently it will never be it).
Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> tensor_dict = new();
Dictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>> tensor_dict = new();
foreach(var pair in ret_tensor_dict)
{
var local_name = TrackableUtils.escape_local_name(pair.Key);
@@ -173,10 +177,12 @@ private static IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> g

tensor_dict[checkpoint_key] = maybe_tensor;

if(maybe_tensor.IsTypeOrDeriveFrom<SaveSpec>())
foreach(var key in maybe_tensor.Keys)
{
throw new NotImplementedException();
//((SaveSpec)maybe_tensor).name = local_name + ((SaveSpec)maybe_tensor).name;
if (maybe_tensor[key].IsTypeOrDeriveFrom<SaveSpec>())
{
maybe_tensor[key].AsT1.name = local_name + maybe_tensor[key].AsT1.name;
}
}

if(object_graph_proto is not null)
@@ -200,7 +206,7 @@ private static IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> g
/// <param name="call_with_mapped_captures"></param>
/// <param name="object_graph_proto"></param>
/// <returns></returns>
private static (Trackable, IDictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary<Trackable, int> node_ids,
private static (Trackable, IDictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>) get_tensors_from_legacy_saveable(TrackableData trackable_data, IDictionary<Trackable, int> node_ids,
bool call_with_mapped_captures, TrackableObjectGraph object_graph_proto)
{
Dictionary<Trackable, string> object_names = new();
Loading