Skip to content

Commit

Permalink
Cleanup Python API for dftd3.library
Browse files Browse the repository at this point in the history
  • Loading branch information
awvwgk committed Dec 2, 2024
1 parent 8018f94 commit 36c3671
Show file tree
Hide file tree
Showing 4 changed files with 265 additions and 148 deletions.
139 changes: 78 additions & 61 deletions python/dftd3/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
library in actual workflows than the low-level access provided in the CFFI generated wrappers.
"""

from typing import List, Optional, Union
from typing import Optional
import numpy as np

from . import library
Expand Down Expand Up @@ -58,7 +58,7 @@ class Structure:
on invalid input, like incorrect shape / type of the passed arrays
"""

_mol = library.ffi.NULL
_mol = library.StructureHandle.null()

def __init__(
self,
Expand Down Expand Up @@ -104,10 +104,10 @@ def __init__(

self._mol = library.new_structure(
self._natoms,
_cast("int*", _numbers),
_cast("double*", _positions),
_cast("double*", _lattice),
_cast("bool*", _periodic),
_numbers,
_positions,
_lattice,
_periodic,
)

def __len__(self):
Expand Down Expand Up @@ -147,8 +147,8 @@ def update(

library.update_structure(
self._mol,
_cast("double*", _positions),
_cast("double*", _lattice),
_positions,
_lattice,
)


Expand Down Expand Up @@ -176,7 +176,7 @@ class DampingParam:
of the object should use the second method.
"""

_param = library.ffi.NULL
_param = library.ParamHandle.null()

def __init__(self, **kwargs):
"""Create new damping parameter from method name or explicit data"""
Expand All @@ -190,11 +190,11 @@ def __init__(self, **kwargs):
self._param = self.new_param(**kwargs)

@staticmethod
def load_param(method, **kwargs):
def load_param(method, **kwargs) -> library.ParamHandle:
raise NotImplementedError("Child class has to define parameter loading")

@staticmethod
def new_param(**kwargs):
def new_param(**kwargs) -> library.ParamHandle:
raise NotImplementedError("Child class has to define parameter construction")


Expand All @@ -218,15 +218,22 @@ def __init__(self, **kwargs):
DampingParam.__init__(self, **kwargs)

@staticmethod
def load_param(method, atm=True):
_method = library.ffi.new("char[]", method.encode())
def load_param(method: str, atm: bool = True) -> library.ParamHandle:
return library.load_rational_damping(
_method,
method,
atm,
)

@staticmethod
def new_param(*, s6=1.0, s8, s9=1.0, a1, a2, alp=14.0):
def new_param(
*,
s6: float = 1.0,
s8: float,
s9: float = 1.0,
a1: float,
a2: float,
alp: float = 14.0,
) -> library.ParamHandle:
return library.new_rational_damping(
s6,
s8,
Expand Down Expand Up @@ -254,15 +261,22 @@ def __init__(self, **kwargs):
DampingParam.__init__(self, **kwargs)

@staticmethod
def load_param(method, atm=True):
_method = library.ffi.new("char[]", method.encode())
def load_param(method: str, atm: bool = True) -> library.ParamHandle:
return library.load_zero_damping(
_method,
method,
atm,
)

@staticmethod
def new_param(*, s6=1.0, s8, s9=1.0, rs6, rs8=1.0, alp=14.0):
def new_param(
*,
s6: float = 1.0,
s8: float,
s9: float = 1.0,
rs6: float,
rs8: float = 1.0,
alp: float = 14.0,
) -> library.ParamHandle:
return library.new_zero_damping(
s6,
s8,
Expand All @@ -289,15 +303,22 @@ def __init__(self, **kwargs):
DampingParam.__init__(self, **kwargs)

@staticmethod
def load_param(method, atm=True):
_method = library.ffi.new("char[]", method.encode())
def load_param(method: str, atm: bool = True) -> library.ParamHandle:
return library.load_mrational_damping(
_method,
method,
atm,
)

@staticmethod
def new_param(*, s6=1.0, s8, s9=1.0, a1, a2, alp=14.0):
def new_param(
*,
s6: float = 1.0,
s8: float,
s9: float = 1.0,
a1: float,
a2: float,
alp: float = 14.0,
) -> library.ParamHandle:
return library.new_mrational_damping(
s6,
s8,
Expand Down Expand Up @@ -327,15 +348,23 @@ def __init__(self, **kwargs):
DampingParam.__init__(self, **kwargs)

@staticmethod
def load_param(method, atm=True):
_method = library.ffi.new("char[]", method.encode())
def load_param(method: str, atm: bool = True) -> library.ParamHandle:
return library.load_mzero_damping(
_method,
method,
atm,
)

@staticmethod
def new_param(*, s6=1.0, s8, s9=1.0, rs6, rs8=1.0, alp=14.0, bet):
def new_param(
*,
s6: float = 1.0,
s8: float,
s9: float = 1.0,
rs6: float,
rs8: float = 1.0,
alp: float = 14.0,
bet: float,
) -> library.ParamHandle:
return library.new_mzero_damping(
s6,
s8,
Expand Down Expand Up @@ -364,15 +393,23 @@ def __init__(self, **kwargs):
DampingParam.__init__(self, **kwargs)

@staticmethod
def load_param(method, atm=True):
_method = library.ffi.new("char[]", method.encode())
def load_param(method: str, atm: bool = True) -> library.ParamHandle:
return library.load_optimizedpower_damping(
_method,
method,
atm,
)

@staticmethod
def new_param(*, s6=1.0, s8, s9=1.0, a1, a2, alp=14.0, bet):
def new_param(
*,
s6: float = 1.0,
s8: float,
s9: float = 1.0,
a1: float,
a2: float,
alp: float = 14.0,
bet,
) -> library.ParamHandle:
return library.new_optimizedpower_damping(
s6,
s8,
Expand All @@ -394,7 +431,7 @@ class DispersionModel(Structure):
input.
"""

_disp = library.ffi.NULL
_disp = library.ModelHandle.null()

def __init__(
self,
Expand Down Expand Up @@ -427,9 +464,9 @@ def get_dispersion(self, param: DampingParam, grad: bool) -> dict:
self._mol,
self._disp,
param._param,
_cast("double*", _energy),
_cast("double*", _gradient),
_cast("double*", _sigma),
_energy,
_gradient,
_sigma,
)

results = dict(energy=_energy)
Expand All @@ -449,8 +486,8 @@ def get_pairwise_dispersion(self, param: DampingParam) -> dict:
self._mol,
self._disp,
param._param,
_cast("double*", _pair_disp2),
_cast("double*", _pair_disp3),
_pair_disp2,
_pair_disp3,
)

return {
Expand All @@ -469,7 +506,7 @@ class GeometricCounterpoise(Structure):
superposition error (BSSE).
"""

_gcp = library.ffi.NULL
_gcp = library.GCPHandle.null()

def __init__(
self,
Expand All @@ -482,14 +519,12 @@ def __init__(
):
Structure.__init__(self, numbers, positions, lattice, periodic)

Check warning on line 520 in python/dftd3/interface.py

View check run for this annotation

Codecov / codecov/patch

python/dftd3/interface.py#L520

Added line #L520 was not covered by tests

_method = library.ffi.new("char[]", method.encode()) if method else library.ffi.NULL
_basis = library.ffi.new("char[]", basis.encode()) if basis else library.ffi.NULL
self._gcp = library.load_gcp_param(self._mol, _method, _basis)
self._gcp = library.load_gcp_param(self._mol, method, basis)

Check warning on line 522 in python/dftd3/interface.py

View check run for this annotation

Codecov / codecov/patch

python/dftd3/interface.py#L522

Added line #L522 was not covered by tests

def set_realspace_cutoff(self, bas: float, srb: float):
"""Set realspace cutoff for evaluation of interactions"""

library.set_gcp_realspace_cutoff(self._disp, bas, srb)
library.set_gcp_realspace_cutoff(self._gcp, bas, srb)

Check warning on line 527 in python/dftd3/interface.py

View check run for this annotation

Codecov / codecov/patch

python/dftd3/interface.py#L527

Added line #L527 was not covered by tests

def get_counterpoise(self, grad: bool) -> dict:
"""Evaluate the counterpoise corrected interaction energy"""
Expand All @@ -502,7 +537,7 @@ def get_counterpoise(self, grad: bool) -> dict:
_gradient = None
_sigma = None

Check warning on line 538 in python/dftd3/interface.py

View check run for this annotation

Codecov / codecov/patch

python/dftd3/interface.py#L537-L538

Added lines #L537 - L538 were not covered by tests

library.get_counterpoise(self._mol, self._gcp, _cast("double*", _energy), _cast("double*", _gradient), _cast("double*", _sigma))
library.get_counterpoise(self._mol, self._gcp, _energy, _gradient, _sigma)

Check warning on line 540 in python/dftd3/interface.py

View check run for this annotation

Codecov / codecov/patch

python/dftd3/interface.py#L540

Added line #L540 was not covered by tests

results = dict(energy=_energy)
if _gradient is not None:
Expand All @@ -512,24 +547,6 @@ def get_counterpoise(self, grad: bool) -> dict:
return results

Check warning on line 547 in python/dftd3/interface.py

View check run for this annotation

Codecov / codecov/patch

python/dftd3/interface.py#L542-L547

Added lines #L542 - L547 were not covered by tests


def _cast(ctype, array):
"""Cast a numpy array to a FFI pointer"""
return (
library.ffi.NULL
if array is None
else library.ffi.cast(ctype, array.ctypes.data)
)


def _ref(ctype, value):
"""Create a reference to a value"""
if value is None:
return library.ffi.NULL
ref = library.ffi.new(ctype + "*")
ref[0] = value
return ref


def _rename_kwargs(kwargs, old_name, new_name):
if old_name in kwargs and new_name not in kwargs:
kwargs[new_name] = kwargs[old_name]
Expand Down
Loading

0 comments on commit 36c3671

Please sign in to comment.