Skip to content

Commit

Permalink
update examples, fix glitch with pyanidirectoryDB
Browse files Browse the repository at this point in the history
  • Loading branch information
lubbersnick committed Aug 14, 2024
1 parent 9a1a662 commit dcd78d6
Show file tree
Hide file tree
Showing 14 changed files with 31 additions and 42 deletions.
3 changes: 1 addition & 2 deletions examples/allegro_ag_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
"""
import os
import torch
import ase.io
import time

import hippynn
Expand All @@ -30,7 +29,7 @@
torch.set_default_dtype(torch.float32)
hippynn.settings.WARN_LOW_DISTANCES = False

max_epochs=500
max_epochs = 500

network_params = {
"possible_species": [0, 47],
Expand Down
6 changes: 0 additions & 6 deletions examples/ani1x_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
This script was designed for an external dataset available at
https://doi.org/10.6084/m9.figshare.c.4712477
pyanitools reader available at
https://github.com/aiqm/ANI1x_datasets
For info on the dataset, see the following publication:
Smith, J.S., Zubatyuk, R., Nebgen, B. et al.
Expand All @@ -20,10 +18,6 @@
import hippynn
import ase.units

import sys
sys.path.append("../../datasets/ani-al/readers/lib/") # location of pyanitools.py

import pyanitools

def make_model(network_params,tensor_order):
"""
Expand Down
4 changes: 0 additions & 4 deletions examples/ani_aluminum_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@
"""

import sys

sys.path.append("../../datasets/ani-al/readers/lib/") # location of pyanitools.py
import pyanitools # Check if pyanitools is found early

import torch

Expand Down
2 changes: 0 additions & 2 deletions examples/ani_aluminum_example_multilayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@

import sys

sys.path.append("../../datasets/ani-al/readers/lib/") # location of pyanitools.py
import pyanitools # Check if pyanitools is found early

import torch

Expand Down
2 changes: 1 addition & 1 deletion examples/ase_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
# Load the files
try:
with active_directory("TEST_ALUMINUM_MODEL", create=False):
bundle = load_checkpoint_from_cwd(map_location="cpu", restore_db=False)
bundle = load_checkpoint_from_cwd(map_location="cpu")
except FileNotFoundError:
raise FileNotFoundError("Model not found, run ani_aluminum_example.py first!")

Expand Down
2 changes: 1 addition & 1 deletion examples/ase_example_multilayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
# Load the files
try:
with active_directory("TEST_ALUMINUM_MODEL_MULTILAYER", create=False):
bundle = load_checkpoint_from_cwd(map_location='cpu',restore_db=False)
bundle = load_checkpoint_from_cwd(map_location='cpu',e)
except FileNotFoundError:
raise FileNotFoundError("Model not found, run ani_aluminum_example_multilayer.py first!")

Expand Down
9 changes: 2 additions & 7 deletions examples/close_contact_finding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@
before running this script.
"""
import sys

sys.path.append("../../datasets/ani-al/readers/lib/") # location of pyanitools.py
import pyanitools # Check if pyanitools is found early

### Loading the database
from hippynn.databases.h5_pyanitools import PyAniDirectoryDB
Expand Down Expand Up @@ -59,12 +55,11 @@

#### How to remove and separate low distance configurations
dist_thresh = 1.7 # Note: what threshold to use may be highly problem-dependent.
low_dist_configs = min_dist_array < dist_thresh
where_low_dist = database.arr_dict["indices"][low_dist_configs]
low_dist_config_mask = min_dist_array < dist_thresh

# This makes the low distance configurations
# into their own split, separate from train/valid/test.
database.make_explicit_split("LOW_DISTANCE_FILTER", where_low_dist)
database.make_explicit_split_bool("LOW_DISTANCE_FILTER", low_dist_config_mask)

# This deletes the new split, although deleting it is not necessary;
# this data will not be included in train/valid/test splits
Expand Down
2 changes: 1 addition & 1 deletion examples/lammps/pickle_mliap_unified_hippynn_Al.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# Load trained model
try:
with active_directory("../TEST_ALUMINUM_MODEL", create=False):
bundle = load_checkpoint_from_cwd(map_location="cpu", restore_db=False)
bundle = load_checkpoint_from_cwd(map_location="cpu")
except FileNotFoundError:
raise FileNotFoundError("Model not found, run ani_aluminum_example.py first!")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# Load trained model
try:
with active_directory("../TEST_ALUMINUM_MODEL_MULTILAYER", create=False):
bundle = load_checkpoint_from_cwd(map_location="cpu", restore_db=False)
bundle = load_checkpoint_from_cwd(map_location="cpu")
except FileNotFoundError:
raise FileNotFoundError("Model not found, run ani_aluminum_example_multilayer.py first!")

Expand Down
2 changes: 1 addition & 1 deletion examples/lammps/pickle_mliap_unified_hippynn_InP.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# Load trained model
try:
with active_directory("../TEST_INP_MODEL", create=False):
bundle = load_checkpoint_from_cwd(map_location="cpu", restore_db=False)
bundle = load_checkpoint_from_cwd(map_location="cpu")
except FileNotFoundError:
raise FileNotFoundError("Model not found, run lammps_train_model_InP.py first!")

Expand Down
2 changes: 1 addition & 1 deletion examples/molecular_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
# Load the pre-trained model
try:
with active_directory("TEST_ALUMINUM_MODEL", create=False):
bundle = load_checkpoint_from_cwd(map_location="cpu", restore_db=False)
bundle = load_checkpoint_from_cwd(map_location="cpu")
except FileNotFoundError:
raise FileNotFoundError("Model not found, run ani_aluminum_example.py first!")

Expand Down
11 changes: 7 additions & 4 deletions hippynn/databases/h5_pyanitools.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,13 +212,15 @@ def load_arrays(self, allow_unfound=False, quiet=False):

class PyAniDirectoryDB(Database, PyAniMethods, Restartable):
def __init__(self, directory, inputs, targets, *args, files=None, allow_unfound=False, species_key="species",
quiet=False,**kwargs):
quiet=False, driver='core', **kwargs):

self.directory = directory
self.files = files
self.inputs = inputs
self.targets = targets
self.species_key = species_key
self.driver = driver

arr_dict = self.load_arrays(allow_unfound=allow_unfound,quiet=quiet)

super().__init__(arr_dict, inputs, targets, *args, **kwargs, quiet=quiet, allow_unfound=allow_unfound)
Expand All @@ -242,14 +244,15 @@ def load_arrays(self, allow_unfound=False, quiet=False):

file_batches = []
for f in progress_bar(files, desc="Data Files", unit="file"):
file_batches.append(self.extract_full_file(f,species_key=self.species_key))
file_batches.append(self.extract_full_file(f, species_key=self.species_key))

data, max_atoms_list = zip(*file_batches)
data, max_atoms_list, sys_count = zip(*file_batches)

n_atoms_max = max(max_atoms_list)
batches = [item for fb in data for item in fb]
sys_count = sum(sys_count)

arr_dict = self.process_batches(batches, n_atoms_max, species_key=self.species_key)
arr_dict = self.process_batches(batches, n_atoms_max, sys_count, species_key=self.species_key)
arr_dict = self.filter_arrays(arr_dict, quiet=quiet, allow_unfound=allow_unfound)
return arr_dict

Expand Down
6 changes: 3 additions & 3 deletions hippynn/databases/restarter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@


class Restarter:
def attempt_reload(self):
def attempt_restart(self):
return NotImplemented


class NoRestart(Restarter):
def attempt_reload(self):
def attempt_restart(self):
print("Couldn't reload database. It might have been generated in-memory.")
return None

Expand Down Expand Up @@ -54,7 +54,7 @@ def __setstate__(self, state):
for k, v in state.items():
setattr(self, k, v)

def attempt_reload(self):
def attempt_restart(self):
print("restarting", self.cls)
if isinstance(self.cls, str):
raise RuntimeError(f"Not restartable due to class error: {self.cls}")
Expand Down
20 changes: 12 additions & 8 deletions hippynn/experiment/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,20 +67,23 @@ def create_structure_file(
torch.save(structure, pfile)


def restore_checkpoint(structure: dict, state: dict, restore_db=True) -> dict:
def restore_checkpoint(structure: dict, state: dict, restart_db=False) -> dict:
"""
This function loads the parameters from the state dictionary into the modules,
optionally tries to restart the database, and sets the RNG state.
:param structure: experiment structure object
:param state: experiment state object
:param restore_db: Attempt to restore database (true/false)
:param restart_db: Attempt to restore database (true/false)
:return: experiment structure
"""
structure["training_modules"][0].load_state_dict(state["model"])
structure["controller"].load_state_dict(state["controller"])

if "database" in structure and restore_db:
structure["database"] = structure["database"].attempt_reload()
if "database" in structure and restart_db:
structure["database"] = structure["database"].attempt_restart()

structure["metric_tracker"] = state["metric_tracker"]
torch.random.set_rng_state(state["torch_rng_state"])
Expand Down Expand Up @@ -109,7 +112,8 @@ def check_mapping_devices(map_location, model_device):


def load_saved_tensors(structure_fname: str, state_fname: str, **kwargs) -> Tuple[dict, dict]:
"""Load torch tensors from file.
"""
Load torch tensors from file.
:param structure_fname: name of the structure file
:param state_fname: name of the state file
Expand All @@ -125,7 +129,7 @@ def load_saved_tensors(structure_fname: str, state_fname: str, **kwargs) -> Tupl


def load_checkpoint(
structure_fname: str, state_fname: str, restore_db=True, map_location=None, model_device=None, **kwargs
structure_fname: str, state_fname: str, restart_db=False, map_location=None, model_device=None, **kwargs
) -> dict:
"""
Load checkpoint file from given filename.
Expand All @@ -134,7 +138,7 @@ def load_checkpoint(
:param structure_fname: name of the structure file
:param state_fname: name of the state file
:param restore_db: restore database or not, defaults to True
:param restart_db: restore database or not, defaults to True
:param map_location: device mapping argument for ``torch.load``, defaults to None
:param model_device: automatically handle device mapping. Defaults to None, defaults to None
:return: experiment structure
Expand All @@ -146,7 +150,7 @@ def load_checkpoint(
structure, state = load_saved_tensors(structure_fname, state_fname, **kwargs)

# transfer stuff back to model_device
structure = restore_checkpoint(structure, state, restore_db=restore_db)
structure = restore_checkpoint(structure, state, restart_db=restart_db)
# no transfer happens in either case, as the tensors are on the target devices already
if model_device == "cpu" or map_location != None:
evaluator = structure["training_modules"].evaluator
Expand Down

0 comments on commit dcd78d6

Please sign in to comment.