@@ -116,17 +116,23 @@ public static Dictionary<string, ConcreteFunction> load_function_def_library(Fun
116
116
}
117
117
118
118
Dictionary < string , ConcreteFunction > loaded_gradients = new ( ) ;
119
- foreach ( var fdef in _sort_function_defs ( library , function_deps ) )
119
+ // Debug(Rinne)
120
+ var temp = _sort_function_defs ( library , function_deps ) ;
121
+ int i = 0 ;
122
+ foreach ( var fdef in temp )
120
123
{
124
+ i ++ ;
121
125
var orig_name = _fix_fdef_in_place ( fdef , functions , load_shared_name_suffix , new_gradient_op_types ) ;
122
126
123
127
object structured_input_signature = null ;
124
128
object structured_outputs = null ;
125
129
if ( saved_object_graph is not null && saved_object_graph . ConcreteFunctions . ContainsKey ( orig_name ) )
126
130
{
127
- var proto = saved_object_graph . ConcreteFunctions [ orig_name ] ;
128
- structured_input_signature = nested_structure_coder . decode_proto ( proto . CanonicalizedInputSignature ) ;
129
- structured_outputs = nested_structure_coder . decode_proto ( proto . OutputSignature ) ;
131
+ // TODO(Rinne): deal with structured_input_signature and structured_outputs.
132
+
133
+ //var proto = saved_object_graph.ConcreteFunctions[orig_name];
134
+ //structured_input_signature = nested_structure_coder.decode_proto(proto.CanonicalizedInputSignature);
135
+ //structured_outputs = nested_structure_coder.decode_proto(proto.OutputSignature);
130
136
}
131
137
132
138
graph . as_default ( ) ;
@@ -234,27 +240,41 @@ private static Func<Operation, Tensor[], Tensor[]> _gen_gradient_func(ConcreteFu
234
240
235
241
private static void _restore_gradient_functions ( FuncGraph func_graph , Dictionary < string , ConcreteFunction > renamed_functions , Dictionary < string , ConcreteFunction > loaded_gradients )
236
242
{
237
- foreach ( var op in func_graph . get_operations ( ) )
243
+ if ( loaded_gradients is null || loaded_gradients . Count == 0 )
238
244
{
239
- if ( op . op . type == "StatefulPartitionedCall" || op . op . type == "PartitionedCall" )
240
- {
241
- var function = renamed_functions [ op . op . node_def . Attr [ "f" ] . Func . Name ] ;
242
- op . op . _gradient_function = function . _get_gradient_function ( ) ;
243
- }
244
- string gradient_op_type = null ;
245
- try
246
- {
247
- gradient_op_type = op . op . get_attr ( "_gradient_op_type" ) as string ;
248
- }
249
- catch ( InvalidArgumentError )
245
+ foreach ( var op in func_graph . get_operations ( ) )
250
246
{
251
- continue ;
247
+ if ( op . op . type == "StatefulPartitionedCall" || op . op . type == "PartitionedCall" )
248
+ {
249
+ var function = renamed_functions [ op . op . node_def . Attr [ "f" ] . Func . Name ] ;
250
+ op . op . _gradient_function = function . _get_gradient_function ( ) ;
251
+ }
252
252
}
253
- if ( loaded_gradients . ContainsKey ( gradient_op_type ) )
253
+ }
254
+ else
255
+ {
256
+ foreach ( var op in func_graph . get_operations ( ) )
254
257
{
255
- var grad_fn = loaded_gradients [ gradient_op_type ] ;
256
- grad_fn . NumPositionArgs = op . op . inputs . Length ;
257
- grad_fn . ArgKeywords = op . op . inputs . _inputs . Select ( x => x . name ) ;
258
+ if ( op . op . type == "StatefulPartitionedCall" || op . op . type == "PartitionedCall" )
259
+ {
260
+ var function = renamed_functions [ op . op . node_def . Attr [ "f" ] . Func . Name ] ;
261
+ op . op . _gradient_function = function . _get_gradient_function ( ) ;
262
+ }
263
+ string gradient_op_type = null ;
264
+ try
265
+ {
266
+ gradient_op_type = op . op . get_attr ( "_gradient_op_type" ) as string ;
267
+ }
268
+ catch ( InvalidArgumentError )
269
+ {
270
+ continue ;
271
+ }
272
+ if ( loaded_gradients . ContainsKey ( gradient_op_type ) )
273
+ {
274
+ var grad_fn = loaded_gradients [ gradient_op_type ] ;
275
+ grad_fn . NumPositionArgs = op . op . inputs . Length ;
276
+ grad_fn . ArgKeywords = op . op . inputs . _inputs . Select ( x => x . name ) ;
277
+ }
258
278
}
259
279
}
260
280
}
0 commit comments