Skip to content

Commit

Permalink
support nnpops with expanded tests (#35)
Browse files Browse the repository at this point in the history
* integrate NNPOps into

* remove

* fix species with explicit atomic number

* fix  to

* adding energy validation test and anipotential implementation with nnpops

* fixing tests; CUDA ERROR Problem again

* rename lambda to scale for naming collision purposes and remove ml-force from CustomCVForce because nnpops doesn't like it

* demonstration of replica exchange sampler free energy calc with a mixed MM/NNP system

* add repex functionality with test

* up-to-date without simulation capacity

* test nnpops

* replacing anipotential.py with this in openmm-ml main (from openmm) will throw an openmmException CUDA handling error (once the platform in the test is switched to CUDA) on line 26, suggesting that there is something problematic with the  inside a

* expand tests, add nnpops, rename lambda because of collision

* fixing platform getter and removing old comments/test utilities.

* removing explicit reference to

Co-authored-by: dominicrufa <dominic.rufa@gmail.com>
  • Loading branch information
dominicrufa and dominicrufa authored Jul 1, 2022
1 parent f19d746 commit 5b577c2
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 29 deletions.
12 changes: 6 additions & 6 deletions openmmml/mlpotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,9 @@ def createMixedSystem(self,
3. For every NonbondedForce, a corresponding CustomBondForce to compute the
nonbonded interactions within the ML subset.
The CustomCVForce defines a global parameter called "lambda" that interpolates
between the two potentials. When lambda=0, the energy is computed entirely with
the conventional force field. When lambda=1, the energy is computed entirely with
The CustomCVForce defines a global parameter called "lambda_interpolate" that interpolates
between the two potentials. When lambda_interpolate=0, the energy is computed entirely with
the conventional force field. When lambda_interpolate=1, the energy is computed entirely with
the ML potential. You can set its value by calling setParameter() on the Context.
Parameters
Expand Down Expand Up @@ -269,7 +269,7 @@ def createMixedSystem(self,
# Create a CustomCVForce and put the ML forces inside it.

cv = openmm.CustomCVForce('')
cv.addGlobalParameter('lambda', 1)
cv.addGlobalParameter('lambda_interpolate', 1)
tempSystem = openmm.System()
self._impl.addForces(topology, tempSystem, atomList, forceGroup, **args)
mlVarNames = []
Expand Down Expand Up @@ -331,11 +331,11 @@ def createMixedSystem(self,
cv.addCollectiveVariable(name, internalNonbonded)
mmVarNames.append(name)

# Configure the CustomCVForce so lambda interpolates between the conventional and ML potentials.
# Configure the CustomCVForce so lambda_interpolate interpolates between the conventional and ML potentials.

mlSum = '+'.join(mlVarNames) if len(mlVarNames) > 0 else '0'
mmSum = '+'.join(mmVarNames) if len(mmVarNames) > 0 else '0'
cv.setEnergyFunction(f'lambda*({mlSum}) + (1-lambda)*({mmSum})')
cv.setEnergyFunction(f'lambda_interpolate*({mlSum}) + (1-lambda_interpolate)*({mmSum})')
newSystem.addForce(cv)
return newSystem

Expand Down
34 changes: 25 additions & 9 deletions openmmml/models/anipotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

from openmmml.mlpotential import MLPotential, MLPotentialImpl, MLPotentialImplFactory
import openmm
from typing import Iterable, Optional
from typing import Iterable, Optional, Union

class ANIPotentialImplFactory(MLPotentialImplFactory):
"""This is the factory that creates ANIPotentialImpl objects."""
Expand All @@ -57,32 +57,45 @@ class ANIPotentialImpl(MLPotentialImpl):
def __init__(self, name):
self.name = name


def addForces(self,
topology: openmm.app.Topology,
system: openmm.System,
atoms: Optional[Iterable[int]],
forceGroup: int,
filename: str = 'animodel.pt',
implementation : str = 'nnpops',
**args):
# Create the TorchANI model.

import torchani
import torch
import openmmtorch

# `nnpops` throws error if `periodic_table_index`=False if one passes `species` as `species_to_tensor` from `element`
_kwarg_dict = {'periodic_table_index': True}
if self.name == 'ani1ccx':
model = torchani.models.ANI1ccx()
model = torchani.models.ANI1ccx(**_kwarg_dict)
elif self.name == 'ani2x':
model = torchani.models.ANI2x()
model = torchani.models.ANI2x(**_kwarg_dict)
else:
raise ValueError('Unsupported ANI model: '+self.name)

# Create the PyTorch model that will be invoked by OpenMM.

includedAtoms = list(topology.atoms())
if atoms is not None:
includedAtoms = [includedAtoms[i] for i in atoms]
elements = [atom.element.symbol for atom in includedAtoms]
species = model.species_to_tensor(elements).unsqueeze(0)
species = torch.tensor([[atom.element.atomic_number for atom in includedAtoms]])
if implementation == 'nnpops':
try:
from NNPOps import OptimizedTorchANI
model = OptimizedTorchANI(model, species)
except Exception as e:
print(f"failed to equip `nnpops` with error: {e}")
elif implementation == "torchani":
pass # do nothing
else:
raise NotImplementedError(f"implementation {implementation} is not supported")

class ANIForce(torch.nn.Module):

Expand All @@ -102,16 +115,20 @@ def __init__(self, model, species, atoms, periodic):

def forward(self, positions, boxvectors: Optional[torch.Tensor] = None):
positions = positions.to(torch.float32)
#print(f"(boxvectors, scale): {boxvectors, scale}")
if self.indices is not None:
positions = positions[self.indices]
if boxvectors is None:
_, energy = self.model((self.species, 10.0*positions.unsqueeze(0)))
else:
boxvectors = boxvectors.to(torch.float32)
_, energy = self.model((self.species, 10.0*positions.unsqueeze(0)), cell=10.0*boxvectors, pbc=self.pbc)

return self.energyScale*energy

aniForce = ANIForce(model, species, atoms, topology.getPeriodicBoxVectors() is not None)
# is_periodic...
is_periodic = (topology.getPeriodicBoxVectors() is not None) or system.usesPeriodicBoundaryConditions()
aniForce = ANIForce(model, species, atoms, is_periodic)

# Convert it to TorchScript and save it.

Expand All @@ -122,8 +139,7 @@ def forward(self, positions, boxvectors: Optional[torch.Tensor] = None):

force = openmmtorch.TorchForce(filename)
force.setForceGroup(forceGroup)
if topology.getPeriodicBoxVectors() is not None:
force.setUsesPeriodicBoundaryConditions(True)
force.setUsesPeriodicBoundaryConditions(is_periodic)
system.addForce(force)

MLPotential.registerImplFactory('ani1ccx', ANIPotentialImplFactory())
Expand Down
29 changes: 15 additions & 14 deletions test/TestMLPotential.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,26 @@
import openmm as mm
import numpy as np
import openmm.app as app
import openmm.unit as unit
from openmmml import MLPotential
import unittest
import pytest
import itertools
rtol=1e-5
platform_ints = range(mm.Platform.getNumPlatforms())

class TestMLPotential(unittest.TestCase):

def testCreateMixedSystem(self):
@pytest.mark.parametrize("implementation,platform_name", list(itertools.product(['nnpops', 'torchani'], list(platform_ints))))
class TestMLPotential:

def testCreateMixedSystem(self, implementation, platform_int):
pdb = app.PDBFile('alanine-dipeptide-explicit.pdb')
ff = app.ForceField('amber14-all.xml', 'amber14/tip3pfb.xml')
mmSystem = ff.createSystem(pdb.topology, nonbondedMethod=app.PME)
potential = MLPotential('ani2x')
mlAtoms = [a.index for a in next(pdb.topology.chains()).atoms()]
mixedSystem = potential.createMixedSystem(pdb.topology, mmSystem, mlAtoms, interpolate=False)
interpSystem = potential.createMixedSystem(pdb.topology, mmSystem, mlAtoms, interpolate=True)
platform = mm.Platform.getPlatformByName('Reference')
mixedSystem = potential.createMixedSystem(pdb.topology, mmSystem, mlAtoms, interpolate=False, implementation=implementation)
interpSystem = potential.createMixedSystem(pdb.topology, mmSystem, mlAtoms, interpolate=True, implementation=implementation)
platform = mm.Platform.getPlatform(platform_int)
mmContext = mm.Context(mmSystem, mm.VerletIntegrator(0.001), platform)
mixedContext = mm.Context(mixedSystem, mm.VerletIntegrator(0.001), platform)
interpContext = mm.Context(interpSystem, mm.VerletIntegrator(0.001), platform)
Expand All @@ -24,12 +30,7 @@ def testCreateMixedSystem(self):
mmEnergy = mmContext.getState(getEnergy=True).getPotentialEnergy().value_in_unit(unit.kilojoules_per_mole)
mixedEnergy = mixedContext.getState(getEnergy=True).getPotentialEnergy().value_in_unit(unit.kilojoules_per_mole)
interpEnergy1 = interpContext.getState(getEnergy=True).getPotentialEnergy().value_in_unit(unit.kilojoules_per_mole)
interpContext.setParameter('lambda', 0)
interpContext.setParameter('lambda_interpolate', 0)
interpEnergy2 = interpContext.getState(getEnergy=True).getPotentialEnergy().value_in_unit(unit.kilojoules_per_mole)
self.assertAlmostEqual(mixedEnergy, interpEnergy1, delta=1e-5*abs(mixedEnergy))
self.assertAlmostEqual(mmEnergy, interpEnergy2, delta=1e-5*abs(mmEnergy))


if __name__ == '__main__':
unittest.main()

assert np.isclose(mixedEnergy, interpEnergy1, rtol=rtol)
assert np.isclose(mmEnergy, interpEnergy2, rtol=rtol)

0 comments on commit 5b577c2

Please sign in to comment.