Skip to content

Commit

Permalink
openmm: fix girsanov sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
axsk committed Sep 12, 2024
1 parent e394997 commit 470c38f
Showing 1 changed file with 18 additions and 11 deletions.
29 changes: 18 additions & 11 deletions src/simulators/openmm.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module OpenMM

using PyCall, CUDA
using LinearAlgebra: norm
using LinearAlgebra: norm, dot

import JLD2
import ..ISOKANN: ISOKANN, IsoSimulation,
Expand Down Expand Up @@ -413,27 +413,34 @@ function integrate_girsanov(sim::OpenMMSimulation; x0=getcoords(sim), steps=step
# TODO: check units on the following three lines
kB = 0.008314463
dt = stepsize(sim)
gamma = friction(sim)
γ = friction(sim)

sigma = sqrt(2 * gamma * kB * temp(sim))
m = repeat(masses(sim), inner=3)
M = repeat(masses(sim), inner=3)
T = temp(sim)
σ = @. sqrt(2 * kB * T /* M))

x = copy(x0)
g = 0.

z = similar(x, length(x), steps)

for i in 1:steps
g += od_langevin_step_girsanov!(x, F, m, sigma, dt, u, g)
F = force(sim, x)
ux = u(x)
g += od_langevin_step_girsanov!(x, F, M, σ, γ, dt, ux)
z[:, i] = x
end

return x, g
return x, g, z
end

function od_langevin_step_girsanov!(x, F, m, sigma, dt, u, g)
dB = randn(length(x))
ux = u(x)
@. x += 1 / m * ((F + sigma * ux) * dt + sigma * sqrt(dt) * dB)
dg = 1/2 * dt * dot(ux, ux) + sqrt(dt) * dot(ux, dB)
function od_langevin_step_girsanov!(x, F, M, σ, γ, dt, u)
dB = randn(length(x)) * sqrt(dt)
@. x += (1 /* M) * F +* u)) * dt + σ * dB
dg = dot(u, u) / 2 * dt + dot(u, dB) * sqrt(dt)
return dg
end



end #module

0 comments on commit 470c38f

Please sign in to comment.