@@ -27,7 +27,7 @@ public partial class Model
27
27
/// <param name="use_multiprocessing"></param>
28
28
/// <param name="return_dict"></param>
29
29
/// <param name="is_val"></param>
30
- public Dictionary < string , float > evaluate ( Tensor x , Tensor y ,
30
+ public Dictionary < string , float > evaluate ( NDArray x , NDArray y ,
31
31
int batch_size = - 1 ,
32
32
int verbose = 1 ,
33
33
int steps = - 1 ,
@@ -115,62 +115,53 @@ public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1, bool is
115
115
/// <param name="test_func">The function to be called on each batch of data.</param>
116
116
/// <param name="is_val">Whether it is validation or test.</param>
117
117
/// <returns></returns>
118
- Dictionary < string , float > evaluate ( DataHandler data_handler , CallbackList callbacks , bool is_val , Func < DataHandler , Tensor [ ] , Dictionary < string , float > > test_func )
118
+ Dictionary < string , float > evaluate ( DataHandler data_handler , CallbackList callbacks , bool is_val , Func < DataHandler , OwnedIterator , Dictionary < string , float > > test_func )
119
119
{
120
120
callbacks . on_test_begin ( ) ;
121
121
122
- var results = new Dictionary < string , float > ( ) ;
123
- var logs = results ;
122
+ var logs = new Dictionary < string , float > ( ) ;
124
123
foreach ( var ( epoch , iterator ) in data_handler . enumerate_epochs ( ) )
125
124
{
126
125
reset_metrics ( ) ;
127
- callbacks . on_epoch_begin ( epoch ) ;
128
- // data_handler.catch_stop_iteration();
129
-
130
126
foreach ( var step in data_handler . steps ( ) )
131
127
{
132
128
callbacks . on_test_batch_begin ( step ) ;
133
-
134
- logs = test_func ( data_handler , iterator . next ( ) ) ;
135
-
136
- tf_with ( ops . control_dependencies ( Array . Empty < object > ( ) ) , ctl => _train_counter . assign_add ( 1 ) ) ;
137
-
129
+ logs = test_func ( data_handler , iterator ) ;
138
130
var end_step = step + data_handler . StepIncrement ;
139
131
if ( ! is_val )
140
132
callbacks . on_test_batch_end ( end_step , logs ) ;
141
133
}
142
-
143
- if ( ! is_val )
144
- callbacks . on_epoch_end ( epoch , logs ) ;
145
134
}
146
-
147
- foreach ( var log in logs )
148
- {
149
- results [ log . Key ] = log . Value ;
150
- }
151
-
135
+ callbacks . on_test_end ( logs ) ;
136
+ var results = new Dictionary < string , float > ( logs ) ;
152
137
return results ;
153
138
}
154
139
155
- Dictionary < string , float > test_function ( DataHandler data_handler , Tensor [ ] data )
140
+ Dictionary < string , float > test_function ( DataHandler data_handler , OwnedIterator iterator )
156
141
{
157
- var ( x , y ) = data_handler . DataAdapter . Expand1d ( data [ 0 ] , data [ 1 ] ) ;
158
-
159
- var y_pred = Apply ( x , training : false ) ;
160
- var loss = compiled_loss . Call ( y , y_pred ) ;
161
-
162
- compiled_metrics . update_state ( y , y_pred ) ;
163
-
164
- var outputs = metrics . Select ( x => ( x . Name , x . result ( ) ) ) . ToDictionary ( x => x . Name , x => ( float ) x . Item2 ) ;
142
+ var data = iterator . next ( ) ;
143
+ var outputs = test_step ( data_handler , data [ 0 ] , data [ 1 ] ) ;
144
+ tf_with ( ops . control_dependencies ( new object [ 0 ] ) , ctl => _test_counter . assign_add ( 1 ) ) ;
165
145
return outputs ;
166
146
}
167
147
168
- Dictionary < string , float > test_step_multi_inputs_function ( DataHandler data_handler , Tensor [ ] data )
148
+ Dictionary < string , float > test_step_multi_inputs_function ( DataHandler data_handler , OwnedIterator iterator )
169
149
{
150
+ var data = iterator . next ( ) ;
170
151
var x_size = data_handler . DataAdapter . GetDataset ( ) . FirstInputTensorCount ;
171
- var outputs = train_step ( data_handler , new Tensors ( data . Take ( x_size ) . ToArray ( ) ) , new Tensors ( data . Skip ( x_size ) . ToArray ( ) ) ) ;
172
- tf_with ( ops . control_dependencies ( new object [ 0 ] ) , ctl => _train_counter . assign_add ( 1 ) ) ;
152
+ var outputs = test_step ( data_handler , data . Take ( x_size ) . ToArray ( ) , data . Skip ( x_size ) . ToArray ( ) ) ;
153
+ tf_with ( ops . control_dependencies ( new object [ 0 ] ) , ctl => _test_counter . assign_add ( 1 ) ) ;
173
154
return outputs ;
174
155
}
156
+
157
+
158
+ Dictionary < string , float > test_step ( DataHandler data_handler , Tensors x , Tensors y )
159
+ {
160
+ ( x , y ) = data_handler . DataAdapter . Expand1d ( x , y ) ;
161
+ var y_pred = Apply ( x , training : false ) ;
162
+ var loss = compiled_loss . Call ( y , y_pred ) ;
163
+ compiled_metrics . update_state ( y , y_pred ) ;
164
+ return metrics . Select ( x => ( x . Name , x . result ( ) ) ) . ToDictionary ( x => x . Item1 , x => ( float ) x . Item2 ) ;
165
+ }
175
166
}
176
167
}
0 commit comments