Skip to content

Commit

Permalink
polish propagate_threaded()
Browse files Browse the repository at this point in the history
move the array conversions to numpy, cleaner and faster. hygiene with py-threadedrun
  • Loading branch information
axsk committed Jan 22, 2024
1 parent 536b519 commit 25136f4
Showing 1 changed file with 25 additions and 36 deletions.
61 changes: 25 additions & 36 deletions src/openmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,21 @@ from openmm.app import *
from openmm import *
from openmm.unit import *
from sys import stdout
import itertools
import numpy as np
def unpack(out):
return list(itertools.chain.from_iterable(out))
"""
# still allocating, but 4x as fast as anything else i could find
mypyvec(out) = reinterpret(Float64, pycall(py"unpack", Vector{Tuple{Float64,Float64,Float64}}, out))
def threadedrun(sim, steps, n, nthreads, xs):
def singlerun(i):
c = Context(sim.system, copy.copy(sim.integrator))
c.setPositions(xs[i])
c.setVelocitiesToTemperature(sim.integrator.getTemperature())
c.getIntegrator().step(steps)
return c.getState(getPositions=True).getPositions(asNumpy=True).value_in_unit(nanometer)
out = Parallel(n_jobs=nthreads, prefer="threads")(delayed(singlerun)(i) for i in range(n))
return np.array(out).flatten()
"""
threadedrun = py"threadedrun"
nanometer = py"nanometer"

""" A Simulation wrapping the Python OpenMM Simulation object """
Expand All @@ -31,6 +38,17 @@ end

###

""" multi-threaded propagation of an `OpenMMSimulation` """
function propagate_threaded(s::OpenMMSimulation, x0::AbstractMatrix, ny; nthreads=1)
dim, nx = size(x0)
xs = repeat(x0, outer=[1, ny])
xs = permutedims(reinterpret(Tuple{Float64,Float64,Float64}, xs), (2, 1))
ys = @pycall threadedrun(s.pysim, s.steps, nx * ny, nthreads, xs)::PyArray
return reshape(ys, dim, nx, ny)
end

###

function propagate(s::OpenMMSimulation, x0::AbstractMatrix, ny)
dim, nx = size(x0)
ys = zeros(dim, nx, ny)
Expand All @@ -40,43 +58,14 @@ function propagate(s::OpenMMSimulation, x0::AbstractMatrix, ny)
return ys
end

# producing nans with nthreads = 1, crashing for nthreads > 1
function propagate_threaded(s::OpenMMSimulation, x0::AbstractMatrix, ny; nthreads=1)
xs = repeat(x0, outer=[1, ny])
dim, nx = size(x0)

xs = PyReverseDims(reinterpret(Tuple{Float64,Float64,Float64}, xs))
steps = s.steps
sim = s.pysim
n = nx * ny

py"""
def singlerun(i):
x = $xs[i]
sim = $sim
steps = $steps
c = Context(sim.system, copy.copy(sim.integrator))
c.setPositions(x)
c.setVelocitiesToTemperature(sim.integrator.getTemperature())
c.getIntegrator().step(steps)
return c.getState(getPositions=True).getPositions().value_in_unit(nanometer)
out = Parallel(n_jobs=$nthreads, prefer="threads")(delayed(singlerun)(i) for i in range($n))
"""
return reshape(mypyvec(py"out"o), dim, nx, ny)


#zs = reshape(reinterpret(Float64, permutedims(py"out")), dim, nx, ny)
#return zs
end

function propagate(s::OpenMMSimulation, x0)
setcoords(s.pysim, x0)
s.pysim.step(s.steps)
getcoords(s.pysim)
end

function getcoords(sim::PyObject)
# TODO: use asNumpy=True
x = sim.context.getState(getPositions=true).getPositions().value_in_unit(nanometer)
reinterpret(Float64, x)
end
Expand Down

0 comments on commit 25136f4

Please sign in to comment.