Skip to content

Commit a767765

Browse files
authored
Merge pull request #1034 from AsakusaRinne/support_bert_load
Add Tensorflow.NET.Hub
2 parents c20d854 + 9174eab commit a767765

23 files changed

+1433
-24
lines changed

TensorFlow.NET.sln

+28
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Keras.UnitTest",
2323
EndProject
2424
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Graph.UnitTest", "test\TensorFlowNET.Graph.UnitTest\TensorFlowNET.Graph.UnitTest.csproj", "{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}"
2525
EndProject
26+
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Hub", "src\TensorflowNET.Hub\Tensorflow.Hub.csproj", "{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}"
27+
EndProject
28+
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Hub.Unittest", "test\TensorflowNET.Hub.Unittest\Tensorflow.Hub.Unittest.csproj", "{7DEA8760-E401-4872-81F3-405F185A13A0}"
29+
EndProject
2630
Global
2731
GlobalSection(SolutionConfigurationPlatforms) = preSolution
2832
Debug|Any CPU = Debug|Any CPU
@@ -153,6 +157,30 @@ Global
153157
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x64.Build.0 = Release|x64
154158
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.ActiveCfg = Release|Any CPU
155159
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.Build.0 = Release|Any CPU
160+
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
161+
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|Any CPU.Build.0 = Debug|Any CPU
162+
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|x64.ActiveCfg = Debug|Any CPU
163+
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|x64.Build.0 = Debug|Any CPU
164+
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|x86.ActiveCfg = Debug|Any CPU
165+
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Debug|x86.Build.0 = Debug|Any CPU
166+
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|Any CPU.ActiveCfg = Release|Any CPU
167+
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|Any CPU.Build.0 = Release|Any CPU
168+
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|x64.ActiveCfg = Release|Any CPU
169+
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|x64.Build.0 = Release|Any CPU
170+
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|x86.ActiveCfg = Release|Any CPU
171+
{9738D16A-CFA0-405C-A7DF-D3D203B0CB18}.Release|x86.Build.0 = Release|Any CPU
172+
{7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
173+
{7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|Any CPU.Build.0 = Debug|Any CPU
174+
{7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|x64.ActiveCfg = Debug|Any CPU
175+
{7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|x64.Build.0 = Debug|Any CPU
176+
{7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|x86.ActiveCfg = Debug|Any CPU
177+
{7DEA8760-E401-4872-81F3-405F185A13A0}.Debug|x86.Build.0 = Debug|Any CPU
178+
{7DEA8760-E401-4872-81F3-405F185A13A0}.Release|Any CPU.ActiveCfg = Release|Any CPU
179+
{7DEA8760-E401-4872-81F3-405F185A13A0}.Release|Any CPU.Build.0 = Release|Any CPU
180+
{7DEA8760-E401-4872-81F3-405F185A13A0}.Release|x64.ActiveCfg = Release|Any CPU
181+
{7DEA8760-E401-4872-81F3-405F185A13A0}.Release|x64.Build.0 = Release|Any CPU
182+
{7DEA8760-E401-4872-81F3-405F185A13A0}.Release|x86.ActiveCfg = Release|Any CPU
183+
{7DEA8760-E401-4872-81F3-405F185A13A0}.Release|x86.Build.0 = Release|Any CPU
156184
EndGlobalSection
157185
GlobalSection(SolutionProperties) = preSolution
158186
HideSolutionNode = FALSE

src/TensorFlowNET.Core/Tensors/Tensors.cs

+18-3
Original file line numberDiff line numberDiff line change
@@ -207,9 +207,24 @@ private static void EnsureSingleTensor(Tensors tensors, string methodnName)
207207
}
208208

209209
public override string ToString()
210-
=> items.Count() == 1
211-
? items.First().ToString()
212-
: items.Count() + " Tensors" + ". " + string.Join(", ", items.Select(x => x.name));
210+
{
211+
if(items.Count == 1)
212+
{
213+
return items[0].ToString();
214+
}
215+
else
216+
{
217+
StringBuilder sb = new StringBuilder();
218+
sb.Append($"Totally {items.Count} tensors, which are {string.Join(", ", items.Select(x => x.name))}\n[\n");
219+
for(int i = 0; i < items.Count; i++)
220+
{
221+
var tensor = items[i];
222+
sb.Append($"Tensor {i}({tensor.name}): {tensor.ToString()}\n");
223+
}
224+
sb.Append("]\n");
225+
return sb.ToString();
226+
}
227+
}
213228

214229
public void Dispose()
215230
{

src/TensorFlowNET.Core/Tensors/dtypes.cs

+11
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,17 @@ public static bool is_integer(this TF_DataType type)
301301
type == TF_DataType.DtInt32Ref || type == TF_DataType.DtInt64Ref;
302302
}
303303

304+
public static bool is_unsigned(this TF_DataType type)
305+
{
306+
return type == TF_DataType.TF_UINT8 || type == TF_DataType.TF_UINT16 || type == TF_DataType.TF_UINT32 ||
307+
type == TF_DataType.TF_UINT64;
308+
}
309+
310+
public static bool is_bool(this TF_DataType type)
311+
{
312+
return type == TF_DataType.TF_BOOL;
313+
}
314+
304315
public static bool is_floating(this TF_DataType type)
305316
{
306317
return type == TF_DataType.TF_HALF || type == TF_DataType.TF_FLOAT || type == TF_DataType.TF_DOUBLE;

src/TensorFlowNET.Keras/Engine/Layer.AddWeights.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ protected virtual IVariableV1 add_weight(string name,
2222
// If dtype is DT_FLOAT, provide a uniform unit scaling initializer
2323
if (dtype.is_floating())
2424
initializer = tf.glorot_uniform_initializer;
25-
else if (dtype.is_integer())
25+
else if (dtype.is_integer() || dtype.is_unsigned() || dtype.is_bool())
2626
initializer = tf.zeros_initializer;
27-
else
27+
else if(getter is null)
2828
throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {name}");
2929
}
3030

src/TensorFlowNET.Keras/Saving/KerasMetaData.cs

+4
Original file line numberDiff line numberDiff line change
@@ -36,5 +36,9 @@ public class KerasMetaData
3636
public bool? Stateful { get; set; }
3737
[JsonProperty("model_config")]
3838
public KerasModelConfig? ModelConfig { get; set; }
39+
[JsonProperty("sparse")]
40+
public bool Sparse { get; set; }
41+
[JsonProperty("ragged")]
42+
public bool Ragged { get; set; }
3943
}
4044
}

src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs

+11-2
Original file line numberDiff line numberDiff line change
@@ -401,13 +401,22 @@ private void _unblock_model_reconstruction(int layer_id, Layer layer)
401401

402402
private (Trackable, Action<object, object, object>) revive_custom_object(string identifier, KerasMetaData metadata)
403403
{
404-
if(identifier == SavedModel.Constants.LAYER_IDENTIFIER)
404+
if (identifier == SavedModel.Constants.LAYER_IDENTIFIER)
405405
{
406406
return RevivedLayer.init_from_metadata(metadata);
407407
}
408+
else if(identifier == SavedModel.Constants.MODEL_IDENTIFIER || identifier == SavedModel.Constants.SEQUENTIAL_IDENTIFIER
409+
|| identifier == SavedModel.Constants.NETWORK_IDENTIFIER)
410+
{
411+
return RevivedNetwork.init_from_metadata(metadata);
412+
}
413+
else if(identifier == SavedModel.Constants.INPUT_LAYER_IDENTIFIER)
414+
{
415+
return RevivedInputLayer.init_from_metadata(metadata);
416+
}
408417
else
409418
{
410-
throw new NotImplementedException();
419+
throw new ValueError($"Cannot revive the layer {identifier}.");
411420
}
412421
}
413422

Original file line numberDiff line numberDiff line change
@@ -1,15 +1,46 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Text;
4+
using Tensorflow.Keras.ArgsDefinition;
45
using Tensorflow.Keras.Engine;
6+
using Tensorflow.Keras.Layers;
57

68
namespace Tensorflow.Keras.Saving.SavedModel
79
{
8-
public class RevivedInputLayer: Layer
10+
public class RevivedInputLayer: InputLayer
911
{
10-
private RevivedInputLayer(): base(null)
12+
protected RevivedConfig _config = null;
13+
private RevivedInputLayer(InputLayerArgs args): base(args)
1114
{
12-
throw new NotImplementedException();
15+
16+
}
17+
18+
public override IKerasConfig get_config()
19+
{
20+
return _config;
21+
}
22+
23+
public static (RevivedInputLayer, Action<object, object, object>) init_from_metadata(KerasMetaData metadata)
24+
{
25+
InputLayerArgs args = new InputLayerArgs()
26+
{
27+
Name = metadata.Name,
28+
DType = metadata.DType,
29+
Sparse = metadata.Sparse,
30+
Ragged = metadata.Ragged,
31+
BatchInputShape = metadata.BatchInputShape
32+
};
33+
34+
RevivedInputLayer revived_obj = new RevivedInputLayer(args);
35+
36+
revived_obj._config = new RevivedConfig() { Config = metadata.Config };
37+
38+
return (revived_obj, Loader.setattr);
39+
}
40+
41+
public override string ToString()
42+
{
43+
return $"Customized keras input layer: {Name}.";
1344
}
1445
}
1546
}

src/TensorFlowNET.Keras/Saving/SavedModel/RevivedLayer.cs

+2-14
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ public static (RevivedLayer, Action<object, object, object>) init_from_metadata(
5353
return (revived_obj, ReviveUtils._revive_setter);
5454
}
5555

56-
private RevivedConfig _config = null;
56+
protected RevivedConfig _config = null;
5757

5858
public object keras_api
5959
{
@@ -70,7 +70,7 @@ public object keras_api
7070
}
7171
}
7272

73-
public RevivedLayer(LayerArgs args): base(args)
73+
protected RevivedLayer(LayerArgs args): base(args)
7474
{
7575

7676
}
@@ -84,17 +84,5 @@ public override IKerasConfig get_config()
8484
{
8585
return _config;
8686
}
87-
88-
//protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
89-
//{
90-
// if(SerializedAttributes is null || !SerializedAttributes.TryGetValue("__call__", out var func) || func is not Function)
91-
// {
92-
// return base.Call(inputs, state, training);
93-
// }
94-
// else
95-
// {
96-
// return (func as Function).Apply(inputs);
97-
// }
98-
//}
9987
}
10088
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Keras.ArgsDefinition;
5+
using Tensorflow.Keras.Utils;
6+
7+
namespace Tensorflow.Keras.Saving.SavedModel
8+
{
9+
public class RevivedNetwork: RevivedLayer
10+
{
11+
private RevivedNetwork(LayerArgs args) : base(args)
12+
{
13+
14+
}
15+
16+
public static (RevivedNetwork, Action<object, object, object>) init_from_metadata(KerasMetaData metadata)
17+
{
18+
RevivedNetwork revived_obj = new(new LayerArgs() { Name = metadata.Name });
19+
20+
// TODO(Rinne): with utils.no_automatic_dependency_tracking_scope(revived_obj)
21+
// TODO(Rinne): revived_obj._expects_training_arg
22+
var config = metadata.Config;
23+
if (generic_utils.validate_config(config))
24+
{
25+
revived_obj._config = new RevivedConfig() { Config = config };
26+
}
27+
if(metadata.ActivityRegularizer is not null)
28+
{
29+
throw new NotImplementedException();
30+
}
31+
32+
return (revived_obj, ReviveUtils._revive_setter);
33+
}
34+
35+
public override string ToString()
36+
{
37+
return $"Customized keras Network: {Name}.";
38+
}
39+
}
40+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
using System.IO;
2+
using System.Threading.Tasks;
3+
4+
namespace Tensorflow.Hub
5+
{
6+
public class GcsCompressedFileResolver : IResolver
7+
{
8+
const int LOCK_FILE_TIMEOUT_SEC = 10 * 60;
9+
public string Call(string handle)
10+
{
11+
var module_dir = _module_dir(handle);
12+
13+
return resolver.atomic_download_async(handle, download, module_dir, LOCK_FILE_TIMEOUT_SEC)
14+
.GetAwaiter().GetResult();
15+
}
16+
public bool IsSupported(string handle)
17+
{
18+
return handle.StartsWith("gs://") && _is_tarfile(handle);
19+
}
20+
21+
private async Task download(string handle, string tmp_dir)
22+
{
23+
new resolver.DownloadManager(handle).download_and_uncompress(
24+
new FileStream(handle, FileMode.Open, FileAccess.Read), tmp_dir);
25+
await Task.Run(() => { });
26+
}
27+
28+
private static string _module_dir(string handle)
29+
{
30+
var cache_dir = resolver.tfhub_cache_dir(use_temp: true);
31+
var sha1 = ComputeSha1(handle);
32+
return resolver.create_local_module_dir(cache_dir, sha1);
33+
}
34+
35+
private static bool _is_tarfile(string filename)
36+
{
37+
return filename.EndsWith(".tar") || filename.EndsWith(".tar.gz") || filename.EndsWith(".tgz");
38+
}
39+
40+
private static string ComputeSha1(string s)
41+
{
42+
using (var sha = new System.Security.Cryptography.SHA1Managed())
43+
{
44+
var bytes = System.Text.Encoding.UTF8.GetBytes(s);
45+
var hash = sha.ComputeHash(bytes);
46+
var stringBuilder = new System.Text.StringBuilder(hash.Length * 2);
47+
48+
foreach (var b in hash)
49+
{
50+
stringBuilder.Append(b.ToString("x2"));
51+
}
52+
53+
return stringBuilder.ToString();
54+
}
55+
}
56+
}
57+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
using System;
2+
using System.Net.Http;
3+
using System.Threading.Tasks;
4+
5+
namespace Tensorflow.Hub
6+
{
7+
public class HttpCompressedFileResolver : HttpResolverBase
8+
{
9+
const int LOCK_FILE_TIMEOUT_SEC = 10 * 60; // 10 minutes
10+
11+
private static readonly (string, string) _COMPRESSED_FORMAT_QUERY =
12+
("tf-hub-format", "compressed");
13+
14+
private static string _module_dir(string handle)
15+
{
16+
var cache_dir = resolver.tfhub_cache_dir(use_temp: true);
17+
var sha1 = ComputeSha1(handle);
18+
return resolver.create_local_module_dir(cache_dir, sha1);
19+
}
20+
21+
public override bool IsSupported(string handle)
22+
{
23+
if (!is_http_protocol(handle))
24+
{
25+
return false;
26+
}
27+
var load_format = resolver.model_load_format();
28+
return load_format == Enum.GetName(typeof(resolver.ModelLoadFormat), resolver.ModelLoadFormat.COMPRESSED)
29+
|| load_format == Enum.GetName(typeof(resolver.ModelLoadFormat), resolver.ModelLoadFormat.AUTO);
30+
}
31+
32+
public override string Call(string handle)
33+
{
34+
var module_dir = _module_dir(handle);
35+
36+
return resolver.atomic_download_async(
37+
handle,
38+
download,
39+
module_dir,
40+
LOCK_FILE_TIMEOUT_SEC
41+
).GetAwaiter().GetResult();
42+
}
43+
44+
private async Task download(string handle, string tmp_dir)
45+
{
46+
var client = new HttpClient();
47+
48+
var response = await client.GetAsync(_append_compressed_format_query(handle));
49+
50+
using (var httpStream = await response.Content.ReadAsStreamAsync())
51+
{
52+
new resolver.DownloadManager(handle).download_and_uncompress(httpStream, tmp_dir);
53+
}
54+
}
55+
56+
private string _append_compressed_format_query(string handle)
57+
{
58+
return append_format_query(handle, _COMPRESSED_FORMAT_QUERY);
59+
}
60+
61+
private static string ComputeSha1(string s)
62+
{
63+
using (var sha = new System.Security.Cryptography.SHA1Managed())
64+
{
65+
var bytes = System.Text.Encoding.UTF8.GetBytes(s);
66+
var hash = sha.ComputeHash(bytes);
67+
var stringBuilder = new System.Text.StringBuilder(hash.Length * 2);
68+
69+
foreach (var b in hash)
70+
{
71+
stringBuilder.Append(b.ToString("x2"));
72+
}
73+
74+
return stringBuilder.ToString();
75+
}
76+
}
77+
}
78+
}

0 commit comments

Comments
 (0)