Skip to content

Commit

Permalink
Fixes NumberOfContacts CV (breaks API) (#50)
Browse files Browse the repository at this point in the history
* Add nonbonded exclusions in NumberOfContacts CV

* Fixed number of contacts unit test

* Sent force evalution function from attraction_strength to utils

* Allows reference definition for number of contacts
  • Loading branch information
craabreu authored Jan 21, 2024
1 parent 8fd36f9 commit c6642d4
Show file tree
Hide file tree
Showing 4 changed files with 232 additions and 159 deletions.
111 changes: 11 additions & 100 deletions cvpack/attraction_strength.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,89 +14,11 @@
from cvpack import unit as mmunit

from .cvpack import AbstractCollectiveVariable
from .utils import NonbondedForceSurrogate, evaluate_in_context

ONE_4PI_EPS0 = 138.93545764438198


class _NonbondedForceSurrogate: # pylint: disable=too-many-instance-attributes
"""A surrogate class for the NonbondedForce class in OpenMM."""

def __init__(self, other: openmm.NonbondedForce) -> None:
self._cutoff = other.getCutoffDistance()
self._uses_pbc = other.usesPeriodicBoundaryConditions()
self._num_particles = other.getNumParticles()
self._particle_parameters = list(
map(other.getParticleParameters, range(self._num_particles))
)
self._num_exceptions = other.getNumExceptions()
self._exception_parameters = list(
map(other.getExceptionParameters, range(self._num_exceptions))
)
self._use_switching_function = other.getUseSwitchingFunction()
self._switching_distance = other.getSwitchingDistance()

def __getstate__(self) -> t.Dict[str, str]:
return {
"cutoff": self.getCutoffDistance(),
"uses_pbc": self.usesPeriodicBoundaryConditions(),
"num_particles": self.getNumParticles(),
"particle_parameters": [
self.getParticleParameters(i) for i in range(self.getNumParticles())
],
"num_exceptions": self.getNumExceptions(),
"exception_parameters": [
self.getExceptionParameters(i) for i in range(self.getNumExceptions())
],
"use_switching_function": self.getUseSwitchingFunction(),
"switching_distance": self.getSwitchingDistance(),
}

def __setstate__(self, state: t.Dict[str, str]) -> None:
self._cutoff = state["cutoff"]
self._uses_pbc = state["uses_pbc"]
self._num_particles = state["num_particles"]
self._particle_parameters = state["particle_parameters"]
self._num_exceptions = state["num_exceptions"]
self._exception_parameters = state["exception_parameters"]
self._use_switching_function = state["use_switching_function"]
self._switching_distance = state["switching_distance"]

def getCutoffDistance(self) -> float:
"""Get the cutoff distance."""
return mmunit.value_in_md_units(self._cutoff)

def usesPeriodicBoundaryConditions(self) -> bool:
"""Return whether periodic boundary conditions are used."""
return self._uses_pbc

def getNumParticles(self) -> int:
"""Get the number of particles."""
return self._num_particles

def getParticleParameters(self, index: int) -> t.Tuple[float, float, float]:
"""Get the parameters of a particle at the given index."""
return tuple(map(mmunit.value_in_md_units, self._particle_parameters[index]))

def getNumExceptions(self):
"""Get the number of exceptions."""
return self._num_exceptions

def getExceptionParameters(
self, index: int
) -> t.Tuple[int, int, float, float, float]:
"""Get the parameters of an exception at the given index."""
i, j, *params = self._exception_parameters[index]
return i, j, *map(mmunit.value_in_md_units, params)

def getUseSwitchingFunction(self) -> bool:
"""Return whether a switching function is used."""
return self._use_switching_function

def getSwitchingDistance(self) -> float:
"""Get the switching distance."""
return mmunit.value_in_md_units(self._switching_distance)


class AttractionStrength(openmm.CustomNonbondedForce, AbstractCollectiveVariable):
"""
The strength of the attraction between two atom groups:
Expand Down Expand Up @@ -151,9 +73,12 @@ class AttractionStrength(openmm.CustomNonbondedForce, AbstractCollectiveVariable
The Lennard-Jones parameters, atomic charges, cutoff distance, boundary conditions,
as well as whether to use a switching function and its corresponding switching
distance, are taken from :openmm:`NonbondedForce` object. Any non-exclusion
exceptions involving atoms in :math:`{\\bf g}_1` and :math:`{\\bf g}_2` are turned
into exclusions.
distance, are taken from :openmm:`NonbondedForce` object.
.. note::
Any non-exclusion exceptions involving atoms in :math:`{\\bf g}_1` and
:math:`{\\bf g}_2` in the provided :class:`openmm.NonbondedForce` are turned
into exclusions in this collective variable.
Parameters
----------
Expand All @@ -162,12 +87,12 @@ class AttractionStrength(openmm.CustomNonbondedForce, AbstractCollectiveVariable
group2
The second atom group.
nonbondedForce
The :openmm:`NonbondedForce` object from which to collect the necessary
The :class:`openmm.NonbondedForce` object from which to collect the necessary
parameters.
reference
A reference value (in energy units per mole) to which the collective variable
should be normalized. One can also provide an :OpenMM:`Context` object from
which to obtain the reference value.
which to obtain a reference attraction strength.
Examples
--------
Expand Down Expand Up @@ -229,7 +154,7 @@ def __init__( # pylint: disable=too-many-arguments
1.0, mmunit.kilojoule_per_mole
),
) -> None:
nonbondedForce = _NonbondedForceSurrogate(nonbondedForce)
nonbondedForce = NonbondedForceSurrogate(nonbondedForce)
cutoff = nonbondedForce.getCutoffDistance()
expression = (
"-(lj + coul)/ref"
Expand Down Expand Up @@ -260,22 +185,8 @@ def __init__( # pylint: disable=too-many-arguments
self.setUseLongRangeCorrection(False)
self.addInteractionGroup(group1, group2)
if isinstance(reference, openmm.Context):
reference = self._getValue(reference)
reference = evaluate_in_context(self, reference)
self.setEnergyFunction(expression.replace("ref = 1", f"ref = {reference}"))
self._registerCV(
mmunit.dimensionless, group1, group2, nonbondedForce, reference
)

def _getValue(self, context: openmm.Context) -> float:
system = openmm.System()
for _ in range(context.getSystem().getNumParticles()):
system.addParticle(1.0)
system.addForce(openmm.CustomNonbondedForce(self))
state = context.getState(getPositions=True)
context = openmm.Context(system, openmm.VerletIntegrator(1.0))
context.setPositions(state.getPositions())
context.setPeriodicBoxVectors(*state.getPeriodicBoxVectors())
# pylint: disable=unexpected-keyword-arg # to avoid false positive
state = context.getState(getEnergy=True)
# pylint: enable=unexpected-keyword-arg
return mmunit.value_in_md_units(state.getPotentialEnergy())
141 changes: 87 additions & 54 deletions cvpack/number_of_contacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from cvpack import unit as mmunit

from .cvpack import AbstractCollectiveVariable
from .utils import NonbondedForceSurrogate, evaluate_in_context


class NumberOfContacts(openmm.CustomNonbondedForce, AbstractCollectiveVariable):
Expand Down Expand Up @@ -48,93 +49,125 @@ class NumberOfContacts(openmm.CustomNonbondedForce, AbstractCollectiveVariable):
(:math:`i = j`) are ignored and each pair of distinct atoms (:math:`i \\neq j`)
is counted only once.
.. note::
Any non-exclusion exceptions involving atoms in :math:`{\\bf g}_1` and
:math:`{\\bf g}_2` in the provided :class:`openmm.NonbondedForce` are turned
into exclusions in this collective variable.
Parameters
----------
group1
The indices of the atoms in the first group
group2
The indices of the atoms in the second group
numAtoms
The total number of atoms in the system (required by OpenMM)
pbc
Whether the system has periodic boundary conditions
stepFunction
The function "step(1-x)" (for analysis only) or a continuous approximation
thereof
thresholdDistance
The threshold distance (:math:`r_0`) for considering two atoms as being in
contact
cutoffFactor
The factor :math:`x_c` that multiplies the threshold distance to define
the cutoff distance
switchFactor
The factor :math:`x_s` that multiplies the threshold distance to define
the distance at which the step function starts switching off smoothly.
If None, it switches off abruptly at the cutoff distance.
group1
The indices of the atoms in the first group
group2
The indices of the atoms in the second group
nonbondedForce
The :class:`openmm.NonbondedForce` object from which the total number of
atoms, the exclusions, and whether to use periodic boundary conditions are
taken
reference
A dimensionless reference value to which the collective variable should be
normalized. One can also provide an :OpenMM:`Context` object from which to
obtain the reference number of contacts.
stepFunction
The function "step(1-x)" (for analysis only) or a continuous approximation
thereof
thresholdDistance
The threshold distance (:math:`r_0`) for considering two atoms as being in
contact
cutoffFactor
The factor :math:`x_c` that multiplies the threshold distance to define
the cutoff distance
switchFactor
The factor :math:`x_s` that multiplies the threshold distance to define
the distance at which the step function starts switching off smoothly.
If None, it switches off abruptly at the cutoff distance.
Example
-------
>>> import cvpack
>>> import openmm
>>> from openmm import app
>>> from openmmtools import testsystems
>>> model = testsystems.AlanineDipeptideVacuum()
>>> carbons = [
... a.index
... for a in model.topology.atoms()
... if a.element == app.element.carbon
... ]
>>> num_atoms = model.topology.getNumAtoms()
>>> optionals = {"pbc": False, "stepFunction": "step(1-x)"}
>>> nc = cvpack.NumberOfContacts(
... carbons, carbons, num_atoms, **optionals
... )
>>> nc.setUnusedForceGroup(0, model.system)
1
>>> model.system.addForce(nc)
5
>>> platform = openmm.Platform.getPlatformByName('Reference')
>>> context = openmm.Context(
... model.system, openmm.CustomIntegrator(0), platform
... )
>>> context.setPositions(model.positions)
>>> print(nc.getValue(context, digits=6))
6.0 dimensionless
>>> import cvpack
>>> from openmm import unit
>>> from openmmtools import testsystems
>>> model = testsystems.HostGuestExplicit()
>>> group1, group2 = [], []
>>> for residue in model.topology.residues():
... if residue.name != "HOH":
... group = group1 if residue.name == "B2" else group2
... group.extend(atom.index for atom in residue.atoms())
>>> forces = {f.getName(): f for f in model.system.getForces()}
>>> nc = cvpack.NumberOfContacts(
... group1,
... group2,
... forces["NonbondedForce"],
... stepFunction="step(1-x)",
... )
>>> nc.setUnusedForceGroup(0, model.system)
1
>>> model.system.addForce(nc)
5
>>> platform = openmm.Platform.getPlatformByName("Reference")
>>> integrator = openmm.VerletIntegrator(1.0 * mmunit.femtoseconds)
>>> context = openmm.Context(model.system, integrator, platform)
>>> context.setPositions(model.positions)
>>> print(nc.getValue(context, 4))
30.0 dimensionless
>>> nc_normalized = cvpack.NumberOfContacts(
... group1,
... group2,
... forces["NonbondedForce"],
... stepFunction="step(1-x)",
... reference=context,
... )
>>> nc_normalized.setUnusedForceGroup(0, model.system)
2
>>> model.system.addForce(nc_normalized)
6
>>> context.reinitialize(preserveState=True)
>>> print(nc_normalized.getValue(context, 4))
1.0 dimensionless
"""

@mmunit.convert_quantities
def __init__( # pylint: disable=too-many-arguments
self,
group1: t.Sequence[int],
group2: t.Sequence[int],
numAtoms: int,
pbc: bool,
nonbondedForce: openmm.NonbondedForce,
reference: t.Union[mmunit.ScalarQuantity, openmm.Context] = 1.0,
stepFunction: str = "1/(1+x^6)",
thresholdDistance: mmunit.ScalarQuantity = mmunit.Quantity(
0.3, mmunit.nanometers
),
cutoffFactor: float = 2.0,
switchFactor: t.Optional[float] = 1.5,
) -> None:
super().__init__(stepFunction + f"; x=r/{thresholdDistance}")
nonbondedForce = NonbondedForceSurrogate(nonbondedForce)
num_atoms = nonbondedForce.getNumParticles()
pbc = nonbondedForce.usesPeriodicBoundaryConditions()
expression = f"({stepFunction})/1; x=r/{thresholdDistance}"
super().__init__(expression)
nonbonded_method = self.CutoffPeriodic if pbc else self.CutoffNonPeriodic
self.setNonbondedMethod(nonbonded_method)
for _ in range(numAtoms):
for _ in range(num_atoms):
self.addParticle([])
for index in range(nonbondedForce.getNumExceptions()):
i, j, *_ = nonbondedForce.getExceptionParameters(index)
self.addExclusion(i, j)
self.setCutoffDistance(cutoffFactor * thresholdDistance)
use_switching_function = switchFactor is not None
self.setUseSwitchingFunction(use_switching_function)
if use_switching_function:
self.setSwitchingDistance(switchFactor * thresholdDistance)
self.setUseLongRangeCorrection(False)
self.addInteractionGroup(group1, group2)
if isinstance(reference, openmm.Context):
reference = evaluate_in_context(self, reference)
self.setEnergyFunction(expression.replace("/1;", f"/{reference};"))
self._registerCV(
mmunit.dimensionless,
group1,
group2,
numAtoms,
pbc,
nonbondedForce,
reference,
stepFunction,
thresholdDistance,
cutoffFactor,
Expand Down
16 changes: 11 additions & 5 deletions cvpack/tests/test_cvpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,18 +264,24 @@ def test_number_of_contacts():
group2 = [
a.index for a in model.topology.atoms() if a.element == app.element.oxygen
]
forces = {f.getName(): f for f in model.system.getForces()}
nonbonded = forces["NonbondedForce"]
exclusions = set()
for index in range(nonbonded.getNumExceptions()):
i, j, *_ = nonbonded.getExceptionParameters(index)
exclusions.add((i, j) if i < j else (j, i))
pairs = set()
for i, j in it.product(group1, group2):
if j != i and (j, i) not in pairs:
pairs.add((i, j))
if j != i:
pair = (i, j) if i < j else (j, i)
if pair not in exclusions:
pairs.add(pair)
distances = np.array([np.linalg.norm(pos[i] - pos[j]) for i, j in pairs])
contacts = np.where(distances <= 0.6, 1 / (1 + (distances / 0.3) ** 6), 0)
num_atoms = model.topology.getNumAtoms()
number_of_contacts = cvpack.NumberOfContacts(
group1,
group2,
num_atoms,
pbc=False,
forces["NonbondedForce"],
switchFactor=None,
)
number_of_contacts.setUnusedForceGroup(0, model.system)
Expand Down
Loading

0 comments on commit c6642d4

Please sign in to comment.