From 898a876230bcafc74aa006676d61701f70086d97 Mon Sep 17 00:00:00 2001 From: marvinfriede <51965259+marvinfriede@users.noreply.github.com> Date: Sun, 26 Nov 2023 22:43:47 +0100 Subject: [PATCH] Fixes --- src/tad_dftd4/charges.py | 2 -- test/conftest.py | 1 + test/test_charge/test_charges.py | 6 +++++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/tad_dftd4/charges.py b/src/tad_dftd4/charges.py index 7f8e2e5..3025c72 100644 --- a/src/tad_dftd4/charges.py +++ b/src/tad_dftd4/charges.py @@ -93,8 +93,6 @@ def __init__( self.eta = eta self.rad = rad - print(self.device) - print(self.chi.device) if any( tensor.device != self.device for tensor in (self.chi, self.kcn, self.eta, self.rad) diff --git a/test/conftest.py b/test/conftest.py index a0f0fa6..2d33490 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -18,6 +18,7 @@ """ Setup for pytest. """ +from __future__ import annotations import pytest import torch diff --git a/test/test_charge/test_charges.py b/test/test_charge/test_charges.py index 8769cbd..f23e06f 100644 --- a/test/test_charge/test_charges.py +++ b/test/test_charge/test_charges.py @@ -31,6 +31,8 @@ """ from __future__ import annotations +from math import sqrt + import pytest import torch @@ -45,7 +47,9 @@ @pytest.mark.parametrize("dtype", [torch.float, torch.double]) def test_single(dtype: torch.dtype): dd: DD = {"device": DEVICE, "dtype": dtype} - tol = torch.finfo(dtype).eps ** 0.5 * 10 + tol = sqrt(torch.finfo(dtype).eps) * 10 + print(tol) + print(sqrt(torch.finfo(dtype).eps) * 10) sample = samples["NH3-dimer"] numbers = sample["numbers"].to(DEVICE)