Skip to content

Commit

Permalink
Fix: model.load_weights
Browse files Browse the repository at this point in the history
  • Loading branch information
Beacontownfc authored Jul 22, 2023
1 parent fa2d2dc commit 737910d
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions src/TensorFlowNET.Keras/Saving/hdf5_format.cs
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,8 @@ public static void load_weights_from_hdf5_group(long f, List<ILayer> layers)
long g = H5G.open(f, name);
var weight_names = load_attributes_from_hdf5_group(g, "weight_names");
foreach (var i_ in weight_names)
{
var vm = Regex.Replace(i_, "/", "$");
vm = i_.Split('/')[0] + "/$" + vm.Substring(i_.Split('/')[0].Length + 1, i_.Length - i_.Split('/')[0].Length - 1);
(success, Array result) = Hdf5.ReadDataset<float>(g, vm);
{
(success, Array result) = Hdf5.ReadDataset<float>(g, i_);
if (success)
weight_values.Add(np.array(result));
}
Expand Down Expand Up @@ -196,9 +194,14 @@ public static void save_weights_to_hdf5_group(long f, List<ILayer> layers)
var tensor = val.AsTensor();
if (name.IndexOf("/") > 1)
{
var crDataGroup = Hdf5.CreateOrOpenGroup(g, Hdf5Utils.NormalizedName(name.Split('/')[0]));
var _name = Regex.Replace(name.Substring(name.Split('/')[0].Length, name.Length - name.Split('/')[0].Length), "/", "$");
WriteDataset(crDataGroup, _name, tensor);
var crDataGroup = g;
string[] name_split = name.Split('/');
for(int i = 0; i < name_split.Length; i++)
{
if (i == name_split.Length - 1) break;
crDataGroup = Hdf5.CreateOrOpenGroup(crDataGroup, Hdf5Utils.NormalizedName(name_split[i]));
}
WriteDataset(crDataGroup, name_split[name_split.Length - 1], tensor);
Hdf5.CloseGroup(crDataGroup);
}
else
Expand Down

0 comments on commit 737910d

Please sign in to comment.