Skip to content

Commit aa13352

Browse files
committed
fix AdamOptimizer for Graph mode.
1 parent b3cd413 commit aa13352

File tree

6 files changed

+35
-33
lines changed

6 files changed

+35
-33
lines changed

src/TensorFlowNET.Core/Gradients/math_grad.cs

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -542,15 +542,28 @@ public static Tensor[] _SumGrad(Operation op, Tensor[] grads)
542542
}
543543

544544
input_shape = array_ops.shape(op.inputs[0]);
545-
if (!op.get_attr<bool>("keep_dims"))
545+
546+
if (tf.executing_eagerly())
547+
{
548+
if (!op.get_attr<bool>("keep_dims"))
549+
{
550+
ops.colocate_with(input_shape);
551+
var output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1]);
552+
// var tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims);
553+
grad = gen_array_ops.reshape(grad, output_shape_kept_dims);
554+
}
555+
556+
return new Tensor[] { gen_array_ops.broadcast_to(grad, input_shape), null };
557+
}
558+
else
546559
{
547560
ops.colocate_with(input_shape);
548561
var output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1]);
549-
// var tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims);
562+
var tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims);
550563
grad = gen_array_ops.reshape(grad, output_shape_kept_dims);
551-
}
552564

553-
return new Tensor[] { gen_array_ops.broadcast_to(grad, input_shape), null };
565+
return new Tensor[] { gen_array_ops.tile(grad, tile_scaling), null };
566+
}
554567
}
555568

556569
[RegisterGradient("RealDiv")]

src/TensorFlowNET.Core/Operations/gen_image_ops.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ public static Tensor decode_jpeg(Tensor contents,
6666
int ratio = 1,
6767
bool fancy_upscaling = true,
6868
bool try_recover_truncated = false,
69-
int acceptable_fraction = 1,
69+
float acceptable_fraction = 1,
7070
string dct_method = "",
7171
string name = null)
7272
{

src/TensorFlowNET.Core/Operations/math_ops.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -652,7 +652,7 @@ private static Tensor _ReductionDims(Tensor x, Tensor axis)
652652
}
653653
else
654654
{
655-
if(x.rank > -1)
655+
if (x.rank > -1 && tf.executing_eagerly())
656656
return constant_op.constant(np.arange(x.rank));
657657

658658
var rank = array_ops.rank(x);

src/TensorFlowNET.Core/Training/AdamOptimizer.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ private Operation _apply_sparse_shared(Tensor grad, IVariableV1 var, Tensor indi
109109
return control_flow_ops.group(new[] { var_update, m_t, v_t });
110110
}
111111

112-
protected override void _create_slots(ResourceVariable[] var_list)
112+
protected override void _create_slots(IVariableV1[] var_list)
113113
{
114114
var first_var = var_list.OrderBy(x => x.Name).First();
115115
_create_non_slot_variable(initial_value: _beta1, name: "beta1_power", colocate_with: first_var);

src/TensorFlowNET.Core/Training/Optimizer.cs

Lines changed: 12 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ public Optimizer(Tensor learning_rate, bool use_locking, string name = null)
107107
/// </returns>
108108
public Operation minimize(Tensor loss,
109109
IVariableV1 global_step = null,
110-
List<ResourceVariable> var_list=null,
110+
List<IVariableV1> var_list=null,
111111
GateGradientType gate_gradients = GateGradientType.GATE_OP,
112112
int? aggregation_method=null,
113113
bool colocate_gradients_with_ops = false, string name=null, Tensor grad_loss=null)
@@ -142,17 +142,17 @@ public Operation minimize(Tensor loss,
142142
/// <returns>
143143
/// An `Operation` that applies the specified gradients. If `global_step`
144144
/// was not None, that operation also increments `global_step`.</returns>
145-
public Operation apply_gradients(Tuple<Tensor, ResourceVariable>[] grads_and_vars, IVariableV1 global_step = null, string name = null)
145+
public Operation apply_gradients(Tuple<Tensor, IVariableV1>[] grads_and_vars, IVariableV1 global_step = null, string name = null)
146146
{
147147
// No DistributionStrategy case.
148-
var converted_grads_and_vars = new List<(Tensor, ResourceVariable, _OptimizableVariable)>();
148+
var converted_grads_and_vars = new List<(Tensor, IVariableV1, _OptimizableVariable)>();
149149
foreach (var (g, v) in grads_and_vars)
150150
{
151151
if(g != null)
152152
{
153153
// Convert the grad to Tensor or IndexedSlices if necessary.
154154
var gR = ops.convert_to_tensor_or_indexed_slices(g);
155-
var p = optimizer._get_processor(v);
155+
var p = optimizer._get_processor(v as ResourceVariable);
156156
converted_grads_and_vars.Add((gR, v, p));
157157
}
158158
}
@@ -230,7 +230,7 @@ public Operation apply_gradients(Tuple<Tensor, ResourceVariable>[] grads_and_var
230230
/// silently ignored).
231231
/// </summary>
232232
/// <param name="var_list"></param>
233-
protected virtual void _create_slots(ResourceVariable[] var_list)
233+
protected virtual void _create_slots(IVariableV1[] var_list)
234234
{
235235

236236
}
@@ -369,8 +369,8 @@ protected IVariableV1 _get_non_slot_variable(string name, Graph graph = null)
369369
/// A list of (gradient, variable) pairs. Variable is always present, but
370370
/// gradient can be `None`.
371371
/// </returns>
372-
public Tuple<Tensor, ResourceVariable>[] compute_gradients(Tensor loss,
373-
List<ResourceVariable> var_list = null,
372+
public Tuple<Tensor, IVariableV1>[] compute_gradients(Tensor loss,
373+
List<IVariableV1> var_list = null,
374374
int? aggregation_method = null,
375375
GateGradientType gate_gradients = GateGradientType.GATE_OP,
376376
bool colocate_gradients_with_ops = false,
@@ -381,26 +381,13 @@ public Tuple<Tensor, ResourceVariable>[] compute_gradients(Tensor loss,
381381

382382
if(var_list == null)
383383
{
384-
var vars = ops.get_collection<ResourceVariable>(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES);
384+
var vars = ops.get_collection<IVariableV1>(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES);
385385
var tmp = variables.trainable_variables();
386-
switch (tmp)
387-
{
388-
case List<ResourceVariable> values:
389-
var_list = values.Concat(vars).ToList();
390-
break;
391-
/*case List<RefVariable> values:
392-
var_list = values.Concat(vars).ToList();
393-
break;
394-
case List<IVariableV1> values:
395-
var_list = values.Select(x => x as RefVariable).Concat(vars).ToList();
396-
break;*/
397-
default:
398-
throw new NotImplementedException("");
399-
}
386+
var_list = (tmp as List<IVariableV1>).Concat(vars).ToList();
400387
}
401388

402-
var_list = var_list.Concat(ops.get_collection<ResourceVariable>(tf.GraphKeys._STREAMING_MODEL_PORTS)).ToList();
403-
var processors = var_list.Select(v => optimizer._get_processor(v)).ToList();
389+
var_list = var_list.Concat(ops.get_collection<IVariableV1>(tf.GraphKeys._STREAMING_MODEL_PORTS)).ToList();
390+
var processors = var_list.Select(v => optimizer._get_processor(v as ResourceVariable)).ToList();
404391
var var_refs = processors.Select(x => x.target()).ToArray();
405392

406393
var grads = gradients_impl.gradients(new Tensor[] { loss }, var_refs, grad_ys: grad_loss == null ? null : new Tensor[] { grad_loss },
@@ -412,7 +399,7 @@ public Tuple<Tensor, ResourceVariable>[] compute_gradients(Tensor loss,
412399
grads = control_flow_ops.tuple(grads);
413400

414401
var grads_and_vars = zip(grads, var_list)
415-
.Select(x => new Tuple<Tensor, ResourceVariable>(x.Item1, x.Item2))
402+
.Select(x => new Tuple<Tensor, IVariableV1>(x.Item1, x.Item2))
416403
.ToArray();
417404

418405
return grads_and_vars;

src/TensorFlowNET.Core/Variables/ResourceVariable.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ private void _init_from_args(object initial_value = null,
155155
_graph_element = value;
156156
});
157157

158-
ops.add_to_collections(collections, this);
158+
ops.add_to_collections<IVariableV1>(collections, this);
159159
}
160160
else
161161
{
@@ -184,6 +184,8 @@ private void _init_from_proto(VariableDef variable_def, string import_scope = nu
184184
var g = ops.get_default_graph();
185185
var prepend_name_scope = ops.prepend_name_scope(variable_def.VariableName, import_scope: import_scope);
186186
handle = g.as_graph_element(prepend_name_scope) as Tensor;
187+
_handle_name = handle.name;
188+
_name = handle.name;
187189
_shape = new TensorShape(handle.op.get_attr("shape") as TensorShapeProto);
188190

189191
prepend_name_scope = ops.prepend_name_scope(variable_def.InitializerName, import_scope: import_scope);

0 commit comments

Comments
 (0)