Skip to content

Commit 89fe0bb

Browse files
Wanglongzhi2001Oceania2018
authored andcommitted
Fix model.evaluate don't have output
1 parent bdf229a commit 89fe0bb

File tree

5 files changed

+87
-16
lines changed

5 files changed

+87
-16
lines changed

src/TensorFlowNET.Core/Keras/Engine/ICallback.cs

+2
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,6 @@ public interface ICallback
1212
void on_predict_batch_begin(long step);
1313
void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs);
1414
void on_predict_end();
15+
void on_test_begin();
16+
void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs);
1517
}

src/TensorFlowNET.Keras/Callbacks/CallbackList.cs

+13-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@ public void on_train_begin()
2020
{
2121
callbacks.ForEach(x => x.on_train_begin());
2222
}
23-
23+
public void on_test_begin()
24+
{
25+
callbacks.ForEach(x => x.on_test_begin());
26+
}
2427
public void on_epoch_begin(int epoch)
2528
{
2629
callbacks.ForEach(x => x.on_epoch_begin(epoch));
@@ -60,4 +63,13 @@ public void on_predict_end()
6063
{
6164
callbacks.ForEach(x => x.on_predict_end());
6265
}
66+
67+
public void on_test_batch_begin(long step)
68+
{
69+
callbacks.ForEach(x => x.on_train_batch_begin(step));
70+
}
71+
public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs)
72+
{
73+
callbacks.ForEach(x => x.on_test_batch_end(end_step, logs));
74+
}
6375
}

src/TensorFlowNET.Keras/Callbacks/History.cs

+18-5
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,19 @@ public void on_train_begin()
1818
epochs = new List<int>();
1919
history = new Dictionary<string, List<float>>();
2020
}
21-
21+
public void on_test_begin()
22+
{
23+
epochs = new List<int>();
24+
history = new Dictionary<string, List<float>>();
25+
}
2226
public void on_epoch_begin(int epoch)
2327
{
2428

2529
}
2630

2731
public void on_train_batch_begin(long step)
2832
{
29-
33+
3034
}
3135

3236
public void on_train_batch_end(long end_step, Dictionary<string, float> logs)
@@ -55,16 +59,25 @@ public void on_predict_begin()
5559

5660
public void on_predict_batch_begin(long step)
5761
{
58-
62+
5963
}
6064

6165
public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs)
6266
{
63-
67+
6468
}
6569

6670
public void on_predict_end()
6771
{
68-
72+
73+
}
74+
75+
public void on_test_batch_begin(long step)
76+
{
77+
78+
}
79+
80+
public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs)
81+
{
6982
}
7083
}

src/TensorFlowNET.Keras/Callbacks/ProgbarLogger.cs

+26-5
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@ public void on_train_begin()
2222
_called_in_fit = true;
2323
_sw = new Stopwatch();
2424
}
25-
25+
public void on_test_begin()
26+
{
27+
_sw = new Stopwatch();
28+
}
2629
public void on_epoch_begin(int epoch)
2730
{
2831
_reset_progbar();
@@ -44,7 +47,7 @@ public void on_train_batch_end(long end_step, Dictionary<string, float> logs)
4447
var progress = "";
4548
var length = 30.0 / _parameters.Steps;
4649
for (int i = 0; i < Math.Floor(end_step * length - 1); i++)
47-
progress += "=";
50+
progress += "=";
4851
if (progress.Length < 28)
4952
progress += ">";
5053
else
@@ -84,17 +87,35 @@ public void on_predict_begin()
8487

8588
public void on_predict_batch_begin(long step)
8689
{
87-
90+
8891
}
8992

9093
public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs)
9194
{
92-
95+
9396
}
9497

9598
public void on_predict_end()
9699
{
97-
100+
101+
}
102+
103+
public void on_test_batch_begin(long step)
104+
{
105+
_sw.Restart();
98106
}
107+
public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs)
108+
{
109+
_sw.Stop();
110+
var elapse = _sw.ElapsedMilliseconds;
111+
var results = string.Join(" - ", logs.Select(x => $"{x.Item1}: {(float)x.Item2.numpy():F6}"));
112+
113+
Binding.tf_output_redirect.Write($"{end_step + 1:D4}/{_parameters.Steps:D4} - {elapse}ms/step - {results}");
114+
if (!Console.IsOutputRedirected)
115+
{
116+
Console.CursorLeft = 0;
117+
}
118+
}
119+
99120
}
100121
}

src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs

+28-5
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
using Tensorflow.Keras.ArgsDefinition;
66
using Tensorflow.Keras.Engine.DataAdapters;
77
using static Tensorflow.Binding;
8+
using Tensorflow.Keras.Layers;
9+
using Tensorflow.Keras.Utils;
10+
using Tensorflow;
11+
using Tensorflow.Keras.Callbacks;
812

913
namespace Tensorflow.Keras.Engine
1014
{
@@ -31,6 +35,11 @@ public void evaluate(NDArray x, NDArray y,
3135
bool use_multiprocessing = false,
3236
bool return_dict = false)
3337
{
38+
if (x.dims[0] != y.dims[0])
39+
{
40+
throw new InvalidArgumentError(
41+
$"The array x and y should have same value at dim 0, but got {x.dims[0]} and {y.dims[0]}");
42+
}
3443
var data_handler = new DataHandler(new DataHandlerArgs
3544
{
3645
X = x,
@@ -46,18 +55,31 @@ public void evaluate(NDArray x, NDArray y,
4655
StepsPerExecution = _steps_per_execution
4756
});
4857

58+
var callbacks = new CallbackList(new CallbackParams
59+
{
60+
Model = this,
61+
Verbose = verbose,
62+
Steps = data_handler.Inferredsteps
63+
});
64+
callbacks.on_test_begin();
65+
4966
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
5067
{
5168
reset_metrics();
52-
// callbacks.on_epoch_begin(epoch)
69+
//callbacks.on_epoch_begin(epoch);
5370
// data_handler.catch_stop_iteration();
54-
IEnumerable<(string, Tensor)> results = null;
71+
IEnumerable<(string, Tensor)> logs = null;
72+
5573
foreach (var step in data_handler.steps())
5674
{
57-
// callbacks.on_train_batch_begin(step)
58-
results = test_function(data_handler, iterator);
75+
callbacks.on_train_batch_begin(step);
76+
logs = test_function(data_handler, iterator);
77+
var end_step = step + data_handler.StepIncrement;
78+
callbacks.on_test_batch_end(end_step, logs);
5979
}
6080
}
81+
GC.Collect();
82+
GC.WaitForPendingFinalizers();
6183
}
6284

6385
public KeyValuePair<string, float>[] evaluate(IDatasetV2 x)
@@ -75,7 +97,8 @@ public KeyValuePair<string, float>[] evaluate(IDatasetV2 x)
7597
reset_metrics();
7698
// callbacks.on_epoch_begin(epoch)
7799
// data_handler.catch_stop_iteration();
78-
100+
101+
79102
foreach (var step in data_handler.steps())
80103
{
81104
// callbacks.on_train_batch_begin(step)

0 commit comments

Comments
 (0)