Skip to content

Commit

Permalink
Address peastman's comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
RaulPPelaez committed Aug 16, 2024
1 parent 3f6d6ff commit 3de4a19
Showing 1 changed file with 23 additions and 45 deletions.
68 changes: 23 additions & 45 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ torch_force.setOutputsForces(True)
Computing energy derivatives with respect to global parameters
--------------------------------------------------------------

Its possible to query `TorchForce` for the derivative of the energy with respect to global parameters. In order to do so the global parameters must be registered as energy derivatives. This is done by calling `addEnergyParameterDerivative()` for each parameter.
TorchForce can compute derivatives of the energy with respect to global parameters.. In order to do so the global parameters must be registered as energy derivatives. This is done by calling `addEnergyParameterDerivative()` for each parameter.

The parameter derivatives can be queried by calling `getEnergyParameterDerivatives()` on the `State` object returned by `Context.getState()`. The result is a dictionary with the parameter names as keys and the derivatives as values.

Expand All @@ -274,50 +274,28 @@ class ForceWithParameters(pt.nn.Module):
def __init__(self):
super(ForceWithParameters, self).__init__()

def forward(
self, positions: Tensor, parameter1: Tensor, parameter2: Tensor
) -> Tensor:
x2 = positions.pow(2).sum(dim=1)
u_harmonic = ((parameter1 + parameter2**2) * x2).sum()
return u_harmonic


def example():
numParticles = 10
system = mm.System()
positions = np.random.rand(numParticles, 3)
for _ in range(numParticles):
system.addParticle(1.0)

pt_force = ForceWithParameters()
model = pt.jit.script(pt_force)
tforce = TorchForce(model)
parameter1 = 1.0
parameter2 = 1.0
force.setOutputsForces(False)
force.addGlobalParameter("parameter1", parameter1)
force.addEnergyParameterDerivative("parameter1")
force.addGlobalParameter("parameter2", parameter2)
force.addEnergyParameterDerivative("parameter2")
system.addForce(force)
integ = mm.VerletIntegrator(1.0)
platform = mm.Platform.getPlatformByName(platform)
context = mm.Context(system, integ, platform)
context.setPositions(positions)
state = context.getState(
getEnergy=True, getForces=True, getParameterDerivatives=True
)
# The network defines a potential of the form E(r) = (parameter1 + parameter2**2)*|r|^2
r2 = np.sum(positions * positions)
expectedEnergy = (parameter1 + parameter2**2) * r2
assert np.allclose(
r2,
state.getEnergyParameterDerivatives()["parameter1"],
)
assert np.allclose(
2 * parameter2 * r2,
state.getEnergyParameterDerivatives()["parameter2"],
)
def forward(self, positions: Tensor, k: Tensor) -> Tensor:
return k*torch.sum(positions**2)


numParticles = 10
system = mm.System()
for _ in range(numParticles):
system.addParticle(1.0)

model = pt.jit.script(ForceWithParameters())
tforce = TorchForce(model)
force.setOutputsForces(False)
force.addGlobalParameter("k", 2.0)
force.addEnergyParameterDerivative("k")
system.addForce(force)
integ = mm.VerletIntegrator(1.0)
platform = mm.Platform.getPlatformByName(platform)
context = mm.Context(system, integ, platform)
positions = np.random.rand(numParticles, 3)
context.setPositions(positions)
state = context.getState(getParameterDerivatives=True)
dEdk = state.getEnergyParameterDerivatives()["k"]
```


Expand Down

0 comments on commit 3de4a19

Please sign in to comment.