From d94e75d2cae24bc9f96576af6deb0b40b66f918c Mon Sep 17 00:00:00 2001 From: Marvin Friede <51965259+marvinfriede@users.noreply.github.com> Date: Fri, 3 May 2024 17:16:19 +0200 Subject: [PATCH] Fix functorch error for 2.0.x (#59) --- src/tad_dftd3/__version__.py | 2 +- src/tad_dftd3/model/c6.py | 19 ++++++++ test/test_grad/test_pos.py | 85 ++++++++++++++++++++++++++---------- 3 files changed, 81 insertions(+), 25 deletions(-) diff --git a/src/tad_dftd3/__version__.py b/src/tad_dftd3/__version__.py index e713393..25dbdbf 100644 --- a/src/tad_dftd3/__version__.py +++ b/src/tad_dftd3/__version__.py @@ -15,4 +15,4 @@ """ Version module for *tad-dftd3*. """ -__version__ = "0.2.1" +__version__ = "0.2.2" diff --git a/src/tad_dftd3/model/c6.py b/src/tad_dftd3/model/c6.py index e55c3de..a146202 100644 --- a/src/tad_dftd3/model/c6.py +++ b/src/tad_dftd3/model/c6.py @@ -69,6 +69,25 @@ def atomic_c6( """ _check_memory(numbers, weights, chunk_size) + # PyTorch 2.0.x has a bug with functorch and custom autograd functions as + # documented in: https://github.com/pytorch/pytorch/issues/99973 + # + # RuntimeError: unwrapped_count > 0 INTERNAL ASSERT FAILED at "../aten/src/ + # ATen/functorch/TensorWrapper.cpp":202, please report a bug to PyTorch. + # Should have at least one dead wrapper + # + # Hence, we cannot use the custom backwards for reduced memory consumption. + if __tversion__[0] == 2 and __tversion__[1] == 0: # pragma: no cover + track_weights = torch._C._functorch.is_gradtrackingtensor(weights) + track_numbers = torch._C._functorch.is_gradtrackingtensor(numbers) + if track_weights or track_numbers: + + if chunk_size is None: + return _atomic_c6_full(numbers, weights, reference) + + return _atomic_c6_chunked(numbers, weights, reference, chunk_size) + + # Use custom autograd function for reduced memory consumption AtomicC6 = AtomicC6_V1 if __tversion__ < (2, 0, 0) else AtomicC6_V2 res = AtomicC6.apply(numbers, weights, reference, chunk_size) assert res is not None diff --git a/test/test_grad/test_pos.py b/test/test_grad/test_pos.py index 277c0b8..e1693b4 100644 --- a/test/test_grad/test_pos.py +++ b/test/test_grad/test_pos.py @@ -19,7 +19,7 @@ import pytest import torch -from tad_mctc.autograd import dgradcheck, dgradgradcheck +from tad_mctc.autograd import dgradcheck, dgradgradcheck, jacrev from tad_mctc.batch import pack from tad_dftd3 import dftd3 @@ -151,29 +151,29 @@ def test_gradgradcheck_batch(dtype: torch.dtype, name1: str, name2: str) -> None @pytest.mark.parametrize("name", sample_list) def test_autograd(dtype: torch.dtype, name: str) -> None: """Compare with reference values from tblite.""" + dd: DD = {"device": DEVICE, "dtype": dtype} + sample = samples[name] - numbers = sample["numbers"] - positions = sample["positions"].type(dtype) + numbers = sample["numbers"].to(DEVICE) + positions = sample["positions"].to(**dd) + + ref = sample["grad"].to(**dd) # GFN1-xTB parameters param = { - "s6": positions.new_tensor(1.00000000), - "s8": positions.new_tensor(2.40000000), - "s9": positions.new_tensor(0.00000000), - "a1": positions.new_tensor(0.63000000), - "a2": positions.new_tensor(5.00000000), + "s6": torch.tensor(1.00000000, **dd), + "s8": torch.tensor(2.40000000, **dd), + "s9": torch.tensor(0.00000000, **dd), + "a1": torch.tensor(0.63000000, **dd), + "a2": torch.tensor(5.00000000, **dd), } - ref = sample["grad"].type(dtype) - # variable to be differentiated - positions.requires_grad_(True) + pos = positions.clone().requires_grad_(True) # automatic gradient - energy = torch.sum(dftd3(numbers, positions, param)) - (grad,) = torch.autograd.grad(energy, positions) - - positions.detach_() + energy = torch.sum(dftd3(numbers, pos, param)) + (grad,) = torch.autograd.grad(energy, pos) assert pytest.approx(ref.cpu(), abs=tol) == grad.cpu() @@ -183,21 +183,23 @@ def test_autograd(dtype: torch.dtype, name: str) -> None: @pytest.mark.parametrize("name", sample_list) def test_backward(dtype: torch.dtype, name: str) -> None: """Compare with reference values from tblite.""" + dd: DD = {"device": DEVICE, "dtype": dtype} + sample = samples[name] - numbers = sample["numbers"] - positions = sample["positions"].type(dtype) + numbers = sample["numbers"].to(DEVICE) + positions = sample["positions"].to(**dd) + + ref = sample["grad"].to(**dd) # GFN1-xTB parameters param = { - "s6": positions.new_tensor(1.00000000), - "s8": positions.new_tensor(2.40000000), - "s9": positions.new_tensor(0.00000000), - "a1": positions.new_tensor(0.63000000), - "a2": positions.new_tensor(5.00000000), + "s6": torch.tensor(1.00000000, **dd), + "s8": torch.tensor(2.40000000, **dd), + "s9": torch.tensor(0.00000000, **dd), + "a1": torch.tensor(0.63000000, **dd), + "a2": torch.tensor(5.00000000, **dd), } - ref = sample["grad"].type(dtype) - # variable to be differentiated positions.requires_grad_(True) @@ -213,3 +215,38 @@ def test_backward(dtype: torch.dtype, name: str) -> None: positions.grad.data.zero_() assert pytest.approx(ref.cpu(), abs=tol) == grad_backward.cpu() + + +@pytest.mark.grad +@pytest.mark.parametrize("dtype", [torch.double]) +@pytest.mark.parametrize("name", sample_list) +def test_functorch(dtype: torch.dtype, name: str) -> None: + """Compare with reference values from tblite.""" + dd: DD = {"device": DEVICE, "dtype": dtype} + + sample = samples[name] + numbers = sample["numbers"].to(DEVICE) + positions = sample["positions"].to(**dd) + + ref = sample["grad"].to(**dd) + + # GFN1-xTB parameters + param = { + "s6": torch.tensor(1.00000000, **dd), + "s8": torch.tensor(2.40000000, **dd), + "s9": torch.tensor(0.00000000, **dd), + "a1": torch.tensor(0.63000000, **dd), + "a2": torch.tensor(5.00000000, **dd), + } + + # variable to be differentiated + pos = positions.clone().requires_grad_(True) + + def dftd3_func(p: Tensor) -> Tensor: + return dftd3(numbers, p, param).sum() + + grad = jacrev(dftd3_func)(pos) + assert isinstance(grad, Tensor) + + assert grad.shape == ref.shape + assert pytest.approx(ref.cpu(), abs=tol) == grad.detach().cpu()