diff --git a/src/TensorFlowNET.Core/Keras/Engine/ICallback.cs b/src/TensorFlowNET.Core/Keras/Engine/ICallback.cs index 096dbd2ef..e114ca97f 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/ICallback.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/ICallback.cs @@ -14,6 +14,9 @@ public interface ICallback void on_predict_batch_end(long end_step, Dictionary logs); void on_predict_end(); void on_test_begin(); + void on_test_end(Dictionary logs); void on_test_batch_begin(long step); void on_test_batch_end(long end_step, Dictionary logs); + + } diff --git a/src/TensorFlowNET.Core/Keras/Engine/IModel.cs b/src/TensorFlowNET.Core/Keras/Engine/IModel.cs index ddc72aeec..19f3df9ba 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/IModel.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/IModel.cs @@ -60,7 +60,7 @@ void load_weights(string filepath, bool skip_mismatch = false, object options = null); - Dictionary evaluate(Tensor x, Tensor y, + Dictionary evaluate(NDArray x, NDArray y, int batch_size = -1, int verbose = 1, int steps = -1, diff --git a/src/TensorFlowNET.Keras/Callbacks/CallbackList.cs b/src/TensorFlowNET.Keras/Callbacks/CallbackList.cs index 362f2280c..cb16aafa3 100644 --- a/src/TensorFlowNET.Keras/Callbacks/CallbackList.cs +++ b/src/TensorFlowNET.Keras/Callbacks/CallbackList.cs @@ -73,4 +73,9 @@ public void on_test_batch_end(long end_step, Dictionary logs) { callbacks.ForEach(x => x.on_test_batch_end(end_step, logs)); } + + public void on_test_end(Dictionary logs) + { + callbacks.ForEach(x => x.on_test_end(logs)); + } } diff --git a/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs b/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs index 59152d9b2..b3b78423c 100644 --- a/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs +++ b/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs @@ -150,4 +150,8 @@ public bool _is_improvement(float monitor_value, float reference_value) return less_op; } } + + public void on_test_end(Dictionary logs) + { + } } diff --git a/src/TensorFlowNET.Keras/Callbacks/History.cs b/src/TensorFlowNET.Keras/Callbacks/History.cs index c34f253d1..6d3ff6c38 100644 --- a/src/TensorFlowNET.Keras/Callbacks/History.cs +++ b/src/TensorFlowNET.Keras/Callbacks/History.cs @@ -81,4 +81,8 @@ public void on_test_batch_begin(long step) public void on_test_batch_end(long end_step, Dictionary logs) { } + + public void on_test_end(Dictionary logs) + { + } } diff --git a/src/TensorFlowNET.Keras/Callbacks/ProgbarLogger.cs b/src/TensorFlowNET.Keras/Callbacks/ProgbarLogger.cs index 9f2b1eb31..23b18cd47 100644 --- a/src/TensorFlowNET.Keras/Callbacks/ProgbarLogger.cs +++ b/src/TensorFlowNET.Keras/Callbacks/ProgbarLogger.cs @@ -118,5 +118,8 @@ public void on_test_batch_end(long end_step, Dictionary logs) } } + public void on_test_end(Dictionary logs) + { + } } } diff --git a/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs b/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs index eaa9eb23c..c4761f873 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs @@ -27,7 +27,7 @@ public partial class Model /// /// /// - public Dictionary evaluate(Tensor x, Tensor y, + public Dictionary evaluate(NDArray x, NDArray y, int batch_size = -1, int verbose = 1, int steps = -1, @@ -115,62 +115,53 @@ public Dictionary evaluate(IDatasetV2 x, int verbose = 1, bool is /// The function to be called on each batch of data. /// Whether it is validation or test. /// - Dictionary evaluate(DataHandler data_handler, CallbackList callbacks, bool is_val, Func> test_func) + Dictionary evaluate(DataHandler data_handler, CallbackList callbacks, bool is_val, Func> test_func) { callbacks.on_test_begin(); - var results = new Dictionary(); - var logs = results; + var logs = new Dictionary(); foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) { reset_metrics(); - callbacks.on_epoch_begin(epoch); - // data_handler.catch_stop_iteration(); - foreach (var step in data_handler.steps()) { callbacks.on_test_batch_begin(step); - - logs = test_func(data_handler, iterator.next()); - - tf_with(ops.control_dependencies(Array.Empty()), ctl => _train_counter.assign_add(1)); - + logs = test_func(data_handler, iterator); var end_step = step + data_handler.StepIncrement; if (!is_val) callbacks.on_test_batch_end(end_step, logs); } - - if (!is_val) - callbacks.on_epoch_end(epoch, logs); } - - foreach (var log in logs) - { - results[log.Key] = log.Value; - } - + callbacks.on_test_end(logs); + var results = new Dictionary(logs); return results; } - Dictionary test_function(DataHandler data_handler, Tensor[] data) + Dictionary test_function(DataHandler data_handler, OwnedIterator iterator) { - var (x, y) = data_handler.DataAdapter.Expand1d(data[0], data[1]); - - var y_pred = Apply(x, training: false); - var loss = compiled_loss.Call(y, y_pred); - - compiled_metrics.update_state(y, y_pred); - - var outputs = metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Name, x => (float)x.Item2); + var data = iterator.next(); + var outputs = test_step(data_handler, data[0], data[1]); + tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1)); return outputs; } - Dictionary test_step_multi_inputs_function(DataHandler data_handler, Tensor[] data) + Dictionary test_step_multi_inputs_function(DataHandler data_handler, OwnedIterator iterator) { + var data = iterator.next(); var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount; - var outputs = train_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray())); - tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); + var outputs = test_step(data_handler, data.Take(x_size).ToArray(), data.Skip(x_size).ToArray()); + tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1)); return outputs; } + + + Dictionary test_step(DataHandler data_handler, Tensors x, Tensors y) + { + (x, y) = data_handler.DataAdapter.Expand1d(x, y); + var y_pred = Apply(x, training: false); + var loss = compiled_loss.Call(y, y_pred); + compiled_metrics.update_state(y, y_pred); + return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2); + } } } diff --git a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs index 68dc5976c..76c592ad6 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs @@ -266,7 +266,7 @@ History FitInternal(DataHandler data_handler, int epochs, int verbose, List