Skip to content

Commit 86b235f

Browse files
authored
Merge pull request #1123 from Wanglongzhi2001/master
fix: fix the bug of repeated progress bar in Model.fit()
2 parents 7e1568f + 8ebe3e3 commit 86b235f

File tree

8 files changed

+45
-35
lines changed

8 files changed

+45
-35
lines changed

src/TensorFlowNET.Core/Keras/Engine/ICallback.cs

+3
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ public interface ICallback
1414
void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs);
1515
void on_predict_end();
1616
void on_test_begin();
17+
void on_test_end(Dictionary<string, float> logs);
1718
void on_test_batch_begin(long step);
1819
void on_test_batch_end(long end_step, Dictionary<string, float> logs);
20+
21+
1922
}

src/TensorFlowNET.Core/Keras/Engine/IModel.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ void load_weights(string filepath,
6060
bool skip_mismatch = false,
6161
object options = null);
6262

63-
Dictionary<string, float> evaluate(Tensor x, Tensor y,
63+
Dictionary<string, float> evaluate(NDArray x, NDArray y,
6464
int batch_size = -1,
6565
int verbose = 1,
6666
int steps = -1,

src/TensorFlowNET.Keras/Callbacks/CallbackList.cs

+5
Original file line numberDiff line numberDiff line change
@@ -73,4 +73,9 @@ public void on_test_batch_end(long end_step, Dictionary<string, float> logs)
7373
{
7474
callbacks.ForEach(x => x.on_test_batch_end(end_step, logs));
7575
}
76+
77+
public void on_test_end(Dictionary<string, float> logs)
78+
{
79+
callbacks.ForEach(x => x.on_test_end(logs));
80+
}
7681
}

src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs

+4
Original file line numberDiff line numberDiff line change
@@ -150,4 +150,8 @@ public bool _is_improvement(float monitor_value, float reference_value)
150150
return less_op;
151151
}
152152
}
153+
154+
public void on_test_end(Dictionary<string, float> logs)
155+
{
156+
}
153157
}

src/TensorFlowNET.Keras/Callbacks/History.cs

+4
Original file line numberDiff line numberDiff line change
@@ -81,4 +81,8 @@ public void on_test_batch_begin(long step)
8181
public void on_test_batch_end(long end_step, Dictionary<string, float> logs)
8282
{
8383
}
84+
85+
public void on_test_end(Dictionary<string, float> logs)
86+
{
87+
}
8488
}

src/TensorFlowNET.Keras/Callbacks/ProgbarLogger.cs

+3
Original file line numberDiff line numberDiff line change
@@ -118,5 +118,8 @@ public void on_test_batch_end(long end_step, Dictionary<string, float> logs)
118118
}
119119
}
120120

121+
public void on_test_end(Dictionary<string, float> logs)
122+
{
123+
}
121124
}
122125
}

src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs

+24-33
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ public partial class Model
2727
/// <param name="use_multiprocessing"></param>
2828
/// <param name="return_dict"></param>
2929
/// <param name="is_val"></param>
30-
public Dictionary<string, float> evaluate(Tensor x, Tensor y,
30+
public Dictionary<string, float> evaluate(NDArray x, NDArray y,
3131
int batch_size = -1,
3232
int verbose = 1,
3333
int steps = -1,
@@ -115,62 +115,53 @@ public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1, bool is
115115
/// <param name="test_func">The function to be called on each batch of data.</param>
116116
/// <param name="is_val">Whether it is validation or test.</param>
117117
/// <returns></returns>
118-
Dictionary<string, float> evaluate(DataHandler data_handler, CallbackList callbacks, bool is_val, Func<DataHandler, Tensor[], Dictionary<string, float>> test_func)
118+
Dictionary<string, float> evaluate(DataHandler data_handler, CallbackList callbacks, bool is_val, Func<DataHandler, OwnedIterator, Dictionary<string, float>> test_func)
119119
{
120120
callbacks.on_test_begin();
121121

122-
var results = new Dictionary<string, float>();
123-
var logs = results;
122+
var logs = new Dictionary<string, float>();
124123
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
125124
{
126125
reset_metrics();
127-
callbacks.on_epoch_begin(epoch);
128-
// data_handler.catch_stop_iteration();
129-
130126
foreach (var step in data_handler.steps())
131127
{
132128
callbacks.on_test_batch_begin(step);
133-
134-
logs = test_func(data_handler, iterator.next());
135-
136-
tf_with(ops.control_dependencies(Array.Empty<object>()), ctl => _train_counter.assign_add(1));
137-
129+
logs = test_func(data_handler, iterator);
138130
var end_step = step + data_handler.StepIncrement;
139131
if (!is_val)
140132
callbacks.on_test_batch_end(end_step, logs);
141133
}
142-
143-
if (!is_val)
144-
callbacks.on_epoch_end(epoch, logs);
145134
}
146-
147-
foreach (var log in logs)
148-
{
149-
results[log.Key] = log.Value;
150-
}
151-
135+
callbacks.on_test_end(logs);
136+
var results = new Dictionary<string, float>(logs);
152137
return results;
153138
}
154139

155-
Dictionary<string, float> test_function(DataHandler data_handler, Tensor[] data)
140+
Dictionary<string, float> test_function(DataHandler data_handler, OwnedIterator iterator)
156141
{
157-
var (x, y) = data_handler.DataAdapter.Expand1d(data[0], data[1]);
158-
159-
var y_pred = Apply(x, training: false);
160-
var loss = compiled_loss.Call(y, y_pred);
161-
162-
compiled_metrics.update_state(y, y_pred);
163-
164-
var outputs = metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Name, x => (float)x.Item2);
142+
var data = iterator.next();
143+
var outputs = test_step(data_handler, data[0], data[1]);
144+
tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1));
165145
return outputs;
166146
}
167147

168-
Dictionary<string, float> test_step_multi_inputs_function(DataHandler data_handler, Tensor[] data)
148+
Dictionary<string, float> test_step_multi_inputs_function(DataHandler data_handler, OwnedIterator iterator)
169149
{
150+
var data = iterator.next();
170151
var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount;
171-
var outputs = train_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray()));
172-
tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1));
152+
var outputs = test_step(data_handler, data.Take(x_size).ToArray(), data.Skip(x_size).ToArray());
153+
tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1));
173154
return outputs;
174155
}
156+
157+
158+
Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y)
159+
{
160+
(x, y) = data_handler.DataAdapter.Expand1d(x, y);
161+
var y_pred = Apply(x, training: false);
162+
var loss = compiled_loss.Call(y, y_pred);
163+
compiled_metrics.update_state(y, y_pred);
164+
return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2);
165+
}
175166
}
176167
}

src/TensorFlowNET.Keras/Engine/Model.Fit.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICal
266266
{
267267
// Because evaluate calls call_test_batch_end, this interferes with our output on the screen
268268
// so we need to pass a is_val parameter to stop on_test_batch_end
269-
var val_logs = evaluate((Tensor)validation_data.Value.Item1, validation_data.Value.Item2, is_val:true);
269+
var val_logs = evaluate(validation_data.Value.Item1, validation_data.Value.Item2, is_val:true);
270270
foreach (var log in val_logs)
271271
{
272272
logs["val_" + log.Key] = log.Value;

0 commit comments

Comments
 (0)