diff --git a/src/TensorFlowNET.Core/Keras/Engine/IModel.cs b/src/TensorFlowNET.Core/Keras/Engine/IModel.cs index 19f3df9ba..ddc72aeec 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/IModel.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/IModel.cs @@ -60,7 +60,7 @@ void load_weights(string filepath, bool skip_mismatch = false, object options = null); - Dictionary evaluate(NDArray x, NDArray y, + Dictionary evaluate(Tensor x, Tensor y, int batch_size = -1, int verbose = 1, int steps = -1, diff --git a/src/TensorFlowNET.Core/Tensors/Tensors.cs b/src/TensorFlowNET.Core/Tensors/Tensors.cs index d063ee39f..8d382d619 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensors.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensors.cs @@ -90,73 +90,73 @@ public T[] ToArray() where T: unmanaged } #region Explicit Conversions - public unsafe static explicit operator bool(Tensors tensor) + public static explicit operator bool(Tensors tensor) { EnsureSingleTensor(tensor, "explicit conversion to bool"); return (bool)tensor[0]; } - public unsafe static explicit operator sbyte(Tensors tensor) + public static explicit operator sbyte(Tensors tensor) { EnsureSingleTensor(tensor, "explicit conversion to sbyte"); return (sbyte)tensor[0]; } - public unsafe static explicit operator byte(Tensors tensor) + public static explicit operator byte(Tensors tensor) { EnsureSingleTensor(tensor, "explicit conversion to byte"); return (byte)tensor[0]; } - public unsafe static explicit operator ushort(Tensors tensor) + public static explicit operator ushort(Tensors tensor) { EnsureSingleTensor(tensor, "explicit conversion to ushort"); return (ushort)tensor[0]; } - public unsafe static explicit operator short(Tensors tensor) + public static explicit operator short(Tensors tensor) { EnsureSingleTensor(tensor, "explicit conversion to short"); return (short)tensor[0]; } - public unsafe static explicit operator int(Tensors tensor) + public static explicit operator int(Tensors tensor) { EnsureSingleTensor(tensor, "explicit conversion to int"); return (int)tensor[0]; } - public unsafe static explicit operator uint(Tensors tensor) + public static explicit operator uint(Tensors tensor) { EnsureSingleTensor(tensor, "explicit conversion to uint"); return (uint)tensor[0]; } - public unsafe static explicit operator long(Tensors tensor) + public static explicit operator long(Tensors tensor) { EnsureSingleTensor(tensor, "explicit conversion to long"); return (long)tensor[0]; } - public unsafe static explicit operator ulong(Tensors tensor) + public static explicit operator ulong(Tensors tensor) { EnsureSingleTensor(tensor, "explicit conversion to ulong"); return (ulong)tensor[0]; } - public unsafe static explicit operator float(Tensors tensor) + public static explicit operator float(Tensors tensor) { EnsureSingleTensor(tensor, "explicit conversion to byte"); return (byte)tensor[0]; } - public unsafe static explicit operator double(Tensors tensor) + public static explicit operator double(Tensors tensor) { EnsureSingleTensor(tensor, "explicit conversion to double"); return (double)tensor[0]; } - public unsafe static explicit operator string(Tensors tensor) + public static explicit operator string(Tensors tensor) { EnsureSingleTensor(tensor, "explicit conversion to string"); return (string)tensor[0]; diff --git a/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs b/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs index 185de4f48..912f5e06d 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs @@ -1,14 +1,14 @@ -using Tensorflow.NumPy; using System; using System.Collections.Generic; using System.Linq; +using Tensorflow; using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Callbacks; using Tensorflow.Keras.Engine.DataAdapters; -using static Tensorflow.Binding; using Tensorflow.Keras.Layers; using Tensorflow.Keras.Utils; -using Tensorflow; -using Tensorflow.Keras.Callbacks; +using Tensorflow.NumPy; +using static Tensorflow.Binding; namespace Tensorflow.Keras.Engine { @@ -27,7 +27,7 @@ public partial class Model /// /// /// - public Dictionary evaluate(NDArray x, NDArray y, + public Dictionary evaluate(Tensor x, Tensor y, int batch_size = -1, int verbose = 1, int steps = -1, @@ -64,34 +64,11 @@ public Dictionary evaluate(NDArray x, NDArray y, Verbose = verbose, Steps = data_handler.Inferredsteps }); - callbacks.on_test_begin(); - - //Dictionary? logs = null; - var logs = new Dictionary(); - foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) - { - reset_metrics(); - // data_handler.catch_stop_iteration(); - - foreach (var step in data_handler.steps()) - { - callbacks.on_test_batch_begin(step); - logs = test_function(data_handler, iterator); - var end_step = step + data_handler.StepIncrement; - if (is_val == false) - callbacks.on_test_batch_end(end_step, logs); - } - } - var results = new Dictionary(); - foreach (var log in logs) - { - results[log.Key] = log.Value; - } - return results; + return evaluate(data_handler, callbacks, is_val, test_function); } - public Dictionary evaluate(IEnumerable x, NDArray y, int verbose = 1, bool is_val = false) + public Dictionary evaluate(IEnumerable x, Tensor y, int verbose = 1, bool is_val = false) { var data_handler = new DataHandler(new DataHandlerArgs { @@ -107,34 +84,10 @@ public Dictionary evaluate(IEnumerable x, NDArray y, int Verbose = verbose, Steps = data_handler.Inferredsteps }); - callbacks.on_test_begin(); - Dictionary logs = null; - foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) - { - reset_metrics(); - callbacks.on_epoch_begin(epoch); - // data_handler.catch_stop_iteration(); - - foreach (var step in data_handler.steps()) - { - callbacks.on_test_batch_begin(step); - logs = test_step_multi_inputs_function(data_handler, iterator); - var end_step = step + data_handler.StepIncrement; - if (is_val == false) - callbacks.on_test_batch_end(end_step, logs); - } - } - - var results = new Dictionary(); - foreach (var log in logs) - { - results[log.Key] = log.Value; - } - return results; + return evaluate(data_handler, callbacks, is_val, test_step_multi_inputs_function); } - public Dictionary evaluate(IDatasetV2 x, int verbose = 1, bool is_val = false) { var data_handler = new DataHandler(new DataHandlerArgs @@ -150,9 +103,24 @@ public Dictionary evaluate(IDatasetV2 x, int verbose = 1, bool is Verbose = verbose, Steps = data_handler.Inferredsteps }); + + return evaluate(data_handler, callbacks, is_val, test_function); + } + + /// + /// Internal bare implementation of evaluate function. + /// + /// Interations handling objects + /// + /// The function to be called on each batch of data. + /// Whether it is validation or test. + /// + Dictionary evaluate(DataHandler data_handler, CallbackList callbacks, bool is_val, Func> test_func) + { callbacks.on_test_begin(); - Dictionary logs = null; + var results = new Dictionary(); + var logs = results; foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) { reset_metrics(); @@ -162,45 +130,46 @@ public Dictionary evaluate(IDatasetV2 x, int verbose = 1, bool is foreach (var step in data_handler.steps()) { callbacks.on_test_batch_begin(step); - logs = test_function(data_handler, iterator); + + logs = test_func(data_handler, iterator.next()); + + tf_with(ops.control_dependencies(Array.Empty()), ctl => _train_counter.assign_add(1)); + var end_step = step + data_handler.StepIncrement; - if (is_val == false) + if (!is_val) callbacks.on_test_batch_end(end_step, logs); } + + if (!is_val) + callbacks.on_epoch_end(epoch, logs); } - var results = new Dictionary(); foreach (var log in logs) { results[log.Key] = log.Value; } + return results; } - Dictionary test_function(DataHandler data_handler, OwnedIterator iterator) + Dictionary test_function(DataHandler data_handler, Tensor[] data) { - var data = iterator.next(); - var outputs = test_step(data_handler, data[0], data[1]); - tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1)); + var (x, y) = data_handler.DataAdapter.Expand1d(data[0], data[1]); + + var y_pred = Apply(x, training: false); + var loss = compiled_loss.Call(y, y_pred); + + compiled_metrics.update_state(y, y_pred); + + var outputs = metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Name, x => (float)x.Item2); return outputs; } - Dictionary test_step_multi_inputs_function(DataHandler data_handler, OwnedIterator iterator) + + Dictionary test_step_multi_inputs_function(DataHandler data_handler, Tensor[] data) { - var data = iterator.next(); var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount; var outputs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size))); - tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); return outputs; } - Dictionary test_step(DataHandler data_handler, Tensor x, Tensor y) - { - (x, y) = data_handler.DataAdapter.Expand1d(x, y); - var y_pred = Apply(x, training: false); - var loss = compiled_loss.Call(y, y_pred); - - compiled_metrics.update_state(y, y_pred); - - return metrics.Select(x => (x.Name, x.result())).ToDictionary(x=>x.Item1, x=>(float)x.Item2); - } } } diff --git a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs index bb8e18ccf..17ecde984 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs @@ -266,7 +266,7 @@ History FitInternal(DataHandler data_handler, int epochs, int verbose, List