Skip to content

Commit baf620a

Browse files
committed
解决keras模式下,使用GPU训练时会爆显存的bug。
观察到的现象是,一些模型增大batchsize后,会在首个epoch的中途爆显存不足,只要过了一个epoch后,就能完整训练。同样的batchsize在python下能设置大得多的值。 最后使用最小训练代码分析出,是每个step之后,图片加载到显存里的数据没有释放导致的。 在寻找释放显存接口没有结果的时候,直接使用了GC.Collect();可以让显存主动回收。 因此当前的修复方案是在每个step里,都执行一次 GC.Collect(); 用来释放显存资源。
1 parent 5e4f530 commit baf620a

File tree

4 files changed

+33
-7
lines changed

4 files changed

+33
-7
lines changed

src/TensorFlowNET.Core/Keras/Engine/IModel.cs

+23
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ ICallback fit(NDArray x, NDArray y,
2424
List<ICallback> callbacks = null,
2525
float validation_split = 0f,
2626
ValidationDataPack validation_data = null,
27+
int validation_step = 10,
2728
bool shuffle = true,
2829
Dictionary<int, float> class_weight = null,
2930
NDArray sample_weight = null,
@@ -47,6 +48,20 @@ ICallback fit(IEnumerable<NDArray> x, NDArray y,
4748
int workers = 1,
4849
bool use_multiprocessing = false);
4950

51+
public ICallback fit(IDatasetV2 dataset,
52+
int batch_size = -1,
53+
int epochs = 1,
54+
int verbose = 1,
55+
List<ICallback> callbacks = null,
56+
IDatasetV2 validation_data = null,
57+
int validation_step = 10, // 间隔多少次会进行一次验证
58+
bool shuffle = true,
59+
Dictionary<int, float> class_weight = null,
60+
int initial_epoch = 0,
61+
int max_queue_size = 10,
62+
int workers = 1,
63+
bool use_multiprocessing = false);
64+
5065
void save(string filepath,
5166
bool overwrite = true,
5267
bool include_optimizer = true,
@@ -85,6 +100,14 @@ Tensors predict(Tensors x,
85100
int workers = 1,
86101
bool use_multiprocessing = false);
87102

103+
public Tensors predict(IDatasetV2 dataset,
104+
int batch_size = -1,
105+
int verbose = 0,
106+
int steps = -1,
107+
int max_queue_size = 10,
108+
int workers = 1,
109+
bool use_multiprocessing = false);
110+
88111
void summary(int line_length = -1, float[] positions = null);
89112

90113
IKerasConfig get_config();

src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs

+3
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ Dictionary<string, float> evaluate(DataHandler data_handler, CallbackList callba
132132
var end_step = step + data_handler.StepIncrement;
133133
if (!is_val)
134134
callbacks.on_test_batch_end(end_step, logs);
135+
GC.Collect();
135136
}
136137
}
137138
callbacks.on_test_end(logs);
@@ -167,7 +168,9 @@ Dictionary<string, float> test_step_multi_inputs_function(DataHandler data_handl
167168
Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y)
168169
{
169170
(x,y) = data_handler.DataAdapter.Expand1d(x, y);
171+
170172
var y_pred = Apply(x, training: false);
173+
171174
var loss = compiled_loss.Call(y, y_pred);
172175
compiled_metrics.update_state(y, y_pred);
173176
return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2);

src/TensorFlowNET.Keras/Engine/Model.Fit.cs

+6-6
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ public ICallback fit(NDArray x, NDArray y,
4141
List<ICallback> callbacks = null,
4242
float validation_split = 0f,
4343
ValidationDataPack validation_data = null,
44+
int validation_step = 10,
4445
bool shuffle = true,
4546
Dictionary<int, float> class_weight = null,
4647
NDArray sample_weight = null,
@@ -147,7 +148,7 @@ public ICallback fit(IEnumerable<NDArray> x, NDArray y,
147148
}
148149
}
149150

150-
public History fit(IDatasetV2 dataset,
151+
public ICallback fit(IDatasetV2 dataset,
151152
int batch_size = -1,
152153
int epochs = 1,
153154
int verbose = 1,
@@ -156,7 +157,6 @@ public History fit(IDatasetV2 dataset,
156157
int validation_step = 10,
157158
bool shuffle = true,
158159
Dictionary<int, float> class_weight = null,
159-
NDArray sample_weight = null,
160160
int initial_epoch = 0,
161161
int max_queue_size = 10,
162162
int workers = 1,
@@ -170,7 +170,7 @@ public History fit(IDatasetV2 dataset,
170170
InitialEpoch = initial_epoch,
171171
Epochs = epochs,
172172
Shuffle = shuffle,
173-
SampleWeight = sample_weight,
173+
ClassWeight = class_weight,
174174
MaxQueueSize = max_queue_size,
175175
Workers = workers,
176176
UseMultiprocessing = use_multiprocessing,
@@ -218,6 +218,7 @@ History FitInternal(DataHandler data_handler, int epochs, int validation_step, i
218218
var end_step = step + data_handler.StepIncrement;
219219
End_step = end_step;
220220
callbacks.on_train_batch_end(end_step, logs);
221+
GC.Collect();
221222
}
222223

223224
if (validation_data != null)
@@ -233,11 +234,10 @@ History FitInternal(DataHandler data_handler, int epochs, int validation_step, i
233234
callbacks.on_train_batch_end(End_step, logs);
234235
}
235236

237+
GC.Collect();
236238

237239
callbacks.on_epoch_end(epoch, logs);
238240

239-
GC.Collect();
240-
GC.WaitForPendingFinalizers();
241241
if (stop_training)
242242
{
243243
break;
@@ -282,6 +282,7 @@ History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICal
282282
var end_step = step + data_handler.StepIncrement;
283283
End_step = end_step;
284284
callbacks.on_train_batch_end(end_step, logs);
285+
GC.Collect();
285286
}
286287

287288
if (validation_data != null)
@@ -301,7 +302,6 @@ History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICal
301302
callbacks.on_epoch_end(epoch, logs);
302303

303304
GC.Collect();
304-
GC.WaitForPendingFinalizers();
305305
if (stop_training)
306306
{
307307
break;

src/TensorFlowNET.Keras/Engine/Model.Predict.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,9 @@ Tensors PredictInternal(DataHandler data_handler, int verbose)
102102
for (int i = 0; i < batch_outputs.Length; i++)
103103
batch_outputs[i] = tf.concat(new Tensor[] { batch_outputs[i], tmp_batch_outputs[i] }, axis: 0);
104104
}
105-
106105
var end_step = step + data_handler.StepIncrement;
107106
callbacks.on_predict_batch_end(end_step, new Dictionary<string, Tensors> { { "outputs", batch_outputs } });
107+
GC.Collect();
108108
}
109109
}
110110

0 commit comments

Comments
 (0)