5
5
using Tensorflow . Keras . ArgsDefinition ;
6
6
using Tensorflow . Keras . Engine . DataAdapters ;
7
7
using static Tensorflow . Binding ;
8
+ using Tensorflow . Keras . Layers ;
9
+ using Tensorflow . Keras . Utils ;
10
+ using Tensorflow ;
11
+ using Tensorflow . Keras . Callbacks ;
8
12
9
13
namespace Tensorflow . Keras . Engine
10
14
{
@@ -31,6 +35,11 @@ public void evaluate(NDArray x, NDArray y,
31
35
bool use_multiprocessing = false ,
32
36
bool return_dict = false )
33
37
{
38
+ if ( x . dims [ 0 ] != y . dims [ 0 ] )
39
+ {
40
+ throw new InvalidArgumentError (
41
+ $ "The array x and y should have same value at dim 0, but got { x . dims [ 0 ] } and { y . dims [ 0 ] } ") ;
42
+ }
34
43
var data_handler = new DataHandler ( new DataHandlerArgs
35
44
{
36
45
X = x ,
@@ -46,18 +55,31 @@ public void evaluate(NDArray x, NDArray y,
46
55
StepsPerExecution = _steps_per_execution
47
56
} ) ;
48
57
58
+ var callbacks = new CallbackList ( new CallbackParams
59
+ {
60
+ Model = this ,
61
+ Verbose = verbose ,
62
+ Steps = data_handler . Inferredsteps
63
+ } ) ;
64
+ callbacks . on_test_begin ( ) ;
65
+
49
66
foreach ( var ( epoch , iterator ) in data_handler . enumerate_epochs ( ) )
50
67
{
51
68
reset_metrics ( ) ;
52
- // callbacks.on_epoch_begin(epoch)
69
+ //callbacks.on_epoch_begin(epoch);
53
70
// data_handler.catch_stop_iteration();
54
- IEnumerable < ( string , Tensor ) > results = null ;
71
+ IEnumerable < ( string , Tensor ) > logs = null ;
72
+
55
73
foreach ( var step in data_handler . steps ( ) )
56
74
{
57
- // callbacks.on_train_batch_begin(step)
58
- results = test_function ( data_handler , iterator ) ;
75
+ callbacks . on_train_batch_begin ( step ) ;
76
+ logs = test_function ( data_handler , iterator ) ;
77
+ var end_step = step + data_handler . StepIncrement ;
78
+ callbacks . on_test_batch_end ( end_step , logs ) ;
59
79
}
60
80
}
81
+ GC . Collect ( ) ;
82
+ GC . WaitForPendingFinalizers ( ) ;
61
83
}
62
84
63
85
public KeyValuePair < string , float > [ ] evaluate ( IDatasetV2 x )
@@ -75,7 +97,8 @@ public KeyValuePair<string, float>[] evaluate(IDatasetV2 x)
75
97
reset_metrics ( ) ;
76
98
// callbacks.on_epoch_begin(epoch)
77
99
// data_handler.catch_stop_iteration();
78
-
100
+
101
+
79
102
foreach ( var step in data_handler . steps ( ) )
80
103
{
81
104
// callbacks.on_train_batch_begin(step)
0 commit comments