Skip to content

Commit

Permalink
change "bool training" => "bool? training"
Browse files Browse the repository at this point in the history
the bool to tensor has a bug, if in init the training is False, the program not start.
  • Loading branch information
dogvane committed Jul 9, 2023
1 parent b968fd7 commit fa213eb
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Core/Keras/Layers/ILayer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public interface ILayer: IWithTrackable, IKerasConfigable
List<ILayer> Layers { get; }
List<INode> InboundNodes { get; }
List<INode> OutboundNodes { get; }
Tensors Apply(Tensors inputs, Tensors states = null, bool training = false, IOptionalArgs? optional_args = null);
Tensors Apply(Tensors inputs, Tensors states = null, bool? training = false, IOptionalArgs? optional_args = null);
List<IVariableV1> TrainableVariables { get; }
List<IVariableV1> TrainableWeights { get; }
List<IVariableV1> NonTrainableWeights { get; }
Expand Down
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ private Tensor _zero_state_tensors(object state_size, Tensor batch_size, TF_Data
throw new NotImplementedException("_zero_state_tensors");
}

public Tensors Apply(Tensors inputs, Tensors state = null, bool is_training = false, IOptionalArgs? optional_args = null)
public Tensors Apply(Tensors inputs, Tensors state = null, bool? is_training = false, IOptionalArgs? optional_args = null)
{
throw new NotImplementedException();
}
Expand Down
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Keras/Engine/Layer.Apply.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ public partial class Layer
/// <param name="state"></param>
/// <param name="training"></param>
/// <returns></returns>
public virtual Tensors Apply(Tensors inputs, Tensors states = null, bool training = false, IOptionalArgs? optional_args = null)
public virtual Tensors Apply(Tensors inputs, Tensors states = null, bool? training = false, IOptionalArgs? optional_args = null)
{
if (callContext.Value == null)
callContext.Value = new CallContext();
Expand Down
8 changes: 6 additions & 2 deletions src/TensorFlowNET.Keras/Engine/Model.Fit.cs
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ public History fit(IDatasetV2 dataset,
int verbose = 1,
List<ICallback> callbacks = null,
IDatasetV2 validation_data = null,
int validation_step = 10, // 间隔多少次会进行一次验证
bool shuffle = true,
int initial_epoch = 0,
int max_queue_size = 10,
Expand All @@ -164,11 +165,11 @@ public History fit(IDatasetV2 dataset,
});


return FitInternal(data_handler, epochs, verbose, callbacks, validation_data: validation_data,
return FitInternal(data_handler, epochs, validation_step, verbose, callbacks, validation_data: validation_data,
train_step_func: train_step_function);
}

History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICallback> callbackList, IDatasetV2 validation_data,
History FitInternal(DataHandler data_handler, int epochs, int validation_step, int verbose, List<ICallback> callbackList, IDatasetV2 validation_data,
Func<DataHandler, OwnedIterator, Dictionary<string, float>> train_step_func)
{
stop_training = false;
Expand Down Expand Up @@ -207,6 +208,9 @@ History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICal

if (validation_data != null)
{
if (validation_step > 0 && epoch ==0 || (epoch) % validation_step != 0)
continue;

var val_logs = evaluate(validation_data);
foreach(var log in val_logs)
{
Expand Down
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bo
}
}

public override Tensors Apply(Tensors inputs, Tensors initial_states = null, bool training = false, IOptionalArgs? optional_args = null)
public override Tensors Apply(Tensors inputs, Tensors initial_states = null, bool? training = false, IOptionalArgs? optional_args = null)
{
RnnOptionalArgs? rnn_optional_args = optional_args as RnnOptionalArgs;
if (optional_args is not null && rnn_optional_args is null)
Expand Down

0 comments on commit fa213eb

Please sign in to comment.