forked from lanl/hippynn
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add coarse-graining example (lanl#98)
* first draft * track unwrapped and wrapped positions in MD code when cell is present, fix typo * remove unused imports, update md length * add link to data on Zenodo, change dataset filename
- Loading branch information
1 parent
0004728
commit 10a9055
Showing
6 changed files
with
351 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
The files in this directory allow one to train and run MD with a coarse-grained HIPNN model. Details of this model can be found in the paper "Thermodynamic Transferability in Coarse-Grained Force Fields using Graph Neural Networks" by Shinkle et. al. available at <https://doi.org/10.48550/arXiv.2406.12112>. | ||
|
||
Before executing these files, one must download the training data from <https://doi.org/10.5281/zenodo.13717306>. The file should be placed at `datasets/cg_methanol_trajectory.npz` where `datasets/` is at the same level as the hippynn repository. | ||
|
||
1. Run `cg_training.py` to generate a model. This model will be saved in `hippynn/examples/coarse-graining/model`. | ||
2. Run `cg_md.py` to run MD using the model trained in step 1. The resulting trajectory will be saved in `hippynn/examples/coarse-graining/md_results`. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
import os | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from ase import units | ||
|
||
from hippynn.experiment.serialization import load_checkpoint_from_cwd | ||
from hippynn.graphs.predictor import Predictor | ||
from hippynn.molecular_dynamics.md import ( | ||
Variable, | ||
NullUpdater, | ||
LangevinDynamics, | ||
MolecularDynamics, | ||
) | ||
from hippynn.tools import active_directory | ||
|
||
default_dtype=torch.float | ||
torch.set_default_dtype(default_dtype) | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
# Load initial conditions | ||
training_data_file = os.path.join(os.pardir,os.pardir,os.pardir,"datasets","cg_methanol_trajectory.npz") | ||
|
||
with np.load(training_data_file) as data: | ||
cell = torch.as_tensor(data["cells"][-1], dtype=default_dtype, device=device)[None,...] | ||
masses = torch.as_tensor(data["masses"][-1], dtype=default_dtype, device=device)[None,...] | ||
positions = torch.as_tensor(data["positions"][-1], dtype=default_dtype, device=device)[None,...] | ||
velocities = torch.as_tensor(data["velocities"][-1], dtype=default_dtype, device=device)[None,...] | ||
species = torch.as_tensor(data["species"][-1], dtype=torch.int, device=device)[None,...] | ||
|
||
positions_variable = Variable( | ||
name="positions", | ||
data={ | ||
"position": positions, | ||
"velocity": velocities, | ||
"mass": masses, | ||
"acceleration": torch.zeros_like(velocities), | ||
"cell": cell, | ||
}, | ||
model_input_map={"positions": "position"}, | ||
device=device, | ||
) | ||
|
||
position_updater = LangevinDynamics( | ||
force_db_name="forces", | ||
temperature=700, | ||
frix=6, | ||
units_force=units.kcal / units.mol / units.Ang, | ||
units_acc=units.Ang / ((1000 * units.fs)**2), | ||
seed=1993, | ||
) | ||
positions_variable.updater = position_updater | ||
|
||
cell_variable = Variable( | ||
name="cell", | ||
data={"cell": cell}, | ||
model_input_map={"cells": "cell"}, | ||
device=device, | ||
updater=NullUpdater(), | ||
) | ||
|
||
species_variable = Variable( | ||
name="species", | ||
data={"species": species}, | ||
model_input_map={"species": "species"}, | ||
device=device, | ||
updater=NullUpdater(), | ||
) | ||
|
||
# Load model | ||
with active_directory("model"): | ||
check = load_checkpoint_from_cwd(model_device=device, restart_db=False) | ||
|
||
repulse = check["training_modules"].model.node_from_name("repulse") | ||
energy = check["training_modules"].model.node_from_name("sys_energy") | ||
|
||
model = Predictor.from_graph( | ||
check["training_modules"].model, | ||
additional_outputs=[ | ||
repulse.mol_energies, | ||
energy, | ||
], | ||
) | ||
|
||
model = Predictor.from_graph(check["training_modules"].model) | ||
|
||
model.to(default_dtype) | ||
model.to(device) | ||
|
||
pairs = model.graph.node_from_name("pairs") | ||
pairs.skin = 3 # see hippynn.graphs.nodes.pairs.KDTreePairsMemory documentation | ||
|
||
# Run MD | ||
with active_directory("md_results"): | ||
emdee = MolecularDynamics( | ||
variables=[positions_variable, species_variable, cell_variable], | ||
model=model, | ||
) | ||
|
||
emdee.run(dt=0.001, n_steps=20000) | ||
emdee.run(dt=0.001, n_steps=50000, record_every=50) | ||
|
||
data = emdee.get_data() | ||
np.savez("hippynn_cg_trajectory.npz", | ||
positions = data["positions_position"], | ||
velocities = data["positions_velocity"], | ||
masses = data["positions_mass"], | ||
accelerations = data["positions_acceleration"], | ||
cells = data["positions_cell"], | ||
unwrapped_positions = data["positions_unwrapped_position"], | ||
forces = data["positions_force"], | ||
species = data["species_species"], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
import os | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from hippynn.databases import NPZDatabase | ||
from hippynn.experiment import SetupParams, setup_and_train | ||
from hippynn.experiment.assembly import assemble_for_training | ||
from hippynn.experiment.controllers import RaiseBatchSizeOnPlateau, PatienceController | ||
from hippynn.graphs import IdxType | ||
from hippynn.graphs.nodes import loss | ||
from hippynn.graphs.nodes.base.algebra import AddNode | ||
from hippynn.graphs.nodes.indexers import acquire_encoding_padding | ||
from hippynn.graphs.nodes.inputs import SpeciesNode, PositionsNode, CellNode | ||
from hippynn.graphs.nodes.networks import HipnnQuad | ||
from hippynn.graphs.nodes.pairs import KDTreePairsMemory | ||
from hippynn.graphs.nodes.physics import MultiGradientNode | ||
from hippynn.graphs.nodes.targets import HEnergyNode | ||
from hippynn.plotting import PlotMaker, Hist2D, SensitivityPlot | ||
from hippynn.tools import active_directory | ||
|
||
from repulsive_potential import RepulsivePotentialNode | ||
|
||
training_data_file = os.path.join(os.pardir,os.pardir,os.pardir,"datasets","cg_methanol_trajectory.npz") | ||
|
||
with np.load(training_data_file) as data: | ||
idx = np.where(data["rdf_values"] > 0.01)[0][0] | ||
repulsive_potential_taper_point = data["rdf_bins"][idx] | ||
repulsive_potential_strength = np.abs(data["forces"]).mean() | ||
|
||
## Initialize needed nodes for network | ||
# Network input nodes | ||
species = SpeciesNode(name="species", db_name="species") | ||
positions = PositionsNode(name="positions", db_name="positions") | ||
cells = CellNode(name="cells", db_name="cells") | ||
|
||
# Network hyperparameters | ||
network_params = { | ||
"possible_species": [0,1], | ||
"n_features": 128, | ||
"n_sensitivities": 20, | ||
"dist_soft_min": 2.0, | ||
"dist_soft_max": 13.0, | ||
"dist_hard_max": 15.0, | ||
"n_interaction_layers": 1, | ||
"n_atom_layers": 3, | ||
"sensitivity_type": "inverse", | ||
"resnet": True, | ||
} | ||
|
||
# Species encoder | ||
enc, pdx = acquire_encoding_padding([species], species_set=[0,1]) | ||
|
||
# Pair finder | ||
pair_finder = KDTreePairsMemory( | ||
"pairs", | ||
(positions, enc, pdx, cells), | ||
dist_hard_max=network_params["dist_hard_max"], | ||
skin=0, | ||
) | ||
|
||
# HIP-NN-TS node with l=2 | ||
network = HipnnQuad( | ||
"HIPNN", (pdx, pair_finder), module_kwargs=network_params, periodic=True | ||
) | ||
|
||
# Network energy prediction | ||
henergy = HEnergyNode("HEnergy", parents=(network,)) | ||
|
||
# Repulsive potential | ||
repulse = RepulsivePotentialNode( | ||
"repulse", | ||
(pair_finder, pdx), | ||
taper_point=repulsive_potential_taper_point, | ||
strength=repulsive_potential_strength, | ||
dr=0.15, | ||
perc=0.05, | ||
) | ||
|
||
# Combined energy prediction | ||
energy = AddNode(henergy.main_output, repulse.mol_energies) | ||
energy.name = "energies" | ||
energy._index_state = IdxType.Molecules | ||
|
||
sys_energy = energy.main_output | ||
sys_energy.name = "sys_energy" | ||
|
||
# Force node | ||
grad = MultiGradientNode("forces", energy, (positions,), signs=-1) | ||
force = grad.children[0] | ||
force.db_name = "forces" | ||
|
||
## Define losses | ||
force_rsq = loss.Rsq.of_node(force) | ||
force_rmse = loss.MSELoss.of_node(force) ** (1 / 2) | ||
force_mae = loss.MAELoss.of_node(force) | ||
total_loss = force_rmse + force_mae | ||
|
||
validation_losses = { | ||
"ForceRMSE": force_rmse, | ||
"ForceMAE": force_mae, | ||
"ForceRsq": force_rsq, | ||
"TotalLoss": total_loss, | ||
} | ||
|
||
plotters = [ | ||
Hist2D.compare(force, saved="forces", shown=False), | ||
SensitivityPlot( | ||
network.torch_module.sensitivity_layers[0], saved="sensitivity", shown=False | ||
), | ||
] | ||
|
||
plot_maker = PlotMaker( | ||
*plotters, | ||
plot_every=10, | ||
) | ||
|
||
## Build network | ||
training_modules, db_info = assemble_for_training( | ||
total_loss, validation_losses, plot_maker=plot_maker | ||
) | ||
|
||
## Load training data | ||
database = NPZDatabase( | ||
training_data_file, | ||
seed=0, | ||
**db_info, | ||
valid_size=0.1, | ||
test_size=0.1, | ||
) | ||
|
||
## Set up optimizer | ||
optimizer = torch.optim.Adam(training_modules.model.parameters(), lr=1e-3) | ||
|
||
scheduler = RaiseBatchSizeOnPlateau( | ||
optimizer=optimizer, | ||
max_batch_size=64, | ||
patience=10, | ||
factor=0.5, | ||
) | ||
|
||
controller = PatienceController( | ||
optimizer=optimizer, | ||
scheduler=scheduler, | ||
batch_size=1, | ||
fraction_train_eval=0.2, | ||
eval_batch_size=1, | ||
max_epochs=200, | ||
termination_patience=20, | ||
stopping_key="TotalLoss", | ||
) | ||
|
||
experiment_params = SetupParams(controller=controller) | ||
|
||
## Train! | ||
with active_directory("model"): | ||
metric_tracker = setup_and_train( | ||
training_modules=training_modules, | ||
database=database, | ||
setup_params=experiment_params, | ||
) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import math | ||
import torch | ||
|
||
from hippynn.graphs import IdxType | ||
from hippynn.graphs.nodes.base import ExpandParents | ||
from hippynn.graphs.nodes.base.definition_helpers import AutoKw | ||
from hippynn.graphs.nodes.base.multi import MultiNode | ||
from hippynn.graphs.nodes.tags import PairIndexer, AtomIndexer | ||
from hippynn.layers import pairs | ||
|
||
## Define repulsive potential node for hippynn | ||
class RepulsivePotential(torch.nn.Module): | ||
def __init__(self, taper_point, strength, dr, perc): | ||
''' | ||
Let F(r) be the force between two particles of distance r generated | ||
by this potential. Then | ||
F(taper_point) = perc * strength | ||
F(taper_point - dr) = strength | ||
Eg. If taper_point=3, strength=1, dr=0.5, and perc=0.01, then | ||
F(3) = 0.01 | ||
F(2.5) = 1 | ||
''' | ||
super().__init__() | ||
self.t = taper_point | ||
self.s = strength | ||
self.d = dr | ||
self.p = perc | ||
|
||
self.a = (1/self.d)*math.log(1/self.p) | ||
self.g = -1 * self.s * self.p * math.exp(self.a * self.t) / self.a | ||
|
||
self.summer = pairs.MolPairSummer() | ||
|
||
def forward(self, pair_dist, pair_first, mol_index, n_molecules): | ||
atom_energies = -1 * self.g * torch.exp(-1 * self.a * pair_dist) | ||
mol_energies = self.summer(atom_energies, mol_index, n_molecules, pair_first) | ||
return mol_energies, atom_energies, | ||
|
||
class RepulsivePotentialNode(ExpandParents, AutoKw, MultiNode): | ||
_input_names = "pair_dist", "pair_first", "mol_index", "n_molecules" | ||
_output_names = "mol_energies", "atom_energies", | ||
_auto_module_class = RepulsivePotential | ||
_output_index_states = IdxType.Molecules, IdxType.Pair, | ||
|
||
@_parent_expander.match(PairIndexer, AtomIndexer) | ||
def expansion(self, pairfinder, pidxer, **kwargs): | ||
return pairfinder.pair_dist, pairfinder.pair_first, pidxer.mol_index, pidxer.n_molecules | ||
|
||
def __init__(self, name, parents, taper_point, strength, dr, perc, module="auto"): | ||
self.module_kwargs = { | ||
"taper_point": taper_point, | ||
"strength": strength, | ||
"dr": dr, | ||
"perc": perc, | ||
} | ||
parents = self.expand_parents(parents, module="auto") | ||
super().__init__(name, parents, module=module) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters