Skip to content

Commit

Permalink
Typing refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede committed Dec 21, 2023
1 parent 9db9647 commit a2f058a
Show file tree
Hide file tree
Showing 40 changed files with 7,017 additions and 6,802 deletions.
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,10 @@ repos:
hooks:
- id: black
stages: [commit]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.7.1
hooks:
- id: mypy
additional_dependencies: [types-all]
exclude: 'test/conftest.py'
2 changes: 2 additions & 0 deletions docs/modules/data/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@

hardness
r4r2
radii
zeff
1 change: 1 addition & 0 deletions docs/modules/data/radii.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.. automodule:: tad_dftd4.data.radii
2 changes: 2 additions & 0 deletions docs/modules/data/zeff.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.. automodule:: tad_dftd4.data.zeff
:members:
2 changes: 2 additions & 0 deletions docs/modules/typing/builtin.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.. automodule:: tad_dftd4.typing.builtin
:members:
8 changes: 8 additions & 0 deletions docs/modules/typing/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
.. _typing:

.. automodule:: tad_dftd4.typing

.. toctree::

builtin
pytorch
2 changes: 2 additions & 0 deletions docs/modules/typing/pytorch.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.. automodule:: tad_dftd4.typing.pytorch
:members:
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ disallow_untyped_defs = true
warn_redundant_casts = true
warn_unreachable = true
warn_unused_ignores = true
exclude = '''
(?x)
^test?s/conftest.py$
'''


[tool.coverage.run]
Expand Down
2 changes: 1 addition & 1 deletion src/tad_dftd4/cutoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
from __future__ import annotations

import torch
from tad_mctc.typing import Tensor, TensorLike

from . import defaults
from .typing import Tensor, TensorLike


class Cutoff(TensorLike):
Expand Down
8 changes: 4 additions & 4 deletions src/tad_dftd4/damping/atm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,9 @@
import torch
from tad_mctc import storch
from tad_mctc.batch import real_pairs, real_triples
from tad_mctc.typing import DD, Tensor

from .. import defaults
from ..data import r4r2
from .. import data, defaults
from ..typing import DD, Tensor


def get_atm_dispersion(
Expand Down Expand Up @@ -103,7 +102,8 @@ def get_atm_dispersion(
torch.abs(c6.unsqueeze(-1) * c6.unsqueeze(-2) * c6.unsqueeze(-3)),
)

radii = r4r2[numbers].unsqueeze(-1) * r4r2[numbers].unsqueeze(-2)
rad = data.R4R2[numbers]
radii = rad.unsqueeze(-1) * rad.unsqueeze(-2)
temp = a1 * storch.sqrt(3.0 * radii) + a2

r0ij = temp.unsqueeze(-1)
Expand Down
2 changes: 1 addition & 1 deletion src/tad_dftd4/damping/rational.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@
from __future__ import annotations

import torch
from tad_mctc.typing import Tensor

from .. import defaults
from ..typing import Tensor


def rational_damping(
Expand Down
4 changes: 3 additions & 1 deletion src/tad_dftd4/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
- covalent radii
- effective nuclear charge
Some atomic data is imported from the `tad_mctc` library.
Some atomic data is imported from the `tad_mctc` library or indirectly used within the `tad_mctc` library.
"""
from .hardness import *
from .r4r2 import *
from .radii import *
from .zeff import *
10 changes: 5 additions & 5 deletions src/tad_dftd4/data/hardness.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@
# You should have received a copy of the GNU Lesser General Public License
# along with tad-dftd4. If not, see <https://www.gnu.org/licenses/>.
"""
Atomic data: Chemical hardnesses
================================
Data: Chemical hardnesses
=========================
Element-specific chemical hardnesses for the charge scaling function used
to extrapolate the C6 coefficients in DFT-D4.
"""

import torch

__all__ = ["gam"]
__all__ = ["GAM"]


gam = torch.tensor(
GAM = torch.tensor(
[
0.00000000, # None
0.47259288, # H
Expand Down
6 changes: 3 additions & 3 deletions src/tad_dftd4/data/r4r2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
Also new super heavies Cn, Nh, Fl, Lv, Og and Am-Rg calculated at
4c-PBE/Dyall-AE4Z level (Dirac 2022).
"""

import torch

__all__ = ["r4r2"]
__all__ = ["R4R2"]


# fmt: off
r4_over_r2 = torch.tensor([
Expand Down Expand Up @@ -60,5 +60,5 @@
# fmt: on


r4r2 = torch.sqrt(0.5 * (r4_over_r2 * torch.sqrt(torch.arange(r4_over_r2.shape[0]))))
R4R2 = torch.sqrt(0.5 * (r4_over_r2 * torch.sqrt(torch.arange(r4_over_r2.shape[0]))))
"""r⁴ over r² expectation values."""
26 changes: 26 additions & 0 deletions src/tad_dftd4/data/radii.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# This file is part of tad-dftd4.
#
# SPDX-Identifier: LGPL-3.0
# Copyright (C) 2022 Marvin Friede
#
# tad-dftd4 is free software: you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# tad-dftd4 is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with tad-dftd4. If not, see <https://www.gnu.org/licenses/>.
"""
Data: Radii
===========
Covalent radii (imported from *tad-mctc*).
"""
from tad_mctc.data import COV_D3

__all__ = ["COV_D3"]
26 changes: 26 additions & 0 deletions src/tad_dftd4/data/zeff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# This file is part of tad-dftd4.
#
# SPDX-Identifier: LGPL-3.0
# Copyright (C) 2022 Marvin Friede
#
# tad-dftd4 is free software: you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# tad-dftd4 is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with tad-dftd4. If not, see <https://www.gnu.org/licenses/>.
"""
Data: Charges
=============
Effective charges (imported from *tad-mctc*).
"""
from tad_mctc.data.zeff import ZEFF

__all__ = ["ZEFF"]
7 changes: 3 additions & 4 deletions src/tad_dftd4/disp.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,14 @@
import torch
from tad_mctc import storch
from tad_mctc.batch import real_pairs
from tad_mctc.data import radii
from tad_mctc.ncoord import cn_d4, erf_count
from tad_mctc.typing import DD, Any, CountingFunction, DampingFunction, Tensor
from tad_multicharge.eeq import get_charges

from . import data, defaults
from .cutoff import Cutoff
from .damping import get_atm_dispersion, rational_damping
from .model import D4Model
from .typing import DD, Any, CountingFunction, DampingFunction, Tensor


def dftd4(
Expand Down Expand Up @@ -108,9 +107,9 @@ def dftd4(
cutoff = Cutoff(**dd)

if rcov is None:
rcov = radii.COV_D3.to(**dd)[numbers]
rcov = data.COV_D3.to(**dd)[numbers]
if r4r2 is None:
r4r2 = data.r4r2.to(**dd)[numbers]
r4r2 = data.R4R2.to(**dd)[numbers]
if q is None:
q = get_charges(numbers, positions, charge, cutoff=cutoff.cn_eeq)

Expand Down
11 changes: 5 additions & 6 deletions src/tad_dftd4/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,9 @@
from __future__ import annotations

import torch
from tad_mctc.data.zeff import ZEFF
from tad_mctc.typing import Tensor, TensorLike

from . import data, params
from .typing import Tensor, TensorLike

ga_default = 3.0
gc_default = 2.0
Expand Down Expand Up @@ -230,8 +229,8 @@ def refc_pow(n: int) -> Tensor:
)

# unsqueeze for reference dimension
zeff = ZEFF.to(self.device)[self.numbers].unsqueeze(-1)
gam = data.gam.to(**self.dd)[self.numbers].unsqueeze(-1) * self.gc
zeff = data.ZEFF.to(self.device)[self.numbers].unsqueeze(-1)
gam = data.GAM.to(**self.dd)[self.numbers].unsqueeze(-1) * self.gc
q = q.unsqueeze(-1)

# charge scaling
Expand Down Expand Up @@ -322,8 +321,8 @@ def _set_refalpha_eeq(self) -> Tensor:

mask = refsys > 0

zeff = ZEFF.to(self.device)[refsys]
gam = data.gam.to(**self.dd)[refsys] * self.gc
zeff = data.ZEFF.to(self.device)[refsys]
gam = data.GAM.to(**self.dd)[refsys] * self.gc

# charge scaling
zeta = torch.where(
Expand Down
25 changes: 25 additions & 0 deletions src/tad_dftd4/typing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# This file is part of tad-dftd4.
#
# SPDX-Identifier: LGPL-3.0
# Copyright (C) 2022 Marvin Friede
#
# tad-dftd4 is free software: you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# tad-dftd4 is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with tad-dftd4. If not, see <https://www.gnu.org/licenses/>.
"""
Type annotations
================
All type annotations for this project.
"""
from .builtin import *
from .pytorch import *
25 changes: 25 additions & 0 deletions src/tad_dftd4/typing/builtin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# This file is part of tad-dftd4.
#
# SPDX-Identifier: LGPL-3.0
# Copyright (C) 2022 Marvin Friede
#
# tad-dftd4 is free software: you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# tad-dftd4 is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with tad-dftd4. If not, see <https://www.gnu.org/licenses/>.
"""
Type annotations: Built-ins
===========================
Built-in type annotations are imported from the *tad-mctc* library, which
handles some version checking.
"""
from tad_mctc.typing import Any, Callable, NoReturn, TypedDict
34 changes: 34 additions & 0 deletions src/tad_dftd4/typing/pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# This file is part of tad-dftd4.
#
# SPDX-Identifier: LGPL-3.0
# Copyright (C) 2022 Marvin Friede
#
# tad-dftd4 is free software: you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# tad-dftd4 is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with tad-dftd4. If not, see <https://www.gnu.org/licenses/>.
"""
Type annotations: PyTorch
=========================
PyTorch-related type annotations for this project.
"""
from tad_mctc.typing import (
DD,
CountingFunction,
DampingFunction,
Molecule,
Tensor,
TensorLike,
TensorOrTensors,
get_default_device,
get_default_dtype,
)
6 changes: 3 additions & 3 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,9 @@ def pytest_configure(config: pytest.Config) -> None:
torch.autograd.anomaly_mode.set_detect_anomaly(True)

if config.getoption("--jit"):
torch.jit._state.enable() # type: ignore # pylint: disable=protected-access
torch.jit._state.enable() # type:ignore # pylint: disable=protected-access
else:
torch.jit._state.disable() # type: ignore # pylint: disable=protected-access
torch.jit._state.disable() # type:ignore # pylint: disable=protected-access

if config.getoption("--fast"):
FAST_MODE = True
Expand Down Expand Up @@ -147,7 +147,7 @@ def pytest_configure(config: pytest.Config) -> None:
if torch.__version__ < (2, 0, 0): # type: ignore
torch.set_default_tensor_type("torch.cuda.FloatTensor") # type: ignore
else:
torch.set_default_device(DEVICE) # type: ignore
torch.set_default_device(DEVICE) # type: ignore[attr-defined]
else:
torch.use_deterministic_algorithms(True)
DEVICE = None
Expand Down
2 changes: 1 addition & 1 deletion test/test_cutoff/test_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,4 @@ def test_change_device_fail() -> None:

# trying to use setter
with pytest.raises(AttributeError):
cutoff.device = "cpu"
cutoff.device = torch.device("cpu")
Loading

0 comments on commit a2f058a

Please sign in to comment.