Skip to content

Commit c8947fa

Browse files
committed
Refine the keras.Activation and add tf.keras.activations.
1 parent d8341e4 commit c8947fa

File tree

16 files changed

+111
-126
lines changed

16 files changed

+111
-126
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,44 @@
1-
namespace Tensorflow.Keras
1+
using Newtonsoft.Json;
2+
using System.Reflection;
3+
using System.Runtime.Versioning;
4+
using Tensorflow.Keras.Common;
5+
6+
namespace Tensorflow.Keras
27
{
3-
public delegate Tensor Activation(Tensor features, string name = null);
8+
[JsonConverter(typeof(CustomizedActivationJsonConverter))]
9+
public class Activation
10+
{
11+
public string Name { get; set; }
12+
/// <summary>
13+
/// The parameters are `features` and `name`.
14+
/// </summary>
15+
public Func<Tensor, string, Tensor> ActivationFunction { get; set; }
16+
17+
public Tensor Apply(Tensor input, string name = null) => ActivationFunction(input, name);
18+
19+
public static implicit operator Activation(Func<Tensor, string, Tensor> func)
20+
{
21+
return new Activation()
22+
{
23+
Name = func.GetMethodInfo().Name,
24+
ActivationFunction = func
25+
};
26+
}
27+
}
28+
29+
public interface IActivationsApi
30+
{
31+
Activation GetActivationFromName(string name);
32+
Activation Linear { get; }
33+
34+
Activation Relu { get; }
35+
36+
Activation Sigmoid { get; }
37+
38+
Activation Softmax { get; }
39+
40+
Activation Tanh { get; }
41+
42+
Activation Mish { get; }
43+
}
444
}

src/TensorFlowNET.Core/Keras/ArgsDefinition/Convolution/ConvolutionalArgs.cs

+1-20
Original file line numberDiff line numberDiff line change
@@ -26,27 +26,8 @@ public class ConvolutionalArgs : AutoSerializeLayerArgs
2626
public Shape DilationRate { get; set; } = (1, 1);
2727
[JsonProperty("groups")]
2828
public int Groups { get; set; } = 1;
29-
public Activation Activation { get; set; }
30-
private string _activationName;
3129
[JsonProperty("activation")]
32-
public string ActivationName
33-
{
34-
get
35-
{
36-
if (string.IsNullOrEmpty(_activationName))
37-
{
38-
return Activation.Method.Name;
39-
}
40-
else
41-
{
42-
return _activationName;
43-
}
44-
}
45-
set
46-
{
47-
_activationName = value;
48-
}
49-
}
30+
public Activation Activation { get; set; }
5031
[JsonProperty("use_bias")]
5132
public bool UseBias { get; set; }
5233
[JsonProperty("kernel_initializer")]

src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/DenseArgs.cs

+1-21
Original file line numberDiff line numberDiff line change
@@ -18,28 +18,8 @@ public class DenseArgs : LayerArgs
1818
/// <summary>
1919
/// Activation function to use.
2020
/// </summary>
21-
public Activation Activation { get; set; }
22-
23-
private string _activationName;
2421
[JsonProperty("activation")]
25-
public string ActivationName
26-
{
27-
get
28-
{
29-
if (string.IsNullOrEmpty(_activationName))
30-
{
31-
return Activation.Method.Name;
32-
}
33-
else
34-
{
35-
return _activationName;
36-
}
37-
}
38-
set
39-
{
40-
_activationName = value;
41-
}
42-
}
22+
public Activation Activation { get; set; }
4323

4424
/// <summary>
4525
/// Whether the layer uses a bias vector.

src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/EinsumDenseArgs.cs

+1-20
Original file line numberDiff line numberDiff line change
@@ -35,27 +35,8 @@ public class EinsumDenseArgs : AutoSerializeLayerArgs
3535
/// <summary>
3636
/// Activation function to use.
3737
/// </summary>
38-
public Activation Activation { get; set; }
39-
private string _activationName;
4038
[JsonProperty("activation")]
41-
public string ActivationName
42-
{
43-
get
44-
{
45-
if (string.IsNullOrEmpty(_activationName))
46-
{
47-
return Activation.Method.Name;
48-
}
49-
else
50-
{
51-
return _activationName;
52-
}
53-
}
54-
set
55-
{
56-
_activationName = value;
57-
}
58-
}
39+
public Activation Activation { get; set; }
5940

6041
/// <summary>
6142
/// Initializer for the `kernel` weights matrix.

src/TensorFlowNET.Core/Keras/Common/CustomizedActivationJsonConverter.cs

+8-8
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System;
55
using System.Collections.Generic;
66
using System.Text;
7+
using static Tensorflow.Binding;
78

89
namespace Tensorflow.Keras.Common
910
{
@@ -31,20 +32,19 @@ public override void WriteJson(JsonWriter writer, object? value, JsonSerializer
3132
}
3233
else
3334
{
34-
var token = JToken.FromObject((value as Activation)!.GetType().Name);
35+
var token = JToken.FromObject(((Activation)value).Name);
3536
token.WriteTo(writer);
3637
}
3738
}
3839

3940
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
4041
{
41-
throw new NotImplementedException();
42-
//var dims = serializer.Deserialize(reader, typeof(string));
43-
//if (dims is null)
44-
//{
45-
// throw new ValueError("Cannot deserialize 'null' to `Activation`.");
46-
//}
47-
//return new Shape((long[])(dims!));
42+
var activationName = serializer.Deserialize<string>(reader);
43+
if (tf.keras is null)
44+
{
45+
throw new RuntimeError("Tensorflow.Keras is not loaded, please install it first.");
46+
}
47+
return tf.keras.activations.GetActivationFromName(string.IsNullOrEmpty(activationName) ? "linear" : activationName);
4848
}
4949
}
5050
}

src/TensorFlowNET.Core/Keras/IKerasApi.cs

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ public interface IKerasApi
1616
IInitializersApi initializers { get; }
1717
ILayersApi layers { get; }
1818
ILossesApi losses { get; }
19+
IActivationsApi activations { get; }
1920
IOptimizerApi optimizers { get; }
2021
IMetricsApi metrics { get; }
2122
IModelsApi models { get; }

src/TensorFlowNET.Keras/Activations.cs

+40-36
Original file line numberDiff line numberDiff line change
@@ -6,45 +6,61 @@
66

77
namespace Tensorflow.Keras
88
{
9-
public class Activations
9+
public class Activations: IActivationsApi
1010
{
1111
private static Dictionary<string, Activation> _nameActivationMap;
12-
private static Dictionary<Activation, string> _activationNameMap;
1312

14-
private static Activation _linear = (features, name) => features;
15-
private static Activation _relu = (features, name)
16-
=> tf.Context.ExecuteOp("Relu", name, new ExecuteOpArgs(features));
17-
private static Activation _sigmoid = (features, name)
18-
=> tf.Context.ExecuteOp("Sigmoid", name, new ExecuteOpArgs(features));
19-
private static Activation _softmax = (features, name)
20-
=> tf.Context.ExecuteOp("Softmax", name, new ExecuteOpArgs(features));
21-
private static Activation _tanh = (features, name)
22-
=> tf.Context.ExecuteOp("Tanh", name, new ExecuteOpArgs(features));
23-
private static Activation _mish = (features, name)
24-
=> features * tf.math.tanh(tf.math.softplus(features));
13+
private static Activation _linear = new Activation()
14+
{
15+
Name = "linear",
16+
ActivationFunction = (features, name) => features
17+
};
18+
private static Activation _relu = new Activation()
19+
{
20+
Name = "relu",
21+
ActivationFunction = (features, name) => tf.Context.ExecuteOp("Relu", name, new ExecuteOpArgs(features))
22+
};
23+
private static Activation _sigmoid = new Activation()
24+
{
25+
Name = "sigmoid",
26+
ActivationFunction = (features, name) => tf.Context.ExecuteOp("Sigmoid", name, new ExecuteOpArgs(features))
27+
};
28+
private static Activation _softmax = new Activation()
29+
{
30+
Name = "softmax",
31+
ActivationFunction = (features, name) => tf.Context.ExecuteOp("Softmax", name, new ExecuteOpArgs(features))
32+
};
33+
private static Activation _tanh = new Activation()
34+
{
35+
Name = "tanh",
36+
ActivationFunction = (features, name) => tf.Context.ExecuteOp("Tanh", name, new ExecuteOpArgs(features))
37+
};
38+
private static Activation _mish = new Activation()
39+
{
40+
Name = "mish",
41+
ActivationFunction = (features, name) => features * tf.math.tanh(tf.math.softplus(features))
42+
};
2543

2644
/// <summary>
2745
/// Register the name-activation mapping in this static class.
2846
/// </summary>
2947
/// <param name="name"></param>
3048
/// <param name="activation"></param>
31-
private static void RegisterActivation(string name, Activation activation)
49+
private static void RegisterActivation(Activation activation)
3250
{
33-
_nameActivationMap[name] = activation;
34-
_activationNameMap[activation] = name;
51+
_nameActivationMap[activation.Name] = activation;
3552
}
3653

3754
static Activations()
3855
{
3956
_nameActivationMap = new Dictionary<string, Activation>();
40-
_activationNameMap= new Dictionary<Activation, string>();
4157

42-
RegisterActivation("relu", _relu);
43-
RegisterActivation("linear", _linear);
44-
RegisterActivation("sigmoid", _sigmoid);
45-
RegisterActivation("softmax", _softmax);
46-
RegisterActivation("tanh", _tanh);
47-
RegisterActivation("mish", _mish);
58+
RegisterActivation(_relu);
59+
RegisterActivation(_linear);
60+
RegisterActivation(_sigmoid);
61+
RegisterActivation(_softmax);
62+
RegisterActivation(_tanh);
63+
RegisterActivation(_mish);
4864
}
4965

5066
public Activation Linear => _linear;
@@ -59,7 +75,7 @@ static Activations()
5975

6076
public Activation Mish => _mish;
6177

62-
public static Activation GetActivationByName(string name)
78+
public Activation GetActivationFromName(string name)
6379
{
6480
if (!_nameActivationMap.TryGetValue(name, out var res))
6581
{
@@ -70,17 +86,5 @@ public static Activation GetActivationByName(string name)
7086
return res;
7187
}
7288
}
73-
74-
public static string GetNameByActivation(Activation activation)
75-
{
76-
if(!_activationNameMap.TryGetValue(activation, out var name))
77-
{
78-
throw new Exception($"Activation {activation} not found");
79-
}
80-
else
81-
{
82-
return name;
83-
}
84-
}
8589
}
8690
}

src/TensorFlowNET.Keras/KerasInterface.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ public static KerasInterface Instance
4545
public Regularizers regularizers { get; } = new Regularizers();
4646
public ILayersApi layers { get; } = new LayersApi();
4747
public ILossesApi losses { get; } = new LossesApi();
48-
public Activations activations { get; } = new Activations();
48+
public IActivationsApi activations { get; } = new Activations();
4949
public Preprocessing preprocessing { get; } = new Preprocessing();
5050
ThreadLocal<BackendImpl> _backend = new ThreadLocal<BackendImpl>(() => new BackendImpl());
5151
public BackendImpl backend => _backend.Value;

src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ protected override Tensors Call(Tensors inputs, Tensor state = null, bool? train
110110
throw new NotImplementedException("");
111111

112112
if (activation != null)
113-
return activation(outputs);
113+
return activation.Apply(outputs);
114114

115115
return outputs;
116116
}

src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ protected override Tensors Call(Tensors inputs, Tensor state = null, bool? train
117117
}
118118

119119
if (activation != null)
120-
outputs = activation(outputs);
120+
outputs = activation.Apply(outputs);
121121

122122
return outputs;
123123
}

src/TensorFlowNET.Keras/Layers/Core/Dense.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ protected override Tensors Call(Tensors inputs, Tensor state = null, bool? train
8181
if (args.UseBias)
8282
outputs = tf.nn.bias_add(outputs, bias);
8383
if (args.Activation != null)
84-
outputs = activation(outputs);
84+
outputs = activation.Apply(outputs);
8585

8686
return outputs;
8787
}

src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ protected override Tensors Call(Tensors inputs, Tensor state = null, bool? train
193193
if (this.bias != null)
194194
ret += this.bias.AsTensor();
195195
if (this.activation != null)
196-
ret = this.activation(ret);
196+
ret = this.activation.Apply(ret);
197197
return ret;
198198
}
199199
/// <summary>

0 commit comments

Comments
 (0)