Skip to content

Add Tensorflow.NET.Hub #1034

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions TensorFlow.NET.sln
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Keras.UnitTest",
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Graph.UnitTest", "test\TensorFlowNET.Graph.UnitTest\TensorFlowNET.Graph.UnitTest.csproj", "{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Hub", "src\TensorflowNET.Hub\Tensorflow.Hub.csproj", "{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Hub.Unittest", "test\TensorflowNET.Hub.Unittest\Tensorflow.Hub.Unittest.csproj", "{7DEA8760-E401-4872-81F3-405F185A13A0}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -153,6 +157,30 @@ Global
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x64.Build.0 = Release|x64
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.ActiveCfg = Release|Any CPU
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.Build.0 = Release|Any CPU
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|Any CPU.Build.0 = Debug|Any CPU
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|x64.ActiveCfg = Debug|Any CPU
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|x64.Build.0 = Debug|Any CPU
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|x86.ActiveCfg = Debug|Any CPU
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|x86.Build.0 = Debug|Any CPU
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|Any CPU.ActiveCfg = Release|Any CPU
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|Any CPU.Build.0 = Release|Any CPU
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|x64.ActiveCfg = Release|Any CPU
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|x64.Build.0 = Release|Any CPU
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|x86.ActiveCfg = Release|Any CPU
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|x86.Build.0 = Release|Any CPU
{7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|Any CPU.Build.0 = Debug|Any CPU
{7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|x64.ActiveCfg = Debug|Any CPU
{7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|x64.Build.0 = Debug|Any CPU
{7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|x86.ActiveCfg = Debug|Any CPU
{7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|x86.Build.0 = Debug|Any CPU
{7DEA8760-E401-4872-81F3-405F185A13A0}.Release|Any CPU.ActiveCfg = Release|Any CPU
{7DEA8760-E401-4872-81F3-405F185A13A0}.Release|Any CPU.Build.0 = Release|Any CPU
{7DEA8760-E401-4872-81F3-405F185A13A0}.Release|x64.ActiveCfg = Release|Any CPU
{7DEA8760-E401-4872-81F3-405F185A13A0}.Release|x64.Build.0 = Release|Any CPU
{7DEA8760-E401-4872-81F3-405F185A13A0}.Release|x86.ActiveCfg = Release|Any CPU
{7DEA8760-E401-4872-81F3-405F185A13A0}.Release|x86.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down
21 changes: 18 additions & 3 deletions src/TensorFlowNET.Core/Tensors/Tensors.cs
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,24 @@ private static void EnsureSingleTensor(Tensors tensors, string methodnName)
}

public override string ToString()
=> items.Count() == 1
? items.First().ToString()
: items.Count() + " Tensors" + ". " + string.Join(", ", items.Select(x => x.name));
{
if(items.Count == 1)
{
return items[0].ToString();
}
else
{
StringBuilder sb = new StringBuilder();
sb.Append($"Totally {items.Count} tensors, which are {string.Join(", ", items.Select(x => x.name))}\n[\n");
for(int i = 0; i < items.Count; i++)
{
var tensor = items[i];
sb.Append($"Tensor {i}({tensor.name}): {tensor.ToString()}\n");
}
sb.Append("]\n");
return sb.ToString();
}
}

public void Dispose()
{
Expand Down
11 changes: 11 additions & 0 deletions src/TensorFlowNET.Core/Tensors/dtypes.cs
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,17 @@ public static bool is_integer(this TF_DataType type)
type == TF_DataType.DtInt32Ref || type == TF_DataType.DtInt64Ref;
}

public static bool is_unsigned(this TF_DataType type)
{
return type == TF_DataType.TF_UINT8 || type == TF_DataType.TF_UINT16 || type == TF_DataType.TF_UINT32 ||
type == TF_DataType.TF_UINT64;
}

public static bool is_bool(this TF_DataType type)
{
return type == TF_DataType.TF_BOOL;
}

public static bool is_floating(this TF_DataType type)
{
return type == TF_DataType.TF_HALF || type == TF_DataType.TF_FLOAT || type == TF_DataType.TF_DOUBLE;
Expand Down
4 changes: 2 additions & 2 deletions src/TensorFlowNET.Keras/Engine/Layer.AddWeights.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ protected virtual IVariableV1 add_weight(string name,
// If dtype is DT_FLOAT, provide a uniform unit scaling initializer
if (dtype.is_floating())
initializer = tf.glorot_uniform_initializer;
else if (dtype.is_integer())
else if (dtype.is_integer() || dtype.is_unsigned() || dtype.is_bool())
initializer = tf.zeros_initializer;
else
else if(getter is null)
throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {name}");
}

Expand Down
4 changes: 4 additions & 0 deletions src/TensorFlowNET.Keras/Saving/KerasMetaData.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,9 @@ public class KerasMetaData
public bool? Stateful { get; set; }
[JsonProperty("model_config")]
public KerasModelConfig? ModelConfig { get; set; }
[JsonProperty("sparse")]
public bool Sparse { get; set; }
[JsonProperty("ragged")]
public bool Ragged { get; set; }
}
}
13 changes: 11 additions & 2 deletions src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -401,13 +401,22 @@ private void _unblock_model_reconstruction(int layer_id, Layer layer)

private (Trackable, Action<object, object, object>) revive_custom_object(string identifier, KerasMetaData metadata)
{
if(identifier == SavedModel.Constants.LAYER_IDENTIFIER)
if (identifier == SavedModel.Constants.LAYER_IDENTIFIER)
{
return RevivedLayer.init_from_metadata(metadata);
}
else if(identifier == SavedModel.Constants.MODEL_IDENTIFIER || identifier == SavedModel.Constants.SEQUENTIAL_IDENTIFIER
|| identifier == SavedModel.Constants.NETWORK_IDENTIFIER)
{
return RevivedNetwork.init_from_metadata(metadata);
}
else if(identifier == SavedModel.Constants.INPUT_LAYER_IDENTIFIER)
{
return RevivedInputLayer.init_from_metadata(metadata);
}
else
{
throw new NotImplementedException();
throw new ValueError($"Cannot revive the layer {identifier}.");
}
}

Expand Down
37 changes: 34 additions & 3 deletions src/TensorFlowNET.Keras/Saving/SavedModel/RevivedInputLayer.cs
Original file line number Diff line number Diff line change
@@ -1,15 +1,46 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers;

namespace Tensorflow.Keras.Saving.SavedModel
{
public class RevivedInputLayer: Layer
public class RevivedInputLayer: InputLayer
{
private RevivedInputLayer(): base(null)
protected RevivedConfig _config = null;
private RevivedInputLayer(InputLayerArgs args): base(args)
{
throw new NotImplementedException();

}

public override IKerasConfig get_config()
{
return _config;
}

public static (RevivedInputLayer, Action<object, object, object>) init_from_metadata(KerasMetaData metadata)
{
InputLayerArgs args = new InputLayerArgs()
{
Name = metadata.Name,
DType = metadata.DType,
Sparse = metadata.Sparse,
Ragged = metadata.Ragged,
BatchInputShape = metadata.BatchInputShape
};

RevivedInputLayer revived_obj = new RevivedInputLayer(args);

revived_obj._config = new RevivedConfig() { Config = metadata.Config };

return (revived_obj, Loader.setattr);
}

public override string ToString()
{
return $"Customized keras input layer: {Name}.";
}
}
}
16 changes: 2 additions & 14 deletions src/TensorFlowNET.Keras/Saving/SavedModel/RevivedLayer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public static (RevivedLayer, Action<object, object, object>) init_from_metadata(
return (revived_obj, ReviveUtils._revive_setter);
}

private RevivedConfig _config = null;
protected RevivedConfig _config = null;

public object keras_api
{
Expand All @@ -70,7 +70,7 @@ public object keras_api
}
}

public RevivedLayer(LayerArgs args): base(args)
protected RevivedLayer(LayerArgs args): base(args)
{

}
Expand All @@ -84,17 +84,5 @@ public override IKerasConfig get_config()
{
return _config;
}

//protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
//{
// if(SerializedAttributes is null || !SerializedAttributes.TryGetValue("__call__", out var func) || func is not Function)
// {
// return base.Call(inputs, state, training);
// }
// else
// {
// return (func as Function).Apply(inputs);
// }
//}
}
}
40 changes: 40 additions & 0 deletions src/TensorFlowNET.Keras/Saving/SavedModel/RevivedNetwork.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Utils;

namespace Tensorflow.Keras.Saving.SavedModel
{
public class RevivedNetwork: RevivedLayer
{
private RevivedNetwork(LayerArgs args) : base(args)
{

}

public static (RevivedNetwork, Action<object, object, object>) init_from_metadata(KerasMetaData metadata)
{
RevivedNetwork revived_obj = new(new LayerArgs() { Name = metadata.Name });

// TODO(Rinne): with utils.no_automatic_dependency_tracking_scope(revived_obj)
// TODO(Rinne): revived_obj._expects_training_arg
var config = metadata.Config;
if (generic_utils.validate_config(config))
{
revived_obj._config = new RevivedConfig() { Config = config };
}
if(metadata.ActivityRegularizer is not null)
{
throw new NotImplementedException();
}

return (revived_obj, ReviveUtils._revive_setter);
}

public override string ToString()
{
return $"Customized keras Network: {Name}.";
}
}
}
57 changes: 57 additions & 0 deletions src/TensorflowNET.Hub/GcsCompressedFileResolver.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
using System.IO;
using System.Threading.Tasks;

namespace Tensorflow.Hub
{
public class GcsCompressedFileResolver : IResolver
{
const int LOCK_FILE_TIMEOUT_SEC = 10 * 60;
public string Call(string handle)
{
var module_dir = _module_dir(handle);

return resolver.atomic_download_async(handle, download, module_dir, LOCK_FILE_TIMEOUT_SEC)
.GetAwaiter().GetResult();
}
public bool IsSupported(string handle)
{
return handle.StartsWith("gs://") && _is_tarfile(handle);
}

private async Task download(string handle, string tmp_dir)
{
new resolver.DownloadManager(handle).download_and_uncompress(
new FileStream(handle, FileMode.Open, FileAccess.Read), tmp_dir);
await Task.Run(() => { });
}

private static string _module_dir(string handle)
{
var cache_dir = resolver.tfhub_cache_dir(use_temp: true);
var sha1 = ComputeSha1(handle);
return resolver.create_local_module_dir(cache_dir, sha1);
}

private static bool _is_tarfile(string filename)
{
return filename.EndsWith(".tar") || filename.EndsWith(".tar.gz") || filename.EndsWith(".tgz");
}

private static string ComputeSha1(string s)
{
using (var sha = new System.Security.Cryptography.SHA1Managed())
{
var bytes = System.Text.Encoding.UTF8.GetBytes(s);
var hash = sha.ComputeHash(bytes);
var stringBuilder = new System.Text.StringBuilder(hash.Length * 2);

foreach (var b in hash)
{
stringBuilder.Append(b.ToString("x2"));
}

return stringBuilder.ToString();
}
}
}
}
78 changes: 78 additions & 0 deletions src/TensorflowNET.Hub/HttpCompressedFileResolver.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
using System;
using System.Net.Http;
using System.Threading.Tasks;

namespace Tensorflow.Hub
{
public class HttpCompressedFileResolver : HttpResolverBase
{
const int LOCK_FILE_TIMEOUT_SEC = 10 * 60; // 10 minutes

private static readonly (string, string) _COMPRESSED_FORMAT_QUERY =
("tf-hub-format", "compressed");

private static string _module_dir(string handle)
{
var cache_dir = resolver.tfhub_cache_dir(use_temp: true);
var sha1 = ComputeSha1(handle);
return resolver.create_local_module_dir(cache_dir, sha1);
}

public override bool IsSupported(string handle)
{
if (!is_http_protocol(handle))
{
return false;
}
var load_format = resolver.model_load_format();
return load_format == Enum.GetName(typeof(resolver.ModelLoadFormat), resolver.ModelLoadFormat.COMPRESSED)
|| load_format == Enum.GetName(typeof(resolver.ModelLoadFormat), resolver.ModelLoadFormat.AUTO);
}

public override string Call(string handle)
{
var module_dir = _module_dir(handle);

return resolver.atomic_download_async(
handle,
download,
module_dir,
LOCK_FILE_TIMEOUT_SEC
).GetAwaiter().GetResult();
}

private async Task download(string handle, string tmp_dir)
{
var client = new HttpClient();

var response = await client.GetAsync(_append_compressed_format_query(handle));

using (var httpStream = await response.Content.ReadAsStreamAsync())
{
new resolver.DownloadManager(handle).download_and_uncompress(httpStream, tmp_dir);
}
}

private string _append_compressed_format_query(string handle)
{
return append_format_query(handle, _COMPRESSED_FORMAT_QUERY);
}

private static string ComputeSha1(string s)
{
using (var sha = new System.Security.Cryptography.SHA1Managed())
{
var bytes = System.Text.Encoding.UTF8.GetBytes(s);
var hash = sha.ComputeHash(bytes);
var stringBuilder = new System.Text.StringBuilder(hash.Length * 2);

foreach (var b in hash)
{
stringBuilder.Append(b.ToString("x2"));
}

return stringBuilder.ToString();
}
}
}
}
Loading