Skip to content

refactor: Model.eval #1092

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Core/Keras/Engine/IModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ void load_weights(string filepath,
bool skip_mismatch = false,
object options = null);

Dictionary<string, float> evaluate(NDArray x, NDArray y,
Dictionary<string, float> evaluate(Tensor x, Tensor y,
int batch_size = -1,
int verbose = 1,
int steps = -1,
Expand Down
24 changes: 12 additions & 12 deletions src/TensorFlowNET.Core/Tensors/Tensors.cs
Original file line number Diff line number Diff line change
Expand Up @@ -90,73 +90,73 @@ public T[] ToArray<T>() where T: unmanaged
}

#region Explicit Conversions
public unsafe static explicit operator bool(Tensors tensor)
public static explicit operator bool(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to bool");
return (bool)tensor[0];
}

public unsafe static explicit operator sbyte(Tensors tensor)
public static explicit operator sbyte(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to sbyte");
return (sbyte)tensor[0];
}

public unsafe static explicit operator byte(Tensors tensor)
public static explicit operator byte(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to byte");
return (byte)tensor[0];
}

public unsafe static explicit operator ushort(Tensors tensor)
public static explicit operator ushort(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to ushort");
return (ushort)tensor[0];
}

public unsafe static explicit operator short(Tensors tensor)
public static explicit operator short(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to short");
return (short)tensor[0];
}

public unsafe static explicit operator int(Tensors tensor)
public static explicit operator int(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to int");
return (int)tensor[0];
}

public unsafe static explicit operator uint(Tensors tensor)
public static explicit operator uint(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to uint");
return (uint)tensor[0];
}

public unsafe static explicit operator long(Tensors tensor)
public static explicit operator long(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to long");
return (long)tensor[0];
}

public unsafe static explicit operator ulong(Tensors tensor)
public static explicit operator ulong(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to ulong");
return (ulong)tensor[0];
}

public unsafe static explicit operator float(Tensors tensor)
public static explicit operator float(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to byte");
return (byte)tensor[0];
}

public unsafe static explicit operator double(Tensors tensor)
public static explicit operator double(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to double");
return (double)tensor[0];
}

public unsafe static explicit operator string(Tensors tensor)
public static explicit operator string(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to string");
return (string)tensor[0];
Expand Down
121 changes: 45 additions & 76 deletions src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
using Tensorflow.NumPy;
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Callbacks;
using Tensorflow.Keras.Engine.DataAdapters;
using static Tensorflow.Binding;
using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Utils;
using Tensorflow;
using Tensorflow.Keras.Callbacks;
using Tensorflow.NumPy;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Engine
{
Expand All @@ -27,7 +27,7 @@ public partial class Model
/// <param name="use_multiprocessing"></param>
/// <param name="return_dict"></param>
/// <param name="is_val"></param>
public Dictionary<string, float> evaluate(NDArray x, NDArray y,
public Dictionary<string, float> evaluate(Tensor x, Tensor y,
int batch_size = -1,
int verbose = 1,
int steps = -1,
Expand Down Expand Up @@ -64,34 +64,11 @@ public Dictionary<string, float> evaluate(NDArray x, NDArray y,
Verbose = verbose,
Steps = data_handler.Inferredsteps
});
callbacks.on_test_begin();

//Dictionary<string, float>? logs = null;
var logs = new Dictionary<string, float>();
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{
reset_metrics();
// data_handler.catch_stop_iteration();

foreach (var step in data_handler.steps())
{
callbacks.on_test_batch_begin(step);
logs = test_function(data_handler, iterator);
var end_step = step + data_handler.StepIncrement;
if (is_val == false)
callbacks.on_test_batch_end(end_step, logs);
}
}

var results = new Dictionary<string, float>();
foreach (var log in logs)
{
results[log.Key] = log.Value;
}
return results;
return evaluate(data_handler, callbacks, is_val, test_function);
}

public Dictionary<string, float> evaluate(IEnumerable<Tensor> x, NDArray y, int verbose = 1, bool is_val = false)
public Dictionary<string, float> evaluate(IEnumerable<Tensor> x, Tensor y, int verbose = 1, bool is_val = false)
{
var data_handler = new DataHandler(new DataHandlerArgs
{
Expand All @@ -107,34 +84,10 @@ public Dictionary<string, float> evaluate(IEnumerable<Tensor> x, NDArray y, int
Verbose = verbose,
Steps = data_handler.Inferredsteps
});
callbacks.on_test_begin();

Dictionary<string, float> logs = null;
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{
reset_metrics();
callbacks.on_epoch_begin(epoch);
// data_handler.catch_stop_iteration();

foreach (var step in data_handler.steps())
{
callbacks.on_test_batch_begin(step);
logs = test_step_multi_inputs_function(data_handler, iterator);
var end_step = step + data_handler.StepIncrement;
if (is_val == false)
callbacks.on_test_batch_end(end_step, logs);
}
}

var results = new Dictionary<string, float>();
foreach (var log in logs)
{
results[log.Key] = log.Value;
}
return results;
return evaluate(data_handler, callbacks, is_val, test_step_multi_inputs_function);
}


public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1, bool is_val = false)
{
var data_handler = new DataHandler(new DataHandlerArgs
Expand All @@ -150,9 +103,24 @@ public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1, bool is
Verbose = verbose,
Steps = data_handler.Inferredsteps
});

return evaluate(data_handler, callbacks, is_val, test_function);
}

/// <summary>
/// Internal bare implementation of evaluate function.
/// </summary>
/// <param name="data_handler">Interations handling objects</param>
/// <param name="callbacks"></param>
/// <param name="test_func">The function to be called on each batch of data.</param>
/// <param name="is_val">Whether it is validation or test.</param>
/// <returns></returns>
Dictionary<string, float> evaluate(DataHandler data_handler, CallbackList callbacks, bool is_val, Func<DataHandler, Tensor[], Dictionary<string, float>> test_func)
{
callbacks.on_test_begin();

Dictionary<string, float> logs = null;
var results = new Dictionary<string, float>();
var logs = results;
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{
reset_metrics();
Expand All @@ -162,45 +130,46 @@ public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1, bool is
foreach (var step in data_handler.steps())
{
callbacks.on_test_batch_begin(step);
logs = test_function(data_handler, iterator);

logs = test_func(data_handler, iterator.next());

tf_with(ops.control_dependencies(Array.Empty<object>()), ctl => _train_counter.assign_add(1));

var end_step = step + data_handler.StepIncrement;
if (is_val == false)
if (!is_val)
callbacks.on_test_batch_end(end_step, logs);
}

if (!is_val)
callbacks.on_epoch_end(epoch, logs);
}

var results = new Dictionary<string, float>();
foreach (var log in logs)
{
results[log.Key] = log.Value;
}

return results;
}

Dictionary<string, float> test_function(DataHandler data_handler, OwnedIterator iterator)
Dictionary<string, float> test_function(DataHandler data_handler, Tensor[] data)
{
var data = iterator.next();
var outputs = test_step(data_handler, data[0], data[1]);
tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1));
var (x, y) = data_handler.DataAdapter.Expand1d(data[0], data[1]);

var y_pred = Apply(x, training: false);
var loss = compiled_loss.Call(y, y_pred);

compiled_metrics.update_state(y, y_pred);

var outputs = metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Name, x => (float)x.Item2);
return outputs;
}
Dictionary<string, float> test_step_multi_inputs_function(DataHandler data_handler, OwnedIterator iterator)

Dictionary<string, float> test_step_multi_inputs_function(DataHandler data_handler, Tensor[] data)
{
var data = iterator.next();
var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount;
var outputs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size)));
tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1));
return outputs;
}
Dictionary<string, float> test_step(DataHandler data_handler, Tensor x, Tensor y)
{
(x, y) = data_handler.DataAdapter.Expand1d(x, y);
var y_pred = Apply(x, training: false);
var loss = compiled_loss.Call(y, y_pred);

compiled_metrics.update_state(y, y_pred);

return metrics.Select(x => (x.Name, x.result())).ToDictionary(x=>x.Item1, x=>(float)x.Item2);
}
}
}
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Keras/Engine/Model.Fit.cs
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICal
{
// Because evaluate calls call_test_batch_end, this interferes with our output on the screen
// so we need to pass a is_val parameter to stop on_test_batch_end
var val_logs = evaluate(validation_data.Value.Item1, validation_data.Value.Item2, is_val:true);
var val_logs = evaluate((Tensor)validation_data.Value.Item1, validation_data.Value.Item2, is_val:true);
foreach (var log in val_logs)
{
logs["val_" + log.Key] = log.Value;
Expand Down