Skip to content

Commit ece36e6

Browse files
AsakusaRinneOceania2018
authored andcommitted
Automatically add KerasInterface to f.
1 parent e5837dc commit ece36e6

File tree

12 files changed

+132
-28
lines changed

12 files changed

+132
-28
lines changed

src/TensorFlowNET.Console/Program.cs

-2
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@ class Program
88
{
99
static void Main(string[] args)
1010
{
11-
tf.UseKeras<KerasInterface>();
12-
1311
var diag = new Diagnostician();
1412
// diag.Diagnose(@"D:\memory.txt");
1513

+39-5
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,60 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Text;
4+
using System.Threading;
5+
using Tensorflow.Framework.Models;
46
using Tensorflow.Keras.Engine;
57
using Tensorflow.Keras.Layers;
68
using Tensorflow.Keras.Losses;
79
using Tensorflow.Keras.Metrics;
10+
using Tensorflow.Keras.Models;
811

912
namespace Tensorflow.Keras
1013
{
1114
public interface IKerasApi
1215
{
13-
public ILayersApi layers { get; }
14-
public ILossesApi losses { get; }
15-
public IMetricsApi metrics { get; }
16-
public IInitializersApi initializers { get; }
16+
IInitializersApi initializers { get; }
17+
ILayersApi layers { get; }
18+
ILossesApi losses { get; }
19+
IOptimizerApi optimizers { get; }
20+
IMetricsApi metrics { get; }
21+
IModelsApi models { get; }
1722

1823
/// <summary>
1924
/// `Model` groups layers into an object with training and inference features.
2025
/// </summary>
2126
/// <param name="input"></param>
2227
/// <param name="output"></param>
2328
/// <returns></returns>
24-
public IModel Model(Tensors inputs, Tensors outputs, string name = null);
29+
IModel Model(Tensors inputs, Tensors outputs, string name = null);
30+
31+
/// <summary>
32+
/// Instantiate a Keras tensor.
33+
/// </summary>
34+
/// <param name="shape"></param>
35+
/// <param name="batch_size"></param>
36+
/// <param name="dtype"></param>
37+
/// <param name="name"></param>
38+
/// <param name="sparse">
39+
/// A boolean specifying whether the placeholder to be created is sparse.
40+
/// </param>
41+
/// <param name="ragged">
42+
/// A boolean specifying whether the placeholder to be created is ragged.
43+
/// </param>
44+
/// <param name="tensor">
45+
/// Optional existing tensor to wrap into the `Input` layer.
46+
/// If set, the layer will not create a placeholder tensor.
47+
/// </param>
48+
/// <returns></returns>
49+
Tensors Input(Shape shape = null,
50+
int batch_size = -1,
51+
string name = null,
52+
TF_DataType dtype = TF_DataType.DtInvalid,
53+
bool sparse = false,
54+
Tensor tensor = null,
55+
bool ragged = false,
56+
TypeSpec type_spec = null,
57+
Shape batch_input_shape = null,
58+
Shape batch_shape = null);
2559
}
2660
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Keras.Engine;
5+
6+
namespace Tensorflow.Keras
7+
{
8+
public interface IOptimizerApi
9+
{
10+
/// <summary>
11+
/// Adam optimization is a stochastic gradient descent method that is based on
12+
/// adaptive estimation of first-order and second-order moments.
13+
/// </summary>
14+
/// <param name="learning_rate"></param>
15+
/// <param name="beta_1"></param>
16+
/// <param name="beta_2"></param>
17+
/// <param name="epsilon"></param>
18+
/// <param name="amsgrad"></param>
19+
/// <param name="name"></param>
20+
/// <returns></returns>
21+
IOptimizer Adam(float learning_rate = 0.001f,
22+
float beta_1 = 0.9f,
23+
float beta_2 = 0.999f,
24+
float epsilon = 1e-7f,
25+
bool amsgrad = false,
26+
string name = "Adam");
27+
28+
/// <summary>
29+
/// Construct a new RMSprop optimizer.
30+
/// </summary>
31+
/// <param name="learning_rate"></param>
32+
/// <param name="rho"></param>
33+
/// <param name="momentum"></param>
34+
/// <param name="epsilon"></param>
35+
/// <param name="centered"></param>
36+
/// <param name="name"></param>
37+
/// <returns></returns>
38+
IOptimizer RMSprop(float learning_rate = 0.001f,
39+
float rho = 0.9f,
40+
float momentum = 0.0f,
41+
float epsilon = 1e-7f,
42+
bool centered = false,
43+
string name = "RMSprop");
44+
45+
IOptimizer SGD(float learning_rate);
46+
}
47+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Keras.Engine;
5+
6+
namespace Tensorflow.Keras.Models
7+
{
8+
public interface IModelsApi
9+
{
10+
public IModel load_model(string filepath, bool compile = true, LoadOptions? options = null);
11+
}
12+
}

src/TensorFlowNET.Core/tensorflow.cs

-8
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,6 @@ public tensorflow()
6565
InitGradientEnvironment();
6666
}
6767

68-
public void UseKeras<T>() where T : IKerasApi, new()
69-
{
70-
if (keras == null)
71-
{
72-
keras = new T();
73-
}
74-
}
75-
7668
public string VERSION => c_api.StringPiece(c_api.TF_Version());
7769

7870
private void InitGradientEnvironment()

src/TensorFlowNET.Keras/KerasApi.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ namespace Tensorflow
77
/// </summary>
88
public static class KerasApi
99
{
10-
public static KerasInterface keras { get; } = new KerasInterface();
10+
public static KerasInterface keras { get; } = KerasInterface.Instance;
1111
}
1212
}

src/TensorFlowNET.Keras/KerasInterface.cs

+24-2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,28 @@ namespace Tensorflow.Keras
1818
{
1919
public class KerasInterface : IKerasApi
2020
{
21+
private static KerasInterface _instance = null;
22+
private static readonly object _lock = new object();
23+
private KerasInterface()
24+
{
25+
Tensorflow.Binding.tf.keras = this;
26+
}
27+
28+
public static KerasInterface Instance
29+
{
30+
get
31+
{
32+
lock (_lock)
33+
{
34+
if (_instance is null)
35+
{
36+
_instance = new KerasInterface();
37+
}
38+
return _instance;
39+
}
40+
}
41+
}
42+
2143
public KerasDataset datasets { get; } = new KerasDataset();
2244
public IInitializersApi initializers { get; } = new InitializersApi();
2345
public Regularizers regularizers { get; } = new Regularizers();
@@ -27,9 +49,9 @@ public class KerasInterface : IKerasApi
2749
public Preprocessing preprocessing { get; } = new Preprocessing();
2850
ThreadLocal<BackendImpl> _backend = new ThreadLocal<BackendImpl>(() => new BackendImpl());
2951
public BackendImpl backend => _backend.Value;
30-
public OptimizerApi optimizers { get; } = new OptimizerApi();
52+
public IOptimizerApi optimizers { get; } = new OptimizerApi();
3153
public IMetricsApi metrics { get; } = new MetricsApi();
32-
public ModelsApi models { get; } = new ModelsApi();
54+
public IModelsApi models { get; } = new ModelsApi();
3355
public KerasUtils utils { get; } = new KerasUtils();
3456

3557
public Sequential Sequential(List<ILayer> layers = null,

src/TensorFlowNET.Keras/Models/ModelsApi.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99

1010
namespace Tensorflow.Keras.Models
1111
{
12-
public class ModelsApi
12+
public class ModelsApi: IModelsApi
1313
{
1414
public Functional from_config(ModelConfig config)
1515
=> Functional.from_config(config);
1616

17-
public Model load_model(string filepath, bool compile = true, LoadOptions? options = null)
17+
public IModel load_model(string filepath, bool compile = true, LoadOptions? options = null)
1818
{
1919
return KerasLoadModelUtils.load_model(filepath, compile: compile, options: options) as Model;
2020
}

src/TensorFlowNET.Keras/Optimizers/OptimizerApi.cs

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
using Tensorflow.Keras.ArgsDefinition;
2+
using Tensorflow.Keras.Engine;
23

34
namespace Tensorflow.Keras.Optimizers
45
{
5-
public class OptimizerApi
6+
public class OptimizerApi: IOptimizerApi
67
{
78
/// <summary>
89
/// Adam optimization is a stochastic gradient descent method that is based on
@@ -15,7 +16,7 @@ public class OptimizerApi
1516
/// <param name="amsgrad"></param>
1617
/// <param name="name"></param>
1718
/// <returns></returns>
18-
public OptimizerV2 Adam(float learning_rate = 0.001f,
19+
public IOptimizer Adam(float learning_rate = 0.001f,
1920
float beta_1 = 0.9f,
2021
float beta_2 = 0.999f,
2122
float epsilon = 1e-7f,
@@ -38,7 +39,7 @@ public OptimizerV2 Adam(float learning_rate = 0.001f,
3839
/// <param name="centered"></param>
3940
/// <param name="name"></param>
4041
/// <returns></returns>
41-
public OptimizerV2 RMSprop(float learning_rate = 0.001f,
42+
public IOptimizer RMSprop(float learning_rate = 0.001f,
4243
float rho = 0.9f,
4344
float momentum = 0.0f,
4445
float epsilon = 1e-7f,
@@ -54,7 +55,7 @@ public OptimizerV2 RMSprop(float learning_rate = 0.001f,
5455
Name = name
5556
});
5657

57-
public SGD SGD(float learning_rate)
58+
public IOptimizer SGD(float learning_rate)
5859
=> new SGD(learning_rate);
5960
}
6061
}

test/TensorFlowNET.Keras.UnitTest/EagerModeTestBase.cs

-2
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@ public class EagerModeTestBase
1010
[TestInitialize]
1111
public void TestInit()
1212
{
13-
tf.UseKeras<KerasInterface>();
14-
1513
if (!tf.executing_eagerly())
1614
tf.enable_eager_execution();
1715
tf.Context.ensure_initialized();

test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs

-1
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,6 @@ public void EinsumDense()
150150
[TestMethod, Ignore("WIP")]
151151
public void SimpleRNN()
152152
{
153-
tf.UseKeras<KerasInterface>();
154153
var inputs = np.arange(6 * 10 * 8).reshape((6, 10, 8)).astype(np.float32);
155154
/*var simple_rnn = keras.layers.SimpleRNN(4);
156155
var output = simple_rnn.Apply(inputs);

test/TensorFlowNET.Keras.UnitTest/Layers/ModelSaveTest.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using System.Diagnostics;
44
using static Tensorflow.KerasApi;
55
using Tensorflow.Keras.Saving;
6+
using Tensorflow.Keras.Models;
67

78
namespace TensorFlowNET.Keras.UnitTest
89
{
@@ -18,7 +19,7 @@ public void GetAndFromConfig()
1819
var model = GetFunctionalModel();
1920
var config = model.get_config();
2021
Debug.Assert(config is ModelConfig);
21-
var new_model = keras.models.from_config(config as ModelConfig);
22+
var new_model = new ModelsApi().from_config(config as ModelConfig);
2223
Assert.AreEqual(model.Layers.Count, new_model.Layers.Count);
2324
}
2425

0 commit comments

Comments
 (0)