Skip to content

Commit

Permalink
Merge pull request #275 from torchmd/reduce_energy_error
Browse files Browse the repository at this point in the history
Float precision of energies
  • Loading branch information
stefdoerr authored Feb 12, 2024
2 parents 4601b54 + a7fecf8 commit 6d8e315
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions torchmdnet/datasets/memdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,11 +240,10 @@ def get(self, idx):

props = {}
if "y" in self.properties:
props["y"] = pt.tensor(self.y_mm[idx], dtype=pt.float32).view(
1, 1
) # It would be better to use float64, but the trainer complains
y = self.y_mm[idx]
if self.remove_ref_energy:
props["y"] -= self.compute_reference_energy(z)
y -= self.compute_reference_energy(z)
props["y"] = pt.tensor(y, dtype=pt.float32).view(1, 1)
if "neg_dy" in self.properties:
props["neg_dy"] = pt.tensor(self.neg_dy_mm[atoms], dtype=pt.float32)
if "q" in self.properties:
Expand Down

0 comments on commit 6d8e315

Please sign in to comment.