diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs index e94c8bf10..2f92c4e57 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs @@ -15,7 +15,7 @@ public interface ILayer: IWithTrackable, IKerasConfigable List Layers { get; } List InboundNodes { get; } List OutboundNodes { get; } - Tensors Apply(Tensors inputs, Tensors states = null, bool training = false, IOptionalArgs? optional_args = null); + Tensors Apply(Tensors inputs, Tensors states = null, bool? training = false, IOptionalArgs? optional_args = null); List TrainableVariables { get; } List TrainableWeights { get; } List NonTrainableWeights { get; } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs index e488c47e7..4e99731f9 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs @@ -145,7 +145,7 @@ private Tensor _zero_state_tensors(object state_size, Tensor batch_size, TF_Data throw new NotImplementedException("_zero_state_tensors"); } - public Tensors Apply(Tensors inputs, Tensors state = null, bool is_training = false, IOptionalArgs? optional_args = null) + public Tensors Apply(Tensors inputs, Tensors state = null, bool? is_training = false, IOptionalArgs? optional_args = null) { throw new NotImplementedException(); } diff --git a/src/TensorFlowNET.Keras/Engine/Layer.Apply.cs b/src/TensorFlowNET.Keras/Engine/Layer.Apply.cs index d52190fd3..8a66948b9 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.Apply.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.Apply.cs @@ -13,7 +13,7 @@ public partial class Layer /// /// /// - public virtual Tensors Apply(Tensors inputs, Tensors states = null, bool training = false, IOptionalArgs? optional_args = null) + public virtual Tensors Apply(Tensors inputs, Tensors states = null, bool? training = false, IOptionalArgs? optional_args = null) { if (callContext.Value == null) callContext.Value = new CallContext(); diff --git a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs index 76c592ad6..de57f19ae 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs @@ -142,6 +142,7 @@ public History fit(IDatasetV2 dataset, int verbose = 1, List callbacks = null, IDatasetV2 validation_data = null, + int validation_step = 10, // 间隔多少次会进行一次验证 bool shuffle = true, int initial_epoch = 0, int max_queue_size = 10, @@ -164,11 +165,11 @@ public History fit(IDatasetV2 dataset, }); - return FitInternal(data_handler, epochs, verbose, callbacks, validation_data: validation_data, + return FitInternal(data_handler, epochs, validation_step, verbose, callbacks, validation_data: validation_data, train_step_func: train_step_function); } - History FitInternal(DataHandler data_handler, int epochs, int verbose, List callbackList, IDatasetV2 validation_data, + History FitInternal(DataHandler data_handler, int epochs, int validation_step, int verbose, List callbackList, IDatasetV2 validation_data, Func> train_step_func) { stop_training = false; @@ -207,6 +208,9 @@ History FitInternal(DataHandler data_handler, int epochs, int verbose, List 0 && epoch ==0 || (epoch) % validation_step != 0) + continue; + var val_logs = evaluate(validation_data); foreach(var log in val_logs) { diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs index f86de8a85..0ca62c391 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs @@ -393,7 +393,7 @@ protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bo } } - public override Tensors Apply(Tensors inputs, Tensors initial_states = null, bool training = false, IOptionalArgs? optional_args = null) + public override Tensors Apply(Tensors inputs, Tensors initial_states = null, bool? training = false, IOptionalArgs? optional_args = null) { RnnOptionalArgs? rnn_optional_args = optional_args as RnnOptionalArgs; if (optional_args is not null && rnn_optional_args is null)