Skip to content

Commit

Permalink
Fix functorch error for 2.0.x (#59)
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede authored May 3, 2024
1 parent ba00675 commit d94e75d
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 25 deletions.
2 changes: 1 addition & 1 deletion src/tad_dftd3/__version__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
"""
Version module for *tad-dftd3*.
"""
__version__ = "0.2.1"
__version__ = "0.2.2"
19 changes: 19 additions & 0 deletions src/tad_dftd3/model/c6.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,25 @@ def atomic_c6(
"""
_check_memory(numbers, weights, chunk_size)

# PyTorch 2.0.x has a bug with functorch and custom autograd functions as
# documented in: https://github.com/pytorch/pytorch/issues/99973
#
# RuntimeError: unwrapped_count > 0 INTERNAL ASSERT FAILED at "../aten/src/
# ATen/functorch/TensorWrapper.cpp":202, please report a bug to PyTorch.
# Should have at least one dead wrapper
#
# Hence, we cannot use the custom backwards for reduced memory consumption.
if __tversion__[0] == 2 and __tversion__[1] == 0: # pragma: no cover
track_weights = torch._C._functorch.is_gradtrackingtensor(weights)
track_numbers = torch._C._functorch.is_gradtrackingtensor(numbers)
if track_weights or track_numbers:

if chunk_size is None:
return _atomic_c6_full(numbers, weights, reference)

return _atomic_c6_chunked(numbers, weights, reference, chunk_size)

# Use custom autograd function for reduced memory consumption
AtomicC6 = AtomicC6_V1 if __tversion__ < (2, 0, 0) else AtomicC6_V2
res = AtomicC6.apply(numbers, weights, reference, chunk_size)
assert res is not None
Expand Down
85 changes: 61 additions & 24 deletions test/test_grad/test_pos.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import pytest
import torch
from tad_mctc.autograd import dgradcheck, dgradgradcheck
from tad_mctc.autograd import dgradcheck, dgradgradcheck, jacrev
from tad_mctc.batch import pack

from tad_dftd3 import dftd3
Expand Down Expand Up @@ -151,29 +151,29 @@ def test_gradgradcheck_batch(dtype: torch.dtype, name1: str, name2: str) -> None
@pytest.mark.parametrize("name", sample_list)
def test_autograd(dtype: torch.dtype, name: str) -> None:
"""Compare with reference values from tblite."""
dd: DD = {"device": DEVICE, "dtype": dtype}

sample = samples[name]
numbers = sample["numbers"]
positions = sample["positions"].type(dtype)
numbers = sample["numbers"].to(DEVICE)
positions = sample["positions"].to(**dd)

ref = sample["grad"].to(**dd)

# GFN1-xTB parameters
param = {
"s6": positions.new_tensor(1.00000000),
"s8": positions.new_tensor(2.40000000),
"s9": positions.new_tensor(0.00000000),
"a1": positions.new_tensor(0.63000000),
"a2": positions.new_tensor(5.00000000),
"s6": torch.tensor(1.00000000, **dd),
"s8": torch.tensor(2.40000000, **dd),
"s9": torch.tensor(0.00000000, **dd),
"a1": torch.tensor(0.63000000, **dd),
"a2": torch.tensor(5.00000000, **dd),
}

ref = sample["grad"].type(dtype)

# variable to be differentiated
positions.requires_grad_(True)
pos = positions.clone().requires_grad_(True)

# automatic gradient
energy = torch.sum(dftd3(numbers, positions, param))
(grad,) = torch.autograd.grad(energy, positions)

positions.detach_()
energy = torch.sum(dftd3(numbers, pos, param))
(grad,) = torch.autograd.grad(energy, pos)

assert pytest.approx(ref.cpu(), abs=tol) == grad.cpu()

Expand All @@ -183,21 +183,23 @@ def test_autograd(dtype: torch.dtype, name: str) -> None:
@pytest.mark.parametrize("name", sample_list)
def test_backward(dtype: torch.dtype, name: str) -> None:
"""Compare with reference values from tblite."""
dd: DD = {"device": DEVICE, "dtype": dtype}

sample = samples[name]
numbers = sample["numbers"]
positions = sample["positions"].type(dtype)
numbers = sample["numbers"].to(DEVICE)
positions = sample["positions"].to(**dd)

ref = sample["grad"].to(**dd)

# GFN1-xTB parameters
param = {
"s6": positions.new_tensor(1.00000000),
"s8": positions.new_tensor(2.40000000),
"s9": positions.new_tensor(0.00000000),
"a1": positions.new_tensor(0.63000000),
"a2": positions.new_tensor(5.00000000),
"s6": torch.tensor(1.00000000, **dd),
"s8": torch.tensor(2.40000000, **dd),
"s9": torch.tensor(0.00000000, **dd),
"a1": torch.tensor(0.63000000, **dd),
"a2": torch.tensor(5.00000000, **dd),
}

ref = sample["grad"].type(dtype)

# variable to be differentiated
positions.requires_grad_(True)

Expand All @@ -213,3 +215,38 @@ def test_backward(dtype: torch.dtype, name: str) -> None:
positions.grad.data.zero_()

assert pytest.approx(ref.cpu(), abs=tol) == grad_backward.cpu()


@pytest.mark.grad
@pytest.mark.parametrize("dtype", [torch.double])
@pytest.mark.parametrize("name", sample_list)
def test_functorch(dtype: torch.dtype, name: str) -> None:
"""Compare with reference values from tblite."""
dd: DD = {"device": DEVICE, "dtype": dtype}

sample = samples[name]
numbers = sample["numbers"].to(DEVICE)
positions = sample["positions"].to(**dd)

ref = sample["grad"].to(**dd)

# GFN1-xTB parameters
param = {
"s6": torch.tensor(1.00000000, **dd),
"s8": torch.tensor(2.40000000, **dd),
"s9": torch.tensor(0.00000000, **dd),
"a1": torch.tensor(0.63000000, **dd),
"a2": torch.tensor(5.00000000, **dd),
}

# variable to be differentiated
pos = positions.clone().requires_grad_(True)

def dftd3_func(p: Tensor) -> Tensor:
return dftd3(numbers, p, param).sum()

grad = jacrev(dftd3_func)(pos)
assert isinstance(grad, Tensor)

assert grad.shape == ref.shape
assert pytest.approx(ref.cpu(), abs=tol) == grad.detach().cpu()

0 comments on commit d94e75d

Please sign in to comment.