diff --git a/torchmdnet/datasets/memdataset.py b/torchmdnet/datasets/memdataset.py index 329d076f..76549dfe 100644 --- a/torchmdnet/datasets/memdataset.py +++ b/torchmdnet/datasets/memdataset.py @@ -240,11 +240,10 @@ def get(self, idx): props = {} if "y" in self.properties: - props["y"] = pt.tensor(self.y_mm[idx], dtype=pt.float32).view( - 1, 1 - ) # It would be better to use float64, but the trainer complains + y = self.y_mm[idx] if self.remove_ref_energy: - props["y"] -= self.compute_reference_energy(z) + y -= self.compute_reference_energy(z) + props["y"] = pt.tensor(y, dtype=pt.float32).view(1, 1) if "neg_dy" in self.properties: props["neg_dy"] = pt.tensor(self.neg_dy_mm[atoms], dtype=pt.float32) if "q" in self.properties: