@@ -142,6 +142,7 @@ public History fit(IDatasetV2 dataset,
142
142
int verbose = 1 ,
143
143
List < ICallback > callbacks = null ,
144
144
IDatasetV2 validation_data = null ,
145
+ int validation_step = 10 , // 间隔多少次会进行一次验证
145
146
bool shuffle = true ,
146
147
int initial_epoch = 0 ,
147
148
int max_queue_size = 10 ,
@@ -164,11 +165,11 @@ public History fit(IDatasetV2 dataset,
164
165
} ) ;
165
166
166
167
167
- return FitInternal ( data_handler , epochs , verbose , callbacks , validation_data : validation_data ,
168
+ return FitInternal ( data_handler , epochs , validation_step , verbose , callbacks , validation_data : validation_data ,
168
169
train_step_func : train_step_function ) ;
169
170
}
170
171
171
- History FitInternal ( DataHandler data_handler , int epochs , int verbose , List < ICallback > callbackList , IDatasetV2 validation_data ,
172
+ History FitInternal ( DataHandler data_handler , int epochs , int validation_step , int verbose , List < ICallback > callbackList , IDatasetV2 validation_data ,
172
173
Func < DataHandler , OwnedIterator , Dictionary < string , float > > train_step_func )
173
174
{
174
175
stop_training = false ;
@@ -207,6 +208,9 @@ History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICal
207
208
208
209
if ( validation_data != null )
209
210
{
211
+ if ( validation_step > 0 && epoch == 0 || ( epoch ) % validation_step != 0 )
212
+ continue ;
213
+
210
214
var val_logs = evaluate ( validation_data ) ;
211
215
foreach ( var log in val_logs )
212
216
{
0 commit comments