diff --git a/src/TensorFlowNET.Keras/Engine/Model.Training.cs b/src/TensorFlowNET.Keras/Engine/Model.Training.cs index 50d934d9d..457b3d694 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Training.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Training.cs @@ -10,8 +10,38 @@ namespace Tensorflow.Keras.Engine { public partial class Model { + static Dictionary> weightsCache + = new Dictionary>(); + public void load_weights(string filepath, bool by_name = false, bool skip_mismatch = false, object options = null) { + // Get from cache + if (weightsCache.ContainsKey(filepath)) + { + var filtered_layers = new List(); + foreach (var layer in Layers) + { + var weights = hdf5_format._legacy_weights(layer); + if (weights.Count > 0) + filtered_layers.append(layer); + } + + var weight_value_tuples = new List<(IVariableV1, NDArray)>(); + filtered_layers.Select((layer, i) => + { + var symbolic_weights = hdf5_format._legacy_weights(layer); + foreach(var weight in symbolic_weights) + { + var weight_value = weightsCache[filepath].First(x => x.Item1 == weight.Name).Item2; + weight_value_tuples.Add((weight, weight_value)); + } + return layer; + }).ToList(); + + keras.backend.batch_set_value(weight_value_tuples); + return; + } + long fileId = Hdf5.OpenFile(filepath, true); if(fileId < 0) { @@ -29,8 +59,11 @@ public void load_weights(string filepath, bool by_name = false, bool skip_mismat throw new NotImplementedException(""); else { - hdf5_format.load_weights_from_hdf5_group(fileId, Layers); + var weight_value_tuples = hdf5_format.load_weights_from_hdf5_group(fileId, Layers); Hdf5.CloseFile(fileId); + + weightsCache[filepath] = weight_value_tuples.Select(x => (x.Item1.Name, x.Item2)).ToList(); + keras.backend.batch_set_value(weight_value_tuples); } } diff --git a/src/TensorFlowNET.Keras/Saving/hdf5_format.cs b/src/TensorFlowNET.Keras/Saving/hdf5_format.cs index bab0efecf..68b73953d 100644 --- a/src/TensorFlowNET.Keras/Saving/hdf5_format.cs +++ b/src/TensorFlowNET.Keras/Saving/hdf5_format.cs @@ -82,7 +82,7 @@ public static void load_optimizer_weights_from_hdf5_group(long filepath = -1, Di } - public static void load_weights_from_hdf5_group(long f, List layers) + public static List<(IVariableV1, NDArray)> load_weights_from_hdf5_group(long f, List layers) { string original_keras_version = "2.5.0"; string original_backend = null; @@ -152,7 +152,7 @@ public static void load_weights_from_hdf5_group(long f, List layers) weight_value_tuples.AddRange(zip(symbolic_weights, weight_values)); } - keras.backend.batch_set_value(weight_value_tuples); + return weight_value_tuples; } public static void toarrayf4(long filepath = -1, Dictionary custom_objects = null, bool compile = false)