Skip to content

Commit

Permalink
Fix type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede committed Apr 22, 2024
1 parent 93a1217 commit 5c3c3b9
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions src/tad_dftd3/disp.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
>>> print(torch.sum(energy[0] - energy[1] - energy[2])) # energy in Hartree
tensor(-0.0003964, dtype=torch.float64)
"""
from typing import Dict, Optional
from __future__ import annotations

import torch
from tad_mctc import storch
Expand All @@ -77,17 +77,17 @@
def dftd3(
numbers: Tensor,
positions: Tensor,
param: Dict[str, Tensor],
param: dict[str, Tensor],
*,
ref: Optional[Reference] = None,
rcov: Optional[Tensor] = None,
rvdw: Optional[Tensor] = None,
r4r2: Optional[Tensor] = None,
cutoff: Optional[Tensor] = None,
ref: Reference | None = None,
rcov: Tensor | None = None,
rvdw: Tensor | None = None,
r4r2: Tensor | None = None,
cutoff: Tensor | None = None,
counting_function: CountingFunction = ncoord.exp_count,
weighting_function: WeightingFunction = model.gaussian_weight,
damping_function: DampingFunction = rational_damping,
chunk_size: None | int = None,
chunk_size: int | None = None,
) -> Tensor:
"""
Evaluate DFT-D3 dispersion energy for a batch of geometries.
Expand Down Expand Up @@ -163,12 +163,12 @@ def dftd3(
def dispersion(
numbers: Tensor,
positions: Tensor,
param: Dict[str, Tensor],
param: dict[str, Tensor],
c6: Tensor,
rvdw: Optional[Tensor] = None,
r4r2: Optional[Tensor] = None,
rvdw: Tensor | None = None,
r4r2: Tensor | None = None,
damping_function: DampingFunction = rational_damping,
cutoff: Optional[Tensor] = None,
cutoff: Tensor | None = None,
**kwargs: Any,
) -> Tensor:
"""
Expand Down Expand Up @@ -236,7 +236,7 @@ def dispersion(
def dispersion2(
numbers: Tensor,
positions: Tensor,
param: Dict[str, Tensor],
param: dict[str, Tensor],
c6: Tensor,
r4r2: Tensor,
damping_function: DampingFunction,
Expand Down Expand Up @@ -296,7 +296,7 @@ def dispersion2(
def dispersion3(
numbers: Tensor,
positions: Tensor,
param: Dict[str, Tensor],
param: dict[str, Tensor],
c6: Tensor,
rvdw: Tensor,
cutoff: Tensor,
Expand Down

0 comments on commit 5c3c3b9

Please sign in to comment.