From 43d8f9f10f9da96d7c4fd7f1bab0b8a603b9fcf7 Mon Sep 17 00:00:00 2001 From: Guillem Simeon <55756547+guillemsimeon@users.noreply.github.com> Date: Tue, 5 Mar 2024 11:42:11 +0100 Subject: [PATCH 1/3] Update tensornet.py --- torchmdnet/models/tensornet.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchmdnet/models/tensornet.py b/torchmdnet/models/tensornet.py index a2006c9e..2b830ce1 100644 --- a/torchmdnet/models/tensornet.py +++ b/torchmdnet/models/tensornet.py @@ -239,7 +239,8 @@ def forward( # Total charge q is a molecule-wise property. We transform it into an atom-wise property, with all atoms belonging to the same molecule being assigned the same charge q if q is None: q = torch.zeros_like(z, device=z.device, dtype=z.dtype) - else: + # if not atom-wise, make atom-wise (pq is already atom-wise) + if z.shape != q.shape: q = q[batch] zp = z if self.static_shapes: From 9ccc2c0bc98c38d0f641b7e785a03d8ef8c53502 Mon Sep 17 00:00:00 2001 From: Guillem Simeon <55756547+guillemsimeon@users.noreply.github.com> Date: Tue, 5 Mar 2024 11:44:00 +0100 Subject: [PATCH 2/3] Update model.py --- torchmdnet/models/model.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index a2a80f90..a8788dc6 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -378,6 +378,10 @@ def forward( assert z.dim() == 1 and z.dtype == torch.long batch = torch.zeros_like(z) if batch is None else batch + # trick to incorporate SPICE pqs + # set charge: true in yaml + q = extra_args["pq"] + if self.derivative: pos.requires_grad_(True) # run the potentially wrapped representation model From 25854fb2ce03888108277669e32e2768022514a3 Mon Sep 17 00:00:00 2001 From: Guillem Simeon <55756547+guillemsimeon@users.noreply.github.com> Date: Tue, 5 Mar 2024 11:46:44 +0100 Subject: [PATCH 3/3] Update model.py --- torchmdnet/models/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index a8788dc6..015d629f 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -379,7 +379,7 @@ def forward( batch = torch.zeros_like(z) if batch is None else batch # trick to incorporate SPICE pqs - # set charge: true in yaml + # set charge: true in yaml ((?) currently I do it) q = extra_args["pq"] if self.derivative: