From 5aa3cef8287605c2facb2f13a6eaa493a577026b Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 22 Apr 2024 09:45:42 +0200 Subject: [PATCH] Fix old model loading --- torchmdnet/models/model.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index bfb218fe..efc3c0d6 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -235,11 +235,22 @@ def load_model(filepath, args=None, device="cpu", return_std=False, **kwargs): state_dict = {re.sub(r"^model\.", "", k): v for k, v in ckpt["state_dict"].items()} # In ET, before we had output_model.output_network.{0,1}.update_net.[0-9].{weight,bias} # Now we have output_model.output_network.{0,1}.update_net.layers.[0-9].{weight,bias} + # In other models, we had output_model.output_network.{0,1}.{weight,bias}, + # which is now output_model.output_network.layers.{0,1}.{weight,bias} # This change was introduced in https://github.com/torchmd/torchmd-net/pull/314 - state_dict = { - re.sub(r"update_net\.(\d+)\.", r"update_net.layers.\1.", k): v - for k, v in state_dict.items() - } + patterns = [ + ( + r"output_model.output_network.(\d+).update_net.(\d+).", + r"output_model.output_network.\1.update_net.layers.\2.", + ), + ( + r"output_model.output_network.([02]).(weight|bias)", + r"output_model.output_network.layers.\1.\2", + ), + ] + for p in patterns: + state_dict = {re.sub(p[0], p[1], k): v for k, v in state_dict.items()} + model.load_state_dict(state_dict) return model.to(device)