diff --git a/.github/workflows/python.yaml b/.github/workflows/python.yaml index 21cae53..0085bd3 100644 --- a/.github/workflows/python.yaml +++ b/.github/workflows/python.yaml @@ -25,14 +25,19 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, macos-latest, windows-latest] - python-version: ["3.8", "3.9", "3.10", "3.11"] - # PyTorch does not support Python 3.11 on non-Linux platforms + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + # PyTorch now fully supports Python =<3.11 # see: https://github.com/pytorch/pytorch/issues/86566 + # + # PyTorch does not support Python 3.12 (all platforms) + # see: https://github.com/pytorch/pytorch/issues/110436 exclude: + - os: ubuntu-latest + python-version: "3.12" - os: macos-latest - python-version: "3.11" + python-version: "3.12" - os: windows-latest - python-version: "3.11" + python-version: "3.12" defaults: run: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e87e98d..0fe531b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. ci: - skip: [mypy] + skip: [mypy] repos: - repo: https://github.com/pre-commit/pre-commit-hooks @@ -59,10 +59,3 @@ repos: hooks: - id: black stages: [commit] - - - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.8.0 - hooks: - - id: mypy - additional_dependencies: [types-all] - exclude: 'test/conftest.py' diff --git a/pyproject.toml b/pyproject.toml index b932220..8ef3613 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,10 +34,9 @@ disallow_untyped_defs = true warn_redundant_casts = true warn_unreachable = true warn_unused_ignores = true -exclude = ''' - (?x) - ^test?s/conftest.py$ -''' +exclude = '''(?x)( + test/conftest.py +)''' [tool.coverage.run] diff --git a/test/test_grad/test_hessian.py b/test/test_grad/test_hessian.py index 8d9fe12..abc0c90 100644 --- a/test/test_grad/test_hessian.py +++ b/test/test_grad/test_hessian.py @@ -41,7 +41,7 @@ def test_fail() -> None: param = {"a1": numbers} # differentiable variable is not a tensor - with pytest.raises(RuntimeError): + with pytest.raises(ValueError): hessian(dftd3, (numbers, positions, param), argnums=2) diff --git a/test/test_model/test_reference.py b/test/test_model/test_reference.py index 05a55ab..99a6148 100644 --- a/test/test_model/test_reference.py +++ b/test/test_model/test_reference.py @@ -15,14 +15,15 @@ """ Test the reference. """ -from typing import Union +from typing import Optional, Union +from unittest.mock import patch import pytest import torch from tad_mctc.convert import str_to_device from tad_dftd3 import reference -from tad_dftd3.typing import DD +from tad_dftd3.typing import DD, Any, Tensor, TypedDict from ..conftest import DEVICE @@ -37,8 +38,12 @@ def test_reference_dtype(dtype: torch.dtype) -> None: @pytest.mark.parametrize("dtype", [torch.float16, None]) def test_reference_dtype_both(dtype: Union[torch.dtype, None]) -> None: + class DDNone(TypedDict): + device: torch.device + dtype: Optional[torch.dtype] + dev = torch.device("cpu") - dd = {"device": dev, "dtype": dtype} + dd: DDNone = {"device": dev, "dtype": dtype} ref = reference.Reference(device=dev).to(**dd) assert ref.dtype == torch.tensor(1.0, dtype=dtype).dtype @@ -63,6 +68,38 @@ def test_reference_device(device_str: str, device_str2: str) -> None: ref.device = device +def test_reference_different_devices() -> None: + # Custom Tensor class with overridable device property + class MockTensor(Tensor): + @property + def device(self) -> Any: + return self._device + + @device.setter + def device(self, value: Any) -> None: + self._device = value + + # Custom mock functions + def mock_load_cn(*_: Any, **__: Any) -> Tensor: + tensor = MockTensor([1, 2, 3]) + tensor.device = torch.device("cpu") + return tensor + + def mock_load_c6(*_: Any, **__: Any) -> Tensor: + tensor = MockTensor([4, 5, 6]) + tensor.device = torch.device("cuda") + return tensor + + with patch("tad_dftd3.reference._load_cn", new=mock_load_cn): + with patch("tad_dftd3.reference._load_c6", new=mock_load_c6): + with pytest.raises(RuntimeError) as exc: + # Assuming the device is not explicitly passed, so it picks + # from _load_cn and _load_c6 + reference.Reference() + + assert "All tensors must be on the same device!" in str(exc.value) + + def test_reference_fail() -> None: c6 = reference._load_c6() # pylint: disable=protected-access