Skip to content

Commit

Permalink
Created entry point for registering potentials (#75)
Browse files Browse the repository at this point in the history
  • Loading branch information
peastman authored Apr 11, 2024
1 parent df00239 commit 688b1fc
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 12 deletions.
19 changes: 17 additions & 2 deletions openmmml/mlpotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
Biological Structures at Stanford, funded under the NIH Roadmap for
Medical Research, grant U54 GM072970. See https://simtk.org.
Portions copyright (c) 2021 Stanford University and the Authors.
Portions copyright (c) 2021-2024 Stanford University and the Authors.
Authors: Peter Eastman
Contributors:
Expand Down Expand Up @@ -34,14 +34,23 @@
import openmm.unit as unit
from copy import deepcopy
from typing import Dict, Iterable, Optional
import sys
if sys.version_info < (3, 10):
from importlib_metadata import entry_points
else:
from importlib.metadata import entry_points


class MLPotentialImplFactory(object):
"""Abstract interface for classes that create MLPotentialImpl objects.
If you are defining a new potential function, you need to create subclasses
of MLPotentialImpl and MLPotentialImplFactory, and register an instance of
the factory by calling MLPotential.registerImplFactory().
the factory by calling MLPotential.registerImplFactory(). Alternatively,
if a Python package creates an entry point in the group "openmmml.potentials",
the potential will be registered automatically. The entry point name is the
name of the potential function, and the value should be the name of the
MLPotentialImplFactory subclass.
"""

def createImpl(self, name: str, **args) -> "MLPotentialImpl":
Expand Down Expand Up @@ -417,3 +426,9 @@ def registerImplFactory(name: str, factory: MLPotentialImplFactory):
a factory object that will be used to create MLPotentialImpl objects
"""
MLPotential._implFactories[name] = factory


# Register any potential functions defined by entry points.

for potential in entry_points(group='openmmml.potentials'):
MLPotential.registerImplFactory(potential.name, potential.load()())
3 changes: 0 additions & 3 deletions openmmml/models/anipotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,3 @@ def forward(self, positions, boxvectors: Optional[torch.Tensor] = None):
force.setForceGroup(forceGroup)
force.setUsesPeriodicBoundaryConditions(is_periodic)
system.addForce(force)

MLPotential.registerImplFactory('ani1ccx', ANIPotentialImplFactory())
MLPotential.registerImplFactory('ani2x', ANIPotentialImplFactory())
6 changes: 0 additions & 6 deletions openmmml/models/macepotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,9 +405,3 @@ def forward(
force.setForceGroup(forceGroup)
force.setUsesPeriodicBoundaryConditions(isPeriodic)
system.addForce(force)


MLPotential.registerImplFactory("mace", MACEPotentialImplFactory())
MLPotential.registerImplFactory("mace-off23-small", MACEPotentialImplFactory())
MLPotential.registerImplFactory("mace-off23-medium", MACEPotentialImplFactory())
MLPotential.registerImplFactory("mace-off23-large", MACEPotentialImplFactory())
13 changes: 12 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,16 @@
classifiers=CLASSIFIERS.splitlines(),
packages=find_packages(),
zip_safe=False,
install_requires=['numpy', 'openmm >= 7.5'])
install_requires=['numpy', 'openmm >= 7.5'],
entry_points={
'openmmml.potentials': [
'ani1ccx = openmmml.models.anipotential:ANIPotentialImplFactory',
'ani2x = openmmml.models.anipotential:ANIPotentialImplFactory',
'mace = openmmml.models.macepotential:MACEPotentialImplFactory',
'mace-off23-small = openmmml.models.macepotential:MACEPotentialImplFactory',
'mace-off23-medium = openmmml.models.macepotential:MACEPotentialImplFactory',
'mace-off23-large = openmmml.models.macepotential:MACEPotentialImplFactory'
]
}
)

0 comments on commit 688b1fc

Please sign in to comment.