Skip to content

Commit

Permalink
Merge pull request #1187 from Wanglongzhi2001/master
Browse files Browse the repository at this point in the history
feat: add the implementation of sample_weight in model.fit
  • Loading branch information
Oceania2018 authored Oct 2, 2023
2 parents 15763df + f5af07c commit 0ee9d42
Show file tree
Hide file tree
Showing 13 changed files with 250 additions and 100 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.NumPy;

namespace Tensorflow.Keras.ArgsDefinition
{
Expand All @@ -16,5 +17,7 @@ public class DataAdapterArgs: IKerasConfig
public int Worker { get; set; }
public bool UseMultiprocessing { get; set; }
public IModel Model { get; set; }
public Dictionary<int, float> ClassWeight = null;
public NDArray SampleWeight = null;
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.NumPy;

namespace Tensorflow.Keras.ArgsDefinition
{
Expand All @@ -18,5 +19,7 @@ public class DataHandlerArgs: IKerasConfig
public bool UseMultiprocessing { get; set; } = false;
public IModel Model { get; set; }
public IVariableV1 StepsPerExecution { get; set; }
public Dictionary<int, float> ClassWeight = null;
public NDArray SampleWeight = null;
}
}
11 changes: 9 additions & 2 deletions src/TensorFlowNET.Core/Keras/Engine/IModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using Tensorflow.Keras.Metrics;
using Tensorflow.Keras.Saving;
using Tensorflow.NumPy;
using Tensorflow.Util;

namespace Tensorflow.Keras.Engine;

Expand All @@ -22,8 +23,10 @@ ICallback fit(NDArray x, NDArray y,
int verbose = 1,
List<ICallback> callbacks = null,
float validation_split = 0f,
(NDArray val_x, NDArray val_y)? validation_data = null,
ValidationDataPack validation_data = null,
bool shuffle = true,
Dictionary<int, float> class_weight = null,
NDArray sample_weight = null,
int initial_epoch = 0,
int max_queue_size = 10,
int workers = 1,
Expand All @@ -35,8 +38,10 @@ ICallback fit(IEnumerable<NDArray> x, NDArray y,
int verbose = 1,
List<ICallback> callbacks = null,
float validation_split = 0f,
(IEnumerable<NDArray> val_x, NDArray val_y)? validation_data = null,
ValidationDataPack validation_data = null,
bool shuffle = true,
Dictionary<int, float> class_weight = null,
NDArray sample_weight = null,
int initial_epoch = 0,
int max_queue_size = 10,
int workers = 1,
Expand All @@ -63,6 +68,8 @@ void load_weights(string filepath,
Dictionary<string, float> evaluate(NDArray x, NDArray y,
int batch_size = -1,
int verbose = 1,
NDArray sample_weight = null,

int steps = -1,
int max_queue_size = 10,
int workers = 1,
Expand Down
66 changes: 66 additions & 0 deletions src/TensorFlowNET.Core/Util/Data.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
using Tensorflow.NumPy;

namespace Tensorflow.Util
{
/// <summary>
/// ValidationDataPack is used to pass validation data to fit method.
/// It can recive data which could be A tuple `(x_val, xy_val)` or `(x_val, y_val, sample_weight_val)` of Numpy arrays.
/// </summary>
public class ValidationDataPack
{
public NDArray val_x;
public NDArray val_y;
public NDArray val_sample_weight = null;

public ValidationDataPack((NDArray, NDArray) validation_data)
{
this.val_x = validation_data.Item1;
this.val_y = validation_data.Item2;
}

public ValidationDataPack((NDArray, NDArray, NDArray) validation_data)
{
this.val_x = validation_data.Item1;
this.val_y = validation_data.Item2;
this.val_sample_weight = validation_data.Item3;
}

public ValidationDataPack((IEnumerable<NDArray>, NDArray) validation_data)
{
this.val_x = validation_data.Item1.ToArray()[0];
this.val_y = validation_data.Item2;
}

public ValidationDataPack((IEnumerable<NDArray>, NDArray, NDArray) validation_data)
{
this.val_x = validation_data.Item1.ToArray()[0];
this.val_y = validation_data.Item2;
this.val_sample_weight = validation_data.Item3;
}

public static implicit operator ValidationDataPack((NDArray, NDArray) validation_data)
=> new ValidationDataPack(validation_data);

public static implicit operator ValidationDataPack((NDArray, NDArray, NDArray) validation_data)
=> new ValidationDataPack(validation_data);

public static implicit operator ValidationDataPack((IEnumerable<NDArray>, NDArray) validation_data)
=> new ValidationDataPack(validation_data);

public static implicit operator ValidationDataPack((IEnumerable<NDArray>, NDArray, NDArray) validation_data)
=> new ValidationDataPack(validation_data);

public void Deconstruct(out NDArray val_x, out NDArray val_y)
{
val_x = this.val_x;
val_y = this.val_y;
}

public void Deconstruct(out NDArray val_x, out NDArray val_y, out NDArray val_sample_weight)
{
val_x = this.val_x;
val_y = this.val_y;
val_sample_weight = this.val_sample_weight;
}
}
}
59 changes: 59 additions & 0 deletions src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Util;

namespace Tensorflow.Keras.Engine.DataAdapters
{
Expand Down Expand Up @@ -34,9 +35,67 @@ public virtual (Tensors, Tensors) Expand1d(Tensors x, Tensors y)
return (x, y);
}

public virtual (Tensors, Tensors, Tensors) Expand1d(Tensors x, Tensors y, Tensors sample_weight)
{
for (int i = 0; i < x.Length; i++)
{
if (x[i].shape.ndim == 1)
x[i] = array_ops.expand_dims(x[i], axis: -1);
}
for (int i = 0; i < y.Length; i++)
{
if (y[i].shape.ndim == 1)
y[i] = array_ops.expand_dims(y[i], axis: -1);
}
for (int i = 0; i < sample_weight.Length; i++)
{
if (sample_weight[i].shape.ndim == 1)
sample_weight[i] = array_ops.expand_dims(sample_weight[i], axis: -1);
}
return (x, y, sample_weight);
}

public virtual bool ShouldRecreateIterator()
{
return true;
}

public static ((NDArray, NDArray, NDArray),ValidationDataPack) train_validation_split((NDArray, NDArray, NDArray) x_y_sample_weight, float validation_split)
{
var x = x_y_sample_weight.Item1;
var y = x_y_sample_weight.Item2;
var sample_weight = x_y_sample_weight.Item3;
int train_count = Convert.ToInt32(x.dims[0] * (1 - validation_split));
var train_x = x[new Slice(0, train_count)];
var train_y = y[new Slice(0, train_count)];
ValidationDataPack validation_data;
if (sample_weight != null)
{
validation_data = (x[new Slice(train_count)], y[new Slice(train_count)], sample_weight[new Slice(train_count)]);
sample_weight = sample_weight[new Slice(0, train_count)];
}
else
{
validation_data = (x[new Slice(train_count)], y[new Slice(train_count)]);
}

return ((train_x, train_y, sample_weight), validation_data);
}

public static ((IEnumerable<NDArray>, NDArray, NDArray), ValidationDataPack) train_validation_split((IEnumerable<NDArray>, NDArray, NDArray) x_y_sample_weight, float validation_split)
{
var x = x_y_sample_weight.Item1;
var y = x_y_sample_weight.Item2;
var sample_weight = x_y_sample_weight.Item3;
int train_count = Convert.ToInt32(y.dims[0] * (1 - validation_split));
var train_x = x.Select(x => x[new Slice(0, train_count)] as NDArray);
var train_y = y[new Slice(0, train_count)];
var val_x = x.Select(x => x[new Slice(train_count)] as NDArray);
var val_y = y[new Slice(train_count)];
NDArray tmp_sample_weight = sample_weight;
sample_weight = sample_weight[new Slice(0, train_count)];
ValidationDataPack validation_data = (val_x, val_y, tmp_sample_weight[new Slice(train_count)]);
return ((train_x, train_y, sample_weight), validation_data);
}
}
}
3 changes: 3 additions & 0 deletions src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System.Collections.Generic;
using Tensorflow.Keras.ArgsDefinition;
using static Tensorflow.Binding;
using Tensorflow.Keras.Utils;

namespace Tensorflow.Keras.Engine.DataAdapters
{
Expand All @@ -28,6 +29,7 @@ public class DataHandler
public DataHandler(DataHandlerArgs args)
{
this.args = args;

if (args.StepsPerExecution == null)
{
_steps_per_execution = tf.Variable(1L);
Expand All @@ -48,6 +50,7 @@ public DataHandler(DataHandlerArgs args)
BatchSize = args.BatchSize,
Steps = args.StepsPerEpoch,
Epochs = args.Epochs - args.InitialEpoch,
SampleWeight = args.SampleWeight,
Shuffle = args.Shuffle,
MaxQueueSize = args.MaxQueueSize,
Worker = args.Workers,
Expand Down
2 changes: 2 additions & 0 deletions src/TensorFlowNET.Keras/Engine/DataAdapters/IDataAdapter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ public interface IDataAdapter
IDatasetV2 GetDataset();
int GetSize();
(Tensors, Tensors) Expand1d(Tensors x, Tensors y);
(Tensors, Tensors, Tensors) Expand1d(Tensors x, Tensors y, Tensors sample_weight);

bool ShouldRecreateIterator();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public class TensorLikeDataAdapter : DataAdapter, IDataAdapter
public TensorLikeDataAdapter(DataAdapterArgs args)
{
this.args = args;
_process_tensorlike();
Tensor sample_weight_tensor = args.SampleWeight != null ? _process_tensorlike(args.SampleWeight) : null;
num_samples = (int)args.X.shape[0];
var batch_size = args.BatchSize == -1 ? 32 : args.BatchSize;
_batch_size = batch_size;
Expand All @@ -37,6 +37,8 @@ public TensorLikeDataAdapter(DataAdapterArgs args)
inputs.AddRange(args.X);
if (args.Y != null)
inputs.AddRange(args.Y);
if (sample_weight_tensor != null)
inputs.Add(sample_weight_tensor);
dataset = slice_inputs(indices_dataset, inputs);
dataset.FirstInputTensorCount = args.X.Length;
}
Expand Down Expand Up @@ -94,8 +96,9 @@ IDatasetV2 slice_inputs(IDatasetV2 indices_dataset, Tensors elements)

public override bool ShouldRecreateIterator() => false;

void _process_tensorlike()
Tensor _process_tensorlike(NDArray sample_weights)
{
return tf.convert_to_tensor(sample_weights);
}
}
}
4 changes: 2 additions & 2 deletions src/TensorFlowNET.Keras/Engine/LossesContainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ public LossesContainer(ILossFunc losses, string[] output_names = null)
/// </summary>
/// <param name="y_true"></param>
/// <param name="y_pred"></param>
public Tensor Call(Tensor y_true, Tensor y_pred)
public Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null)
{
if (!_built)
Build(y_pred);
var loss_value = _losses.Call(y_true, y_pred);
var loss_value = _losses.Call(y_true, y_pred, sample_weight:sample_weight);
var loss_metric_value = loss_value;
var batch_dim = array_ops.shape(y_true)[0];

Expand Down
19 changes: 14 additions & 5 deletions src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ public partial class Model
public Dictionary<string, float> evaluate(NDArray x, NDArray y,
int batch_size = -1,
int verbose = 1,
NDArray sample_weight = null,
int steps = -1,
int max_queue_size = 10,
int workers = 1,
Expand All @@ -51,6 +52,7 @@ public Dictionary<string, float> evaluate(NDArray x, NDArray y,
StepsPerEpoch = steps,
InitialEpoch = 0,
Epochs = 1,
SampleWeight = sample_weight,
MaxQueueSize = max_queue_size,
Workers = workers,
UseMultiprocessing = use_multiprocessing,
Expand Down Expand Up @@ -140,7 +142,8 @@ Dictionary<string, float> evaluate(DataHandler data_handler, CallbackList callba
Dictionary<string, float> test_function(DataHandler data_handler, OwnedIterator iterator)
{
var data = iterator.next();
var outputs = test_step(data_handler, data[0], data[1]);
var outputs = data.Length == 2 ? test_step(data_handler, data[0], data[1]) :
test_step(data_handler, data[0], data[1], data[2]);
tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1));
return outputs;
}
Expand All @@ -149,17 +152,23 @@ Dictionary<string, float> test_step_multi_inputs_function(DataHandler data_handl
{
var data = iterator.next();
var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount;
var outputs = test_step(data_handler, data.Take(x_size).ToArray(), data.Skip(x_size).ToArray());
var outputs = data.Length == 2 ?
test_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray())) :
test_step(
data_handler,
new Tensors(data.Take(x_size).ToArray()),
new Tensors(data.Skip(x_size).Take(x_size).ToArray()),
new Tensors(data.Skip(2 * x_size).ToArray()));
tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1));
return outputs;
}


Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y)
Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y, Tensors sample_weight = null)
{
(x, y) = data_handler.DataAdapter.Expand1d(x, y);
(x, y, sample_weight) = data_handler.DataAdapter.Expand1d(x, y, sample_weight);
var y_pred = Apply(x, training: false);
var loss = compiled_loss.Call(y, y_pred);
var loss = compiled_loss.Call(y, y_pred, sample_weight:sample_weight);
compiled_metrics.update_state(y, y_pred);
return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2);
}
Expand Down
Loading

0 comments on commit 0ee9d42

Please sign in to comment.