diff --git a/src/TensorFlowNET.Core/NumPy/Numpy.Math.cs b/src/TensorFlowNET.Core/NumPy/Numpy.Math.cs index 5bc97952b..2559638b3 100644 --- a/src/TensorFlowNET.Core/NumPy/Numpy.Math.cs +++ b/src/TensorFlowNET.Core/NumPy/Numpy.Math.cs @@ -85,5 +85,11 @@ public static NDArray dot(NDArray x1, NDArray x2, NDArray? axes = null, string? [AutoNumPy] public static NDArray add(NDArray x, NDArray y) => new NDArray(math_ops.add(x, y)); + + [AutoNumPy] + public static NDArray greater(NDArray x, NDArray y) => new NDArray(tf.greater(x, y)); + + [AutoNumPy] + public static NDArray less(NDArray x, NDArray y) => new NDArray(tf.less(x, y)); } } diff --git a/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs b/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs index 36993b637..a2a2ecfe2 100644 --- a/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs +++ b/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs @@ -19,8 +19,10 @@ public class EarlyStopping: ICallback string _monitor; string _mode; bool _restore_best_weights; - List? _best_weights; + List? _best_weights; CallbackParams _parameters; + Func _monitor_op; + public Dictionary>? history { get; set; } // user need to pass a CallbackParams to EarlyStopping, CallbackParams at least need the model public EarlyStopping(CallbackParams parameters,string monitor = "val_loss", float min_delta = 0f, int patience = 0, @@ -38,17 +40,49 @@ public EarlyStopping(CallbackParams parameters,string monitor = "val_loss", floa _min_delta = Math.Abs(min_delta); _restore_best_weights = restore_best_weights; _mode = mode; - if (mode != "auto" && mode != "min" && mode != "max") + + if (_mode != "auto" && _mode != "min" && _mode != "max") + { + Console.WriteLine($"EarlyStopping mode {_mode} is unknown, fallback to auto mode."); + _mode = "auto"; + } + + if (_mode == "min") + { + _monitor_op = np.less; + } + else if (_mode == "max") + { + _monitor_op = np.greater; + } + else + { + if (_monitor.EndsWith("acc") || _monitor.EndsWith("accuracy") || _monitor.EndsWith("auc")) + { + _monitor_op = np.greater; + } + else + { + _monitor_op = np.less; + } + } + + if (_monitor_op == np.greater) { - Console.WriteLine("EarlyStopping mode %s is unknown, fallback to auto mode.", mode); + _min_delta *= 1; + } + else + { + _min_delta *= -1; } } public void on_train_begin() { _wait = 0; _stopped_epoch = 0; + _best = _monitor_op == np.less ? (float)np.Inf : (float)-np.Inf; + _best_weights = null; _best_epoch = 0; - _best = (float)np.Inf; } public void on_epoch_begin(int epoch) @@ -74,7 +108,7 @@ public void on_epoch_end(int epoch, Dictionary epoch_logs) // Restore the weights after first epoch if no progress is ever made. if (_restore_best_weights && _best_weights == null) { - _best_weights = _parameters.Model.Weights; + _best_weights = _parameters.Model.get_weights(); } _wait += 1; @@ -83,7 +117,7 @@ public void on_epoch_end(int epoch, Dictionary epoch_logs) _best = current; _best_epoch = epoch; if (_restore_best_weights) - _best_weights = _parameters.Model.TrainableWeights; + _best_weights = _parameters.Model.get_weights(); // Only restart wait if we beat both the baseline and our previous best. if (_baseline == 0f || _is_improvement(current, _baseline)) _wait = 0; @@ -99,7 +133,7 @@ public void on_epoch_end(int epoch, Dictionary epoch_logs) { Console.WriteLine($"Restoring model weights from the end of the best epoch: {_best_epoch + 1}"); } - _parameters.Model.Weights = _best_weights; + _parameters.Model.set_weights(_best_weights); } } } @@ -131,21 +165,7 @@ float get_monitor_value(Dictionary logs) } public bool _is_improvement(float monitor_value, float reference_value) { - bool less_op = (monitor_value - _min_delta) < reference_value; - bool greater_op = (monitor_value - _min_delta) >= reference_value; - if (_mode == "min") - return less_op; - else if (_mode == "max") - return greater_op; - else - { - if (_monitor.EndsWith("acc") || _monitor.EndsWith("accuracy") || _monitor.EndsWith("auc")) - { - return greater_op; - } - else - return less_op; - } + return _monitor_op(monitor_value - _min_delta, reference_value); } public void on_test_end(Dictionary logs)