Skip to content

Commit 48403a5

Browse files
authored
Merge pull request #1161 from Wanglongzhi2001/master
fix: add the momentum parameter's implemention of SGD
2 parents 7fec40b + f3b3d8b commit 48403a5

File tree

4 files changed

+25
-4
lines changed

4 files changed

+25
-4
lines changed

src/TensorFlowNET.Core/Keras/IOptimizerApi.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,6 @@ IOptimizer RMSprop(float learning_rate = 0.001f,
6363
bool centered = false,
6464
string name = "RMSprop");
6565

66-
IOptimizer SGD(float learning_rate);
66+
IOptimizer SGD(float learning_rate, float momentum);
6767
}
6868
}

src/TensorFlowNET.Core/Training/gen_training_ops.cs

+4
Original file line numberDiff line numberDiff line change
@@ -51,5 +51,9 @@ public static Tensor apply_gradient_descent(IVariableV1 var, Tensor alpha, Tenso
5151
public static Tensor resource_apply_gradient_descent(Tensor var, Tensor alpha, Tensor delta, bool use_locking = false, string name = null)
5252
=> tf.Context.ExecuteOp("ResourceApplyGradientDescent", name,
5353
new ExecuteOpArgs(var, alpha, delta).SetAttributes(new { use_locking }));
54+
55+
public static Tensor resource_apply_keras_momentum(Tensor var, Tensor accum, Tensor lr, Tensor grad, Tensor momentum, bool use_locking = false, bool use_nesterov = false, string name = null)
56+
=> tf.Context.ExecuteOp("ResourceApplyKerasMomentum", name,
57+
new ExecuteOpArgs(var, accum, lr, grad, momentum).SetAttributes(new { use_locking, use_nesterov }));
5458
}
5559
}

src/TensorFlowNET.Keras/Optimizers/OptimizerApi.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ public IOptimizer RMSprop(float learning_rate = 0.001f,
7171
Name = name
7272
});
7373

74-
public IOptimizer SGD(float learning_rate)
75-
=> new SGD(learning_rate);
74+
public IOptimizer SGD(float learning_rate, float momentum)
75+
=> new SGD(learning_rate, momentum);
7676
}
7777
}

src/TensorFlowNET.Keras/Optimizers/SGD.cs

+18-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ public SGD(float learning_rate,
2222
_set_hyper("decay", decay);
2323

2424
_momentum = momentum > 0;
25+
if (momentum < 0 || momentum > 1)
26+
throw new ValueError($"momentum must be a number between 0 and 1, got {momentum}.");
2527

2628
_set_hyper("momentum", momentum);
2729

@@ -30,6 +32,13 @@ public SGD(float learning_rate,
3032
#pragma warning restore CS1717 // Assignment made to same variable
3133
}
3234

35+
protected override void _create_slots(IVariableV1[] var_list)
36+
{
37+
if (_momentum)
38+
foreach (var var in var_list)
39+
add_slot(var, "momentum");
40+
}
41+
3342
protected override void _prepare_local(DeviceDType device_dtype,
3443
Dictionary<DeviceDType, Dictionary<string, Tensor>> _apply_state)
3544
{
@@ -43,7 +52,15 @@ protected override Operation _resource_apply_dense(IVariableV1 var, Tensor grad,
4352
{
4453
if (_momentum)
4554
{
46-
throw new NotImplementedException("_resource_apply_dense");
55+
var momentum_var = get_slot(var, "momentum");
56+
return gen_training_ops.resource_apply_keras_momentum(
57+
var.Handle,
58+
momentum_var.Handle,
59+
_get_hyper("learning_rate", var.dtype),
60+
grad,
61+
_get_hyper("momentum", var.dtype),
62+
use_locking: _use_locking,
63+
use_nesterov: nesterov);
4764
}
4865
var device_dtype = _apply_state.Keys.FirstOrDefault(x => x.Device == var.Device && x.DType == var.dtype.as_base_dtype());
4966

0 commit comments

Comments
 (0)