Skip to content

Commit

Permalink
prototype multithreaded openmm
Browse files Browse the repository at this point in the history
  • Loading branch information
axsk committed Jan 22, 2024
1 parent 6a61335 commit 0ba14c6
Showing 1 changed file with 49 additions and 12 deletions.
61 changes: 49 additions & 12 deletions src/openmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@
using PyCall
import ISOKANN: propagate

# install / load OpenMM
pyimport_conda("openmm", "openmm", "conda-forge")
pyimport_conda("joblib", "joblib")

# load into namespace
py"""
from joblib import Parallel, delayed
from openmm.app import *
from openmm import *
from openmm.unit import *
Expand All @@ -11,30 +17,59 @@ from sys import stdout

nanometer = py"nanometer"

""" A Simulation wrapping the Python OpenMM Simulation object """
struct OpenMMSimulation
pysim::PyObject
steps::Int
end

###

function propagate(sim::OpenMMSimulation, x0)
steps = sim.steps
pysim = sim.pysim
setcoords(pysim, x0)
pysim.step(steps)
getcoords(pysim)
end

function propagate(sim::OpenMMSimulation, x0::AbstractMatrix, ny)
function propagate(s::OpenMMSimulation, x0::AbstractMatrix, ny)
dim, nx = size(x0)
ys = zeros(dim, nx, ny)
for i in 1:nx, j in 1:ny
ys[:, i, j] = propagate(sim, x0[:, i])
ys[:, i, j] = propagate(s, x0[:, i])
end
return ys
end

function propagate_threaded(s::OpenMMSimulation, x0::AbstractMatrix, ny; nthreads=1)
zs = repeat(x0, outer=[1, ny])
dim, nx = size(x0)

zs = PyReverseDims(reinterpret(Tuple{Float64,Float64,Float64}, zs))
steps = s.steps
sim = s.pysim
n = nx * ny

py"""
def singlerun(i):
s = copy.copy($sim) # TODO: this is not enough
s.context = copy.copy(s.context)
s.context.setPositions($zs[i] * nanometer)
s.step($steps)
z = s.context.getState(getPositions=True).getPositions().value_in_unit(nanometer)
return z
out = Parallel(n_jobs=$nthreads, prefer="threads")(
delayed(singlerun)(i) for i in range($n))
"""

zs = py"out"
zs = reinterpret(Float64, permutedims(zs))
zs = reshape(zs, dim, nx, ny)

return zs
end

function propagate(s::OpenMMSimulation, x0)
setcoords(s.pysim, x0)
s.pysim.step(s.steps)
getcoords(s.pysim)
end

function getcoords(sim::PyObject)
x = sim.context.getState(getPositions=true).getPositions().value_in_unit(nanometer)
reinterpret(Float64, x)
Expand All @@ -47,13 +82,14 @@ end

###

""" Basic construction of a OpenMM Simulation, following the OpenMM documentation example """
function openmm_examplesys(;
temp=300,
friction=1,
step=0.004,
pdb="/home/htc/bzfsikor/.julia/conda/3/share/openmm/examples/input.pdb",
forcefields=["amber14-all.xml", "amber14/tip3pfb.xml"],
steps=100)
steps=1)

py"""
pdb = PDBFile($pdb)
Expand All @@ -63,8 +99,9 @@ function openmm_examplesys(;
integrator = LangevinMiddleIntegrator($temp*kelvin, $friction/picosecond, $step*picoseconds)
simulation = Simulation(pdb.topology, system, integrator)
simulation.context.setPositions(pdb.positions)
simulation.minimizeEnergy()
"""
# simulation.minimizeEnergy()
#
# simulation.reporters.append(PDBReporter('output.pdb', 1000))
# simulation.reporters.append(StateDataReporter(stdout, 1000, step=True,
# potentialEnergy=True, temperature=True))
Expand Down

0 comments on commit 0ba14c6

Please sign in to comment.