Skip to content

Commit

Permalink
Serialize property values (#152)
Browse files Browse the repository at this point in the history
  • Loading branch information
peastman authored Aug 23, 2024
1 parent 8893c0f commit 9ceb9a0
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
10 changes: 8 additions & 2 deletions serialization/src/TorchForceProxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ TorchForceProxy::TorchForceProxy() : SerializationProxy("TorchForce") {
}

void TorchForceProxy::serialize(const void* object, SerializationNode& node) const {
node.setIntProperty("version", 3);
node.setIntProperty("version", 4);
const TorchForce& force = *reinterpret_cast<const TorchForce*>(object);
node.setStringProperty("file", force.getFile());
try {
Expand All @@ -95,11 +95,14 @@ void TorchForceProxy::serialize(const void* object, SerializationNode& node) con
SerializationNode& paramDerivs = node.createChildNode("ParameterDerivatives");
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++)
paramDerivs.createChildNode("Parameter").setStringProperty("name", force.getEnergyParameterDerivativeName(i));
SerializationNode& properties = node.createChildNode("Properties");
for (auto& prop : force.getProperties())
properties.createChildNode("Property").setStringProperty("name", prop.first).setStringProperty("value", prop.second);
}

void* TorchForceProxy::deserialize(const SerializationNode& node) const {
int storedVersion = node.getIntProperty("version");
if (storedVersion > 3)
if (storedVersion > 4)
throw OpenMMException("Unsupported version number");
TorchForce* force;
if (storedVersion == 1) {
Expand All @@ -126,6 +129,9 @@ void* TorchForceProxy::deserialize(const SerializationNode& node) const {
if (child.getName() == "ParameterDerivatives")
for (auto& parameter : child.getChildren())
force->addEnergyParameterDerivative(parameter.getStringProperty("name"));
if (child.getName() == "Properties")
for (auto& property : child.getChildren())
force->setProperty(property.getStringProperty("name"), property.getStringProperty("value"));
}
return force;
}
5 changes: 5 additions & 0 deletions serialization/tests/TestSerializeTorchForce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ void serializeAndDeserialize(TorchForce force) {
force.setUsesPeriodicBoundaryConditions(true);
force.setOutputsForces(true);
force.addEnergyParameterDerivative("y");
force.setProperty("useCUDAGraphs", "true");
force.setProperty("CUDAGraphWarmupSteps", "5");

// Serialize and then deserialize it.

Expand All @@ -77,6 +79,9 @@ void serializeAndDeserialize(TorchForce force) {
ASSERT_EQUAL(force.getNumEnergyParameterDerivatives(), force2.getNumEnergyParameterDerivatives());
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++)
ASSERT_EQUAL(force.getEnergyParameterDerivativeName(i), force2.getEnergyParameterDerivativeName(i));
ASSERT_EQUAL(force.getProperties().size(), force2.getProperties().size());
for (auto& prop : force.getProperties())
ASSERT_EQUAL(prop.second, force2.getProperties().at(prop.first));
}

void testSerializationFromModule() {
Expand Down

0 comments on commit 9ceb9a0

Please sign in to comment.