Skip to content

Commit

Permalink
Merge pull request #1188 from hchen2020/master
Browse files Browse the repository at this point in the history
Allow Model to cache weights.
  • Loading branch information
Oceania2018 authored Oct 3, 2023
2 parents 0ee9d42 + 0f02885 commit f16902d
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 3 deletions.
35 changes: 34 additions & 1 deletion src/TensorFlowNET.Keras/Engine/Model.Training.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,38 @@ namespace Tensorflow.Keras.Engine
{
public partial class Model
{
static Dictionary<string, List<(string, NDArray)>> weightsCache
= new Dictionary<string, List<(string, NDArray)>>();

public void load_weights(string filepath, bool by_name = false, bool skip_mismatch = false, object options = null)
{
// Get from cache
if (weightsCache.ContainsKey(filepath))
{
var filtered_layers = new List<ILayer>();
foreach (var layer in Layers)
{
var weights = hdf5_format._legacy_weights(layer);
if (weights.Count > 0)
filtered_layers.append(layer);
}

var weight_value_tuples = new List<(IVariableV1, NDArray)>();
filtered_layers.Select((layer, i) =>
{
var symbolic_weights = hdf5_format._legacy_weights(layer);
foreach(var weight in symbolic_weights)
{
var weight_value = weightsCache[filepath].First(x => x.Item1 == weight.Name).Item2;
weight_value_tuples.Add((weight, weight_value));
}
return layer;
}).ToList();

keras.backend.batch_set_value(weight_value_tuples);
return;
}

long fileId = Hdf5.OpenFile(filepath, true);
if(fileId < 0)
{
Expand All @@ -29,8 +59,11 @@ public void load_weights(string filepath, bool by_name = false, bool skip_mismat
throw new NotImplementedException("");
else
{
hdf5_format.load_weights_from_hdf5_group(fileId, Layers);
var weight_value_tuples = hdf5_format.load_weights_from_hdf5_group(fileId, Layers);
Hdf5.CloseFile(fileId);

weightsCache[filepath] = weight_value_tuples.Select(x => (x.Item1.Name, x.Item2)).ToList();
keras.backend.batch_set_value(weight_value_tuples);
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/TensorFlowNET.Keras/Saving/hdf5_format.cs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public static void load_optimizer_weights_from_hdf5_group(long filepath = -1, Di

}

public static void load_weights_from_hdf5_group(long f, List<ILayer> layers)
public static List<(IVariableV1, NDArray)> load_weights_from_hdf5_group(long f, List<ILayer> layers)
{
string original_keras_version = "2.5.0";
string original_backend = null;
Expand Down Expand Up @@ -152,7 +152,7 @@ public static void load_weights_from_hdf5_group(long f, List<ILayer> layers)
weight_value_tuples.AddRange(zip(symbolic_weights, weight_values));
}

keras.backend.batch_set_value(weight_value_tuples);
return weight_value_tuples;
}

public static void toarrayf4(long filepath = -1, Dictionary<string, object> custom_objects = null, bool compile = false)
Expand Down

0 comments on commit f16902d

Please sign in to comment.