Skip to content

Commit

Permalink
Removes self-inspection of constructor arguments (#91)
Browse files Browse the repository at this point in the history
* Removed self-inspection of arguments

* Added classes to __all__ to facilitate import
  • Loading branch information
craabreu authored Apr 15, 2024
1 parent 677f7f6 commit 907d7ca
Show file tree
Hide file tree
Showing 24 changed files with 161 additions and 146 deletions.
25 changes: 25 additions & 0 deletions cvpack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,28 @@
from .sheet_rmsd_content import SheetRMSDContent # noqa: F401
from .torsion import Torsion # noqa: F401
from .torsion_similarity import TorsionSimilarity # noqa: F401

__all__ = [
"Angle",
"AtomicFunction",
"AttractionStrength",
"CentroidFunction",
"CollectiveVariable",
"CompositeRMSD",
"Distance",
"HelixAngleContent",
"HelixHBondContent",
"HelixRMSDContent",
"HelixTorsionContent",
"MetaCollectiveVariable",
"NumberOfContacts",
"OpenMMForceWrapper",
"PathInCVSpace",
"RadiusOfGyration",
"RadiusOfGyrationSq",
"ResidueCoordination",
"RMSD",
"SheetRMSDContent",
"Torsion",
"TorsionSimilarity",
]
4 changes: 3 additions & 1 deletion cvpack/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ def __init__(
super().__init__("theta")
self.addAngle(atom1, atom2, atom3, [])
self.setUsesPeriodicBoundaryConditions(pbc)
self._registerCV(name, mmunit.radians, atom1, atom2, atom3, pbc)
self._registerCV(
name, mmunit.radians, atom1=atom1, atom2=atom2, atom3=atom3, pbc=pbc
)
self._registerPeriodicBounds(-np.pi, np.pi)


Expand Down
10 changes: 5 additions & 5 deletions cvpack/atomic_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,11 @@ def __init__(
self._registerCV(
name,
unit,
function,
unit,
groups,
periodicBounds,
pbc,
function=function,
unit=unit,
groups=groups,
periodicBounds=periodicBounds,
pbc=pbc,
**overalls,
**perbonds,
)
Expand Down
14 changes: 7 additions & 7 deletions cvpack/attraction_strength.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,13 +230,13 @@ def __init__( # pylint: disable=too-many-locals

self._registerCV(
name,
None,
group1,
group2,
nonbondedForce,
contrastGroup,
reference,
contrastScaling,
mmunit.dimensionless,
group1=group1,
group2=group2,
nonbondedForce=nonbondedForce,
contrastGroup=contrastGroup,
reference=reference,
contrastScaling=contrastScaling,
)


Expand Down
14 changes: 7 additions & 7 deletions cvpack/centroid_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,13 +193,13 @@ def __init__(
self._registerCV(
name,
unit,
function,
unit,
groups,
collections,
periodicBounds,
pbc,
weighByMass,
function=function,
unit=unit,
groups=groups,
collections=collections,
periodicBounds=periodicBounds,
pbc=pbc,
weighByMass=weighByMass,
**overalls,
**perbonds,
)
Expand Down
48 changes: 3 additions & 45 deletions cvpack/collective_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
"""

import inspect
import typing as t
from collections import OrderedDict

import openmm
import yaml
Expand Down Expand Up @@ -46,8 +44,7 @@ def __deepcopy__(self, _) -> "CollectiveVariable":
def _registerCV(
self,
name: str,
unit: t.Optional[mmunit.Unit],
*args: t.Any,
cvUnit: Unit,
**kwargs: t.Any,
) -> None:
"""
Expand All @@ -63,18 +60,13 @@ def _registerCV(
The unit of measurement of this collective variable. It must be a unit
in the MD unit system (mass in Da, distance in nm, time in ps,
temperature in K, energy in kJ/mol, angle in rad).
args
The arguments needed to construct this collective variable
kwargs
The keyword arguments needed to construct this collective variable
"""
self.setName(name)
self._unit = Unit("dimensionless") if unit is None else unit
self._unit = cvUnit
self._mass_unit = Unit(mmunit.dalton * (mmunit.nanometers / self._unit) ** 2)
arguments, _ = self._getArguments()
arguments.pop("name")
self._args = dict(zip(arguments, args))
self._args["name"] = name
self._args = {"name": name}
self._args.update(kwargs)

def _registerPeriodicBounds(
Expand All @@ -99,40 +91,6 @@ def _registerPeriodicBounds(
upper = upper.value_in_unit(self.getUnit())
self._periodic_bounds = Quantity((lower, upper), self.getUnit())

@classmethod
def _getArguments(cls) -> t.Tuple[OrderedDict, OrderedDict]:
"""
Inspect the arguments needed for constructing an instance of this collective
variable.
Returns
-------
OrderedDict
A dictionary with the type annotations of all arguments
OrderedDict
A dictionary with the default values of optional arguments
Example
-------
>>> import cvpack
>>> args, defaults = cvpack.RadiusOfGyration._getArguments()
>>> for name, annotation in args.items():
... print(f"{name}: {annotation}")
group: typing.Iterable[int]
pbc: <class 'bool'>
weighByMass: <class 'bool'>
name: <class 'str'>
>>> print(*defaults.items())
('pbc', False) ('weighByMass', False) ('name', 'radius_of_gyration')
"""
arguments = OrderedDict()
defaults = OrderedDict()
for name, parameter in inspect.signature(cls).parameters.items():
arguments[name] = parameter.annotation
if parameter.default is not inspect.Parameter.empty:
defaults[name] = parameter.default
return arguments, defaults

def _setUnusedForceGroup(self, system: openmm.System) -> None:
"""
Set the force group of this collective variable to the one at a given position
Expand Down
8 changes: 7 additions & 1 deletion cvpack/composite_rmsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,13 @@ def __init__(
super().__init__(all_coords)
for group in groups:
self.addGroup(group)
self._registerCV(name, mmunit.nanometers, defined_coords, groups, num_atoms)
self._registerCV(
name,
mmunit.nanometers,
referencePositions=defined_coords,
groups=groups,
numAtoms=num_atoms,
)


CompositeRMSD.registerTag("!cvpack.CompositeRMSD")
2 changes: 1 addition & 1 deletion cvpack/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
super().__init__("r")
self.addBond(atom1, atom2, [])
self.setUsesPeriodicBoundaryConditions(pbc)
self._registerCV(name, mmunit.nanometers, atom1, atom2, pbc)
self._registerCV(name, mmunit.nanometers, atom1=atom1, atom2=atom2, pbc=pbc)


Distance.registerTag("!cvpack.Distance")
14 changes: 7 additions & 7 deletions cvpack/helix_angle_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,13 @@ def find_alpha_carbon(residue: mmapp.topology.Residue) -> int:
self.setUsesPeriodicBoundaryConditions(pbc)
self._registerCV(
name,
None,
residues,
pbc,
thetaReference,
tolerance,
halfExponent,
normalize,
mmunit.dimensionless,
residues=residues,
pbc=pbc,
thetaReference=thetaReference,
tolerance=tolerance,
halfExponent=halfExponent,
normalize=normalize,
)


Expand Down
8 changes: 7 additions & 1 deletion cvpack/helix_hbond_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,13 @@ def find_atom(residue: mmapp.topology.Residue, pattern: t.Pattern) -> int:
)
self.setUsesPeriodicBoundaryConditions(pbc)
self._registerCV(
name, None, residues, pbc, thresholdDistance, halfExponent, normalize
name,
mmunit.dimensionless,
residues=residues,
pbc=pbc,
thresholdDistance=thresholdDistance,
halfExponent=halfExponent,
normalize=normalize,
)


Expand Down
8 changes: 7 additions & 1 deletion cvpack/helix_rmsd_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,13 @@ def __init__(
normalize,
)
self._registerCV(
name, None, residues, numAtoms, thresholdRMSD, stepFunction, normalize
name,
mmunit.dimensionless,
residues=residues,
numAtoms=numAtoms,
thresholdRMSD=thresholdRMSD,
stepFunction=stepFunction,
normalize=normalize,
)


Expand Down
14 changes: 7 additions & 7 deletions cvpack/helix_torsion_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,13 @@ def find_atom(residue: mmapp.topology.Residue, name: str) -> int:
self.setUsesPeriodicBoundaryConditions(pbc)
self._registerCV(
name,
None,
residues,
pbc,
phiReference,
psiReference,
tolerance,
halfExponent,
mmunit.dimensionless,
residues=residues,
pbc=pbc,
phiReference=phiReference,
psiReference=psiReference,
tolerance=tolerance,
halfExponent=halfExponent,
)


Expand Down
8 changes: 4 additions & 4 deletions cvpack/meta_collective_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,10 @@ def __init__(
self._registerCV(
name,
unit,
function,
variables,
unit,
periodicBounds,
function=function,
variables=variables,
unit=unit,
periodicBounds=periodicBounds,
**self._parameters,
)
if periodicBounds is not None:
Expand Down
18 changes: 9 additions & 9 deletions cvpack/number_of_contacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,15 +162,15 @@ def __init__(
self.setEnergyFunction(expression.replace("/1;", f"/{reference};"))
self._registerCV(
name,
None,
group1,
group2,
nonbondedForce,
reference,
stepFunction,
thresholdDistance,
cutoffFactor,
switchFactor,
mmunit.dimensionless,
group1=group1,
group2=group2,
nonbondedForce=nonbondedForce,
reference=reference,
stepFunction=stepFunction,
thresholdDistance=thresholdDistance,
cutoffFactor=cutoffFactor,
switchFactor=switchFactor,
)


Expand Down
8 changes: 7 additions & 1 deletion cvpack/openmm_force_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,13 @@ def __init__( # pylint: disable=super-init-not-called
unit = Unit(unit)
self._wrapped_force = openmm.XmlSerializer.deserialize(openmmForce)
self.this = self._wrapped_force.this
self._registerCV(name, unit, openmmForce, unit, periodicBounds)
self._registerCV(
name,
unit,
openmmForce=openmmForce,
unit=unit,
periodicBounds=periodicBounds,
)
if periodicBounds is not None:
self._registerPeriodicBounds(*periodicBounds)

Expand Down
13 changes: 7 additions & 6 deletions cvpack/path_in_cv_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from copy import deepcopy

import openmm
from openmm import unit as mmunit

from .collective_variable import CollectiveVariable
from .path import Metric, deviation, progress
Expand Down Expand Up @@ -173,12 +174,12 @@ def __init__( # pylint: disable=too-many-branches
self.addCollectiveVariable(f"cv{i}", deepcopy(variable))
self._registerCV(
name,
None,
metric,
variables,
milestones.tolist(),
sigma,
scales,
mmunit.dimensionless,
metric=metric,
variables=variables,
milestones=milestones.tolist(),
sigma=sigma,
scales=scales,
)


Expand Down
4 changes: 3 additions & 1 deletion cvpack/radius_of_gyration.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ def __init__(
num_groups, f"sqrt(({sum_dist_sq})/{num_atoms})", group, pbc, weighByMass
)
self.addBond(list(range(num_groups)))
self._registerCV(name, mmunit.nanometers, group, pbc, weighByMass)
self._registerCV(
name, mmunit.nanometers, group=group, pbc=pbc, weighByMass=weighByMass
)


RadiusOfGyration.registerTag("!cvpack.RadiusOfGyration")
4 changes: 3 additions & 1 deletion cvpack/radius_of_gyration_sq.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ def __init__(
super().__init__(2, f"distance(g1, g2)^2/{num_atoms}", group, pbc, weighByMass)
for atom in group:
self.addBond([atom, num_atoms])
self._registerCV(name, mmunit.nanometers**2, group, pbc, weighByMass)
self._registerCV(
name, mmunit.nanometers**2, group=group, pbc=pbc, weighByMass=weighByMass
)


RadiusOfGyrationSq.registerTag("!cvpack.RadiusOfGyrationSq")
Loading

0 comments on commit 907d7ca

Please sign in to comment.