@@ -107,7 +107,7 @@ public Optimizer(Tensor learning_rate, bool use_locking, string name = null)
107
107
/// </returns>
108
108
public Operation minimize ( Tensor loss ,
109
109
IVariableV1 global_step = null ,
110
- List < ResourceVariable > var_list = null ,
110
+ List < IVariableV1 > var_list = null ,
111
111
GateGradientType gate_gradients = GateGradientType . GATE_OP ,
112
112
int ? aggregation_method = null ,
113
113
bool colocate_gradients_with_ops = false , string name = null , Tensor grad_loss = null )
@@ -142,17 +142,17 @@ public Operation minimize(Tensor loss,
142
142
/// <returns>
143
143
/// An `Operation` that applies the specified gradients. If `global_step`
144
144
/// was not None, that operation also increments `global_step`.</returns>
145
- public Operation apply_gradients ( Tuple < Tensor , ResourceVariable > [ ] grads_and_vars , IVariableV1 global_step = null , string name = null )
145
+ public Operation apply_gradients ( Tuple < Tensor , IVariableV1 > [ ] grads_and_vars , IVariableV1 global_step = null , string name = null )
146
146
{
147
147
// No DistributionStrategy case.
148
- var converted_grads_and_vars = new List < ( Tensor , ResourceVariable , _OptimizableVariable ) > ( ) ;
148
+ var converted_grads_and_vars = new List < ( Tensor , IVariableV1 , _OptimizableVariable ) > ( ) ;
149
149
foreach ( var ( g , v ) in grads_and_vars )
150
150
{
151
151
if ( g != null )
152
152
{
153
153
// Convert the grad to Tensor or IndexedSlices if necessary.
154
154
var gR = ops . convert_to_tensor_or_indexed_slices ( g ) ;
155
- var p = optimizer . _get_processor ( v ) ;
155
+ var p = optimizer . _get_processor ( v as ResourceVariable ) ;
156
156
converted_grads_and_vars . Add ( ( gR , v , p ) ) ;
157
157
}
158
158
}
@@ -230,7 +230,7 @@ public Operation apply_gradients(Tuple<Tensor, ResourceVariable>[] grads_and_var
230
230
/// silently ignored).
231
231
/// </summary>
232
232
/// <param name="var_list"></param>
233
- protected virtual void _create_slots ( ResourceVariable [ ] var_list )
233
+ protected virtual void _create_slots ( IVariableV1 [ ] var_list )
234
234
{
235
235
236
236
}
@@ -369,8 +369,8 @@ protected IVariableV1 _get_non_slot_variable(string name, Graph graph = null)
369
369
/// A list of (gradient, variable) pairs. Variable is always present, but
370
370
/// gradient can be `None`.
371
371
/// </returns>
372
- public Tuple < Tensor , ResourceVariable > [ ] compute_gradients ( Tensor loss ,
373
- List < ResourceVariable > var_list = null ,
372
+ public Tuple < Tensor , IVariableV1 > [ ] compute_gradients ( Tensor loss ,
373
+ List < IVariableV1 > var_list = null ,
374
374
int ? aggregation_method = null ,
375
375
GateGradientType gate_gradients = GateGradientType . GATE_OP ,
376
376
bool colocate_gradients_with_ops = false ,
@@ -381,26 +381,13 @@ public Tuple<Tensor, ResourceVariable>[] compute_gradients(Tensor loss,
381
381
382
382
if ( var_list == null )
383
383
{
384
- var vars = ops . get_collection < ResourceVariable > ( tf . GraphKeys . TRAINABLE_RESOURCE_VARIABLES ) ;
384
+ var vars = ops . get_collection < IVariableV1 > ( tf . GraphKeys . TRAINABLE_RESOURCE_VARIABLES ) ;
385
385
var tmp = variables . trainable_variables ( ) ;
386
- switch ( tmp )
387
- {
388
- case List < ResourceVariable > values :
389
- var_list = values . Concat ( vars ) . ToList ( ) ;
390
- break ;
391
- /*case List<RefVariable> values:
392
- var_list = values.Concat(vars).ToList();
393
- break;
394
- case List<IVariableV1> values:
395
- var_list = values.Select(x => x as RefVariable).Concat(vars).ToList();
396
- break;*/
397
- default :
398
- throw new NotImplementedException ( "" ) ;
399
- }
386
+ var_list = ( tmp as List < IVariableV1 > ) . Concat ( vars ) . ToList ( ) ;
400
387
}
401
388
402
- var_list = var_list . Concat ( ops . get_collection < ResourceVariable > ( tf . GraphKeys . _STREAMING_MODEL_PORTS ) ) . ToList ( ) ;
403
- var processors = var_list . Select ( v => optimizer . _get_processor ( v ) ) . ToList ( ) ;
389
+ var_list = var_list . Concat ( ops . get_collection < IVariableV1 > ( tf . GraphKeys . _STREAMING_MODEL_PORTS ) ) . ToList ( ) ;
390
+ var processors = var_list . Select ( v => optimizer . _get_processor ( v as ResourceVariable ) ) . ToList ( ) ;
404
391
var var_refs = processors . Select ( x => x . target ( ) ) . ToArray ( ) ;
405
392
406
393
var grads = gradients_impl . gradients ( new Tensor [ ] { loss } , var_refs , grad_ys : grad_loss == null ? null : new Tensor [ ] { grad_loss } ,
@@ -412,7 +399,7 @@ public Tuple<Tensor, ResourceVariable>[] compute_gradients(Tensor loss,
412
399
grads = control_flow_ops . tuple ( grads ) ;
413
400
414
401
var grads_and_vars = zip ( grads , var_list )
415
- . Select ( x => new Tuple < Tensor , ResourceVariable > ( x . Item1 , x . Item2 ) )
402
+ . Select ( x => new Tuple < Tensor , IVariableV1 > ( x . Item1 , x . Item2 ) )
416
403
. ToArray ( ) ;
417
404
418
405
return grads_and_vars ;
0 commit comments