Skip to content

Commit

Permalink
Excise openff dependency from openmm testing (#993)
Browse files Browse the repository at this point in the history
* Excise openff dependency from openmm testing

* Remove commmented out code

* Update src/atomate2/openmm/jobs/base.py

Co-authored-by: Janosh Riebesell <janosh.riebesell@gmail.com>

* Respond to Yanosh PR and fix type of OpenMM Flow

* Fix typo, lint

* Add dataclass tag where needed

---------

Co-authored-by: Janosh Riebesell <janosh.riebesell@gmail.com>
  • Loading branch information
orionarcher and janosh authored Sep 27, 2024
1 parent 96b2b82 commit 9600cef
Show file tree
Hide file tree
Showing 11 changed files with 195 additions and 90 deletions.
4 changes: 2 additions & 2 deletions src/atomate2/openmm/flows/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import TYPE_CHECKING

from emmet.core.openmm import Calculation, OpenMMInterchange, OpenMMTaskDocument
from jobflow import Flow, Job, Response
from jobflow import Flow, Job, Maker, Response
from monty.json import MontyDecoder, MontyEncoder

from atomate2.openmm.jobs.base import openmm_job
Expand Down Expand Up @@ -68,7 +68,7 @@ def collect_outputs(


@dataclass
class OpenMMFlowMaker:
class OpenMMFlowMaker(Maker):
"""Run a production simulation.
This flexible flow links together any flows of OpenMM jobs in
Expand Down
11 changes: 9 additions & 2 deletions src/atomate2/openmm/jobs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,11 @@ def _update_interchange(
interchange.box = state.getPeriodicBoxVectors(asNumpy=True)
elif isinstance(interchange, OpenMMInterchange):
interchange.state = XmlSerializer.serialize(state)
else:
raise TypeError(
f"Interchange must be an Interchange or "
f"OpenMMInterchange object, got {type(interchange).__name__}"
)

def _create_structure(
self, sim: Simulation, prev_task: OpenMMTaskDocument | None = None
Expand Down Expand Up @@ -607,8 +612,10 @@ def _create_task_doc(

prev_task = prev_task or OpenMMTaskDocument()

interchange_json = interchange.json()
# interchange_bytes = interchange_json.encode("utf-8")
if isinstance(interchange, Interchange):
interchange_json = interchange.json()
else:
interchange_json = interchange.model_dump_json()

return OpenMMTaskDocument(
tags=tags,
Expand Down
39 changes: 39 additions & 0 deletions src/atomate2/openmm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,23 @@

from __future__ import annotations

import io
import re
import tempfile
import time
import warnings
from pathlib import Path
from typing import TYPE_CHECKING

import numpy as np
import openmm.unit as omm_unit
from emmet.core.openmm import OpenMMInterchange
from openmm import LangevinMiddleIntegrator, XmlSerializer
from openmm.app import PDBFile

if TYPE_CHECKING:
from emmet.core.openmm import OpenMMTaskDocument
from openff.interchange import Interchange


def download_opls_xml(
Expand Down Expand Up @@ -132,3 +140,34 @@ def task_reports(task: OpenMMTaskDocument, traj_or_state: str = "traj") -> bool:
else:
raise ValueError("traj_or_state must be 'traj' or 'state'")
return calc_input.n_steps >= report_freq


def openff_to_openmm_interchange(
openff_interchange: Interchange,
) -> OpenMMInterchange:
"""Convert an OpenFF Interchange object to an OpenMM Interchange object."""
integrator = LangevinMiddleIntegrator(
300 * omm_unit.kelvin,
10.0 / omm_unit.picoseconds,
1.0 * omm_unit.femtoseconds,
)
sim = openff_interchange.to_openmm_simulation(integrator)
state = sim.context.getState(
getPositions=True,
getVelocities=True,
enforcePeriodicBox=True,
)
with io.StringIO() as buffer:
PDBFile.writeFile(
sim.topology,
np.zeros(shape=(sim.topology.getNumAtoms(), 3)),
file=buffer,
)
buffer.seek(0)
pdb = buffer.read()

return OpenMMInterchange(
system=XmlSerializer.serialize(sim.system),
state=XmlSerializer.serialize(state),
topology=pdb,
)
5 changes: 4 additions & 1 deletion tests/openff_md/test_core.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import pytest
from emmet.core.openff import ClassicalMDTaskDocument, MoleculeSpec
from openff.interchange import Interchange

from atomate2.openff.core import generate_interchange

pytest.importorskip("openff.toolkit")
from openff.interchange import Interchange # noqa: E402


def test_generate_interchange(mol_specs_small, run_job):
mass_density = 1
Expand Down
28 changes: 15 additions & 13 deletions tests/openff_md/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,8 @@
import numpy as np
import openff.toolkit as tk
import pymatgen
import pytest
from emmet.core.openff import MoleculeSpec
from openff.interchange import Interchange
from openff.toolkit.topology import Topology
from openff.toolkit.topology.molecule import Molecule
from openff.units import Quantity
from pymatgen.analysis.graphs import MoleculeGraph
from pymatgen.io.openff import (
add_conformer,
assign_partial_charges,
create_openff_mol,
get_atom_map,
infer_openff_mol,
mol_graph_to_openff_mol,
)

from atomate2.openff.utils import (
counts_from_box_size,
Expand All @@ -24,6 +11,21 @@
merge_specs_by_name_and_smiles,
)

pytest.importorskip("openff.toolkit")
import openff.toolkit as tk # noqa: E402
from openff.interchange import Interchange # noqa: E402
from openff.toolkit.topology import Topology # noqa: E402
from openff.toolkit.topology.molecule import Molecule # noqa: E402
from openff.units import Quantity # noqa: E402
from pymatgen.io.openff import ( # noqa: E402
add_conformer,
assign_partial_charges,
create_openff_mol,
get_atom_map,
infer_openff_mol,
mol_graph_to_openff_mol,
)


def test_molgraph_to_openff_pf6(mol_files):
"""transform a water MoleculeGraph to a OpenFF water molecule"""
Expand Down
85 changes: 52 additions & 33 deletions tests/openmm_md/conftest.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
import openff.toolkit as tk
import pytest
from emmet.core.openmm import OpenMMInterchange
from jobflow import run_locally
from openff.interchange import Interchange
from openff.interchange.components._packmol import pack_box
from openff.toolkit import ForceField
from openff.units import unit

from atomate2.openff.utils import create_mol_spec, merge_specs_by_name_and_smiles


@pytest.fixture
Expand All @@ -18,37 +12,62 @@ def run_job(job):
return run_job


@pytest.fixture
@pytest.fixture(scope="package")
def openmm_data(test_dir):
return test_dir / "openmm"


@pytest.fixture(scope="package")
def interchange():
o = create_mol_spec("O", 300, charge_method="mmff94")
cco = create_mol_spec("CCO", 10, charge_method="mmff94")
cco2 = create_mol_spec("CCO", 20, name="cco2", charge_method="mmff94")
mol_specs = [o, cco, cco2]
mol_specs.sort(
key=lambda x: tk.Molecule.from_json(x.openff_mol).to_smiles() + x.name
)

topology = pack_box(
molecules=[tk.Molecule.from_json(spec.openff_mol) for spec in mol_specs],
number_of_copies=[spec.count for spec in mol_specs],
mass_density=0.8 * unit.grams / unit.milliliter,
)

mol_specs = merge_specs_by_name_and_smiles(mol_specs)

return Interchange.from_smirnoff(
force_field=ForceField("openff_unconstrained-2.1.1.offxml"),
topology=topology,
charge_from_molecules=[
tk.Molecule.from_json(spec.openff_mol) for spec in mol_specs
],
allow_nonintegral_charges=True,
)
def interchange(openmm_data):
# we use openff to generate the interchange object that we test on
# but we don't want to create a logical dependency on openff, in
# case the user has another way of generating the interchange object
regenerate_test_data = False
if regenerate_test_data:
import openff.toolkit as tk
from openff.interchange import Interchange
from openff.interchange.components._packmol import pack_box
from openff.toolkit import ForceField
from openff.units import unit

from atomate2.openff.utils import (
create_mol_spec,
merge_specs_by_name_and_smiles,
)
from atomate2.openmm.utils import openff_to_openmm_interchange

o = create_mol_spec("O", 300, charge_method="mmff94")
cco = create_mol_spec("CCO", 10, charge_method="mmff94")
cco2 = create_mol_spec("CCO", 20, name="cco2", charge_method="mmff94")
mol_specs = [o, cco, cco2]
mol_specs.sort(
key=lambda x: tk.Molecule.from_json(x.openff_mol).to_smiles() + x.name
)

topology = pack_box(
molecules=[tk.Molecule.from_json(spec.openff_mol) for spec in mol_specs],
number_of_copies=[spec.count for spec in mol_specs],
mass_density=0.8 * unit.grams / unit.milliliter,
)

mol_specs = merge_specs_by_name_and_smiles(mol_specs)

openff_interchange = Interchange.from_smirnoff(
force_field=ForceField("openff_unconstrained-2.1.1.offxml"),
topology=topology,
charge_from_molecules=[
tk.Molecule.from_json(spec.openff_mol) for spec in mol_specs
],
allow_nonintegral_charges=True,
)

openmm_interchange = openff_to_openmm_interchange(openff_interchange)

with open(openmm_data / "interchange.json", "w") as file:
file.write(openmm_interchange.model_dump_json())

with open(openmm_data / "interchange.json") as file:
return OpenMMInterchange.model_validate_json(file.read())


@pytest.fixture
Expand Down
25 changes: 9 additions & 16 deletions tests/openmm_md/flows/test_core.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from __future__ import annotations

import io
import json
from pathlib import Path

import numpy as np
import pytest
from emmet.core.openmm import OpenMMTaskDocument
from emmet.core.openmm import OpenMMInterchange, OpenMMTaskDocument
from jobflow import Flow
from MDAnalysis import Universe
from monty.json import MontyDecoder
from openff.interchange import Interchange
from openmm.app import PDBFile

from atomate2.openmm.flows.core import OpenMMFlowMaker
from atomate2.openmm.jobs import EnergyMinimizationMaker, NPTMaker, NVTMaker
Expand Down Expand Up @@ -156,22 +157,13 @@ def test_flow_maker(interchange, run_job):
calc_output = task_doc.calcs_reversed[0].output
assert len(calc_output.steps_reported) == 5

all_steps = [calc.output.steps_reported for calc in task_doc.calcs_reversed]
assert all_steps == [
[1, 2, 3, 4, 5],
[1],
[1, 2],
[1, 2],
[1, 2, 3, 4, 5],
None,
]
# Test that the state interval is respected
assert calc_output.steps_reported == list(range(1, 6))
assert calc_output.steps_reported == list(range(11, 16))
assert calc_output.traj_file == "trajectory5.dcd"
assert calc_output.state_file == "state5.csv"

interchange = Interchange.parse_raw(task_doc.interchange)
topology = interchange.to_openmm_topology()
interchange = OpenMMInterchange.model_validate_json(task_doc.interchange)
topology = PDBFile(io.StringIO(interchange.topology)).getTopology()
u = Universe(topology, str(Path(task_doc.dir_name) / "trajectory5.dcd"))

assert len(u.trajectory) == 5
Expand All @@ -184,8 +176,9 @@ def test_traj_blob_embed(interchange, run_job, tmp_path):
nvt_job = nvt.make(interchange)
task_doc = run_job(nvt_job)

interchange = Interchange.parse_raw(task_doc.interchange)
topology = interchange.to_openmm_topology()
interchange = OpenMMInterchange.model_validate_json(task_doc.interchange)
topology = PDBFile(io.StringIO(interchange.topology)).getTopology()

u = Universe(topology, str(Path(task_doc.dir_name) / "trajectory.dcd"))

assert len(u.trajectory) == 2
Expand Down
27 changes: 18 additions & 9 deletions tests/openmm_md/jobs/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from emmet.core.openmm import Calculation, CalculationInput, OpenMMTaskDocument
from jobflow import Flow, Job
from mdareporter import MDAReporter
from openmm import XmlSerializer
from openmm.app import Simulation, StateDataReporter
from openmm.openmm import LangevinMiddleIntegrator
from openmm.unit import kelvin, picoseconds
Expand Down Expand Up @@ -70,22 +71,30 @@ def test_create_simulation(interchange):
def test_update_interchange(interchange):
interchange = copy.deepcopy(interchange)
maker = BaseOpenMMMaker(wrap_traj=True)

sim = maker._create_simulation(interchange) # noqa: SLF001
start_positions = interchange.positions
start_velocities = interchange.velocities
start_box = interchange.box

state = XmlSerializer.deserialize(interchange.state)
start_positions = state.getPositions(asNumpy=True)
start_velocities = state.getVelocities(asNumpy=True)
start_box = state.getPeriodicBoxVectors()

# Run the simulation for one step
sim.step(1)
sim.step(2)

maker._update_interchange(interchange, sim, None) # noqa: SLF001

assert interchange.positions.shape == start_positions.shape
assert interchange.velocities.shape == (1170, 3)
new_state = XmlSerializer.deserialize(interchange.state)
new_positions = new_state.getPositions(asNumpy=True)
new_velocities = new_state.getVelocities(asNumpy=True)
new_box = new_state.getPeriodicBoxVectors()

assert new_positions.shape == start_positions.shape
assert new_velocities.shape == start_velocities.shape

assert np.any(interchange.positions != start_positions)
assert np.any(interchange.velocities != start_velocities)
assert np.all(interchange.box == start_box)
assert not np.all(new_positions == start_positions)
assert not np.all(new_velocities == start_velocities)
assert np.all(new_box == start_box)


def test_create_task_doc(interchange, tmp_path):
Expand Down
Loading

0 comments on commit 9600cef

Please sign in to comment.