diff --git a/examples/allegro_ag_example.py b/examples/allegro_ag_example.py index 71f44f4f..0a891cab 100644 --- a/examples/allegro_ag_example.py +++ b/examples/allegro_ag_example.py @@ -20,7 +20,6 @@ """ import os import torch -import ase.io import time import hippynn @@ -30,7 +29,7 @@ torch.set_default_dtype(torch.float32) hippynn.settings.WARN_LOW_DISTANCES = False -max_epochs=500 +max_epochs = 500 network_params = { "possible_species": [0, 47], diff --git a/examples/ani1x_training.py b/examples/ani1x_training.py index c8fcba68..f97f6114 100644 --- a/examples/ani1x_training.py +++ b/examples/ani1x_training.py @@ -3,8 +3,6 @@ This script was designed for an external dataset available at https://doi.org/10.6084/m9.figshare.c.4712477 -pyanitools reader available at -https://github.com/aiqm/ANI1x_datasets For info on the dataset, see the following publication: Smith, J.S., Zubatyuk, R., Nebgen, B. et al. @@ -20,10 +18,6 @@ import hippynn import ase.units -import sys -sys.path.append("../../datasets/ani-al/readers/lib/") # location of pyanitools.py - -import pyanitools def make_model(network_params,tensor_order): """ diff --git a/examples/ani_aluminum_example.py b/examples/ani_aluminum_example.py index 1d0c7d44..da4cb2c9 100644 --- a/examples/ani_aluminum_example.py +++ b/examples/ani_aluminum_example.py @@ -22,10 +22,6 @@ """ -import sys - -sys.path.append("../../datasets/ani-al/readers/lib/") # location of pyanitools.py -import pyanitools # Check if pyanitools is found early import torch diff --git a/examples/ani_aluminum_example_multilayer.py b/examples/ani_aluminum_example_multilayer.py index b1ca647f..90984dc7 100644 --- a/examples/ani_aluminum_example_multilayer.py +++ b/examples/ani_aluminum_example_multilayer.py @@ -24,8 +24,6 @@ import sys -sys.path.append("../../datasets/ani-al/readers/lib/") # location of pyanitools.py -import pyanitools # Check if pyanitools is found early import torch diff --git a/examples/ase_example.py b/examples/ase_example.py index 99b1bf33..987b39c5 100644 --- a/examples/ase_example.py +++ b/examples/ase_example.py @@ -28,7 +28,7 @@ # Load the files try: with active_directory("TEST_ALUMINUM_MODEL", create=False): - bundle = load_checkpoint_from_cwd(map_location="cpu", restore_db=False) + bundle = load_checkpoint_from_cwd(map_location="cpu") except FileNotFoundError: raise FileNotFoundError("Model not found, run ani_aluminum_example.py first!") diff --git a/examples/ase_example_multilayer.py b/examples/ase_example_multilayer.py index 8249388a..c873745e 100644 --- a/examples/ase_example_multilayer.py +++ b/examples/ase_example_multilayer.py @@ -30,7 +30,7 @@ # Load the files try: with active_directory("TEST_ALUMINUM_MODEL_MULTILAYER", create=False): - bundle = load_checkpoint_from_cwd(map_location='cpu',restore_db=False) + bundle = load_checkpoint_from_cwd(map_location='cpu',e) except FileNotFoundError: raise FileNotFoundError("Model not found, run ani_aluminum_example_multilayer.py first!") diff --git a/examples/close_contact_finding.py b/examples/close_contact_finding.py index 36d3597f..e98fd668 100644 --- a/examples/close_contact_finding.py +++ b/examples/close_contact_finding.py @@ -13,10 +13,6 @@ before running this script. """ -import sys - -sys.path.append("../../datasets/ani-al/readers/lib/") # location of pyanitools.py -import pyanitools # Check if pyanitools is found early ### Loading the database from hippynn.databases.h5_pyanitools import PyAniDirectoryDB @@ -59,12 +55,11 @@ #### How to remove and separate low distance configurations dist_thresh = 1.7 # Note: what threshold to use may be highly problem-dependent. -low_dist_configs = min_dist_array < dist_thresh -where_low_dist = database.arr_dict["indices"][low_dist_configs] +low_dist_config_mask = min_dist_array < dist_thresh # This makes the low distance configurations # into their own split, separate from train/valid/test. -database.make_explicit_split("LOW_DISTANCE_FILTER", where_low_dist) +database.make_explicit_split_bool("LOW_DISTANCE_FILTER", low_dist_config_mask) # This deletes the new split, although deleting it is not necessary; # this data will not be included in train/valid/test splits diff --git a/examples/lammps/pickle_mliap_unified_hippynn_Al.py b/examples/lammps/pickle_mliap_unified_hippynn_Al.py index 364ce6e7..e79d779f 100644 --- a/examples/lammps/pickle_mliap_unified_hippynn_Al.py +++ b/examples/lammps/pickle_mliap_unified_hippynn_Al.py @@ -11,7 +11,7 @@ # Load trained model try: with active_directory("../TEST_ALUMINUM_MODEL", create=False): - bundle = load_checkpoint_from_cwd(map_location="cpu", restore_db=False) + bundle = load_checkpoint_from_cwd(map_location="cpu") except FileNotFoundError: raise FileNotFoundError("Model not found, run ani_aluminum_example.py first!") diff --git a/examples/lammps/pickle_mliap_unified_hippynn_Al_multilayer.py b/examples/lammps/pickle_mliap_unified_hippynn_Al_multilayer.py index 85d645f9..186b2d61 100644 --- a/examples/lammps/pickle_mliap_unified_hippynn_Al_multilayer.py +++ b/examples/lammps/pickle_mliap_unified_hippynn_Al_multilayer.py @@ -11,7 +11,7 @@ # Load trained model try: with active_directory("../TEST_ALUMINUM_MODEL_MULTILAYER", create=False): - bundle = load_checkpoint_from_cwd(map_location="cpu", restore_db=False) + bundle = load_checkpoint_from_cwd(map_location="cpu") except FileNotFoundError: raise FileNotFoundError("Model not found, run ani_aluminum_example_multilayer.py first!") diff --git a/examples/lammps/pickle_mliap_unified_hippynn_InP.py b/examples/lammps/pickle_mliap_unified_hippynn_InP.py index 70ea7515..2ff7f8f7 100644 --- a/examples/lammps/pickle_mliap_unified_hippynn_InP.py +++ b/examples/lammps/pickle_mliap_unified_hippynn_InP.py @@ -11,7 +11,7 @@ # Load trained model try: with active_directory("../TEST_INP_MODEL", create=False): - bundle = load_checkpoint_from_cwd(map_location="cpu", restore_db=False) + bundle = load_checkpoint_from_cwd(map_location="cpu") except FileNotFoundError: raise FileNotFoundError("Model not found, run lammps_train_model_InP.py first!") diff --git a/examples/molecular_dynamics.py b/examples/molecular_dynamics.py index 18a38028..bb44cf63 100644 --- a/examples/molecular_dynamics.py +++ b/examples/molecular_dynamics.py @@ -43,7 +43,7 @@ # Load the pre-trained model try: with active_directory("TEST_ALUMINUM_MODEL", create=False): - bundle = load_checkpoint_from_cwd(map_location="cpu", restore_db=False) + bundle = load_checkpoint_from_cwd(map_location="cpu") except FileNotFoundError: raise FileNotFoundError("Model not found, run ani_aluminum_example.py first!") diff --git a/hippynn/databases/h5_pyanitools.py b/hippynn/databases/h5_pyanitools.py index 41c1d2cd..fd4c6a27 100644 --- a/hippynn/databases/h5_pyanitools.py +++ b/hippynn/databases/h5_pyanitools.py @@ -212,13 +212,15 @@ def load_arrays(self, allow_unfound=False, quiet=False): class PyAniDirectoryDB(Database, PyAniMethods, Restartable): def __init__(self, directory, inputs, targets, *args, files=None, allow_unfound=False, species_key="species", - quiet=False,**kwargs): + quiet=False, driver='core', **kwargs): self.directory = directory self.files = files self.inputs = inputs self.targets = targets self.species_key = species_key + self.driver = driver + arr_dict = self.load_arrays(allow_unfound=allow_unfound,quiet=quiet) super().__init__(arr_dict, inputs, targets, *args, **kwargs, quiet=quiet, allow_unfound=allow_unfound) @@ -242,14 +244,15 @@ def load_arrays(self, allow_unfound=False, quiet=False): file_batches = [] for f in progress_bar(files, desc="Data Files", unit="file"): - file_batches.append(self.extract_full_file(f,species_key=self.species_key)) + file_batches.append(self.extract_full_file(f, species_key=self.species_key)) - data, max_atoms_list = zip(*file_batches) + data, max_atoms_list, sys_count = zip(*file_batches) n_atoms_max = max(max_atoms_list) batches = [item for fb in data for item in fb] + sys_count = sum(sys_count) - arr_dict = self.process_batches(batches, n_atoms_max, species_key=self.species_key) + arr_dict = self.process_batches(batches, n_atoms_max, sys_count, species_key=self.species_key) arr_dict = self.filter_arrays(arr_dict, quiet=quiet, allow_unfound=allow_unfound) return arr_dict diff --git a/hippynn/databases/restarter.py b/hippynn/databases/restarter.py index 0886bafc..6cdb451d 100644 --- a/hippynn/databases/restarter.py +++ b/hippynn/databases/restarter.py @@ -16,12 +16,12 @@ class Restarter: - def attempt_reload(self): + def attempt_restart(self): return NotImplemented class NoRestart(Restarter): - def attempt_reload(self): + def attempt_restart(self): print("Couldn't reload database. It might have been generated in-memory.") return None @@ -54,7 +54,7 @@ def __setstate__(self, state): for k, v in state.items(): setattr(self, k, v) - def attempt_reload(self): + def attempt_restart(self): print("restarting", self.cls) if isinstance(self.cls, str): raise RuntimeError(f"Not restartable due to class error: {self.cls}") diff --git a/hippynn/experiment/serialization.py b/hippynn/experiment/serialization.py index 132a0c30..c4d73c1a 100644 --- a/hippynn/experiment/serialization.py +++ b/hippynn/experiment/serialization.py @@ -67,20 +67,23 @@ def create_structure_file( torch.save(structure, pfile) -def restore_checkpoint(structure: dict, state: dict, restore_db=True) -> dict: +def restore_checkpoint(structure: dict, state: dict, restart_db=False) -> dict: """ + This function loads the parameters from the state dictionary into the modules, + optionally tries to restart the database, and sets the RNG state. + :param structure: experiment structure object :param state: experiment state object - :param restore_db: Attempt to restore database (true/false) + :param restart_db: Attempt to restore database (true/false) :return: experiment structure """ structure["training_modules"][0].load_state_dict(state["model"]) structure["controller"].load_state_dict(state["controller"]) - if "database" in structure and restore_db: - structure["database"] = structure["database"].attempt_reload() + if "database" in structure and restart_db: + structure["database"] = structure["database"].attempt_restart() structure["metric_tracker"] = state["metric_tracker"] torch.random.set_rng_state(state["torch_rng_state"]) @@ -109,7 +112,8 @@ def check_mapping_devices(map_location, model_device): def load_saved_tensors(structure_fname: str, state_fname: str, **kwargs) -> Tuple[dict, dict]: - """Load torch tensors from file. + """ + Load torch tensors from file. :param structure_fname: name of the structure file :param state_fname: name of the state file @@ -125,7 +129,7 @@ def load_saved_tensors(structure_fname: str, state_fname: str, **kwargs) -> Tupl def load_checkpoint( - structure_fname: str, state_fname: str, restore_db=True, map_location=None, model_device=None, **kwargs + structure_fname: str, state_fname: str, restart_db=False, map_location=None, model_device=None, **kwargs ) -> dict: """ Load checkpoint file from given filename. @@ -134,7 +138,7 @@ def load_checkpoint( :param structure_fname: name of the structure file :param state_fname: name of the state file - :param restore_db: restore database or not, defaults to True + :param restart_db: restore database or not, defaults to True :param map_location: device mapping argument for ``torch.load``, defaults to None :param model_device: automatically handle device mapping. Defaults to None, defaults to None :return: experiment structure @@ -146,7 +150,7 @@ def load_checkpoint( structure, state = load_saved_tensors(structure_fname, state_fname, **kwargs) # transfer stuff back to model_device - structure = restore_checkpoint(structure, state, restore_db=restore_db) + structure = restore_checkpoint(structure, state, restart_db=restart_db) # no transfer happens in either case, as the tensors are on the target devices already if model_device == "cpu" or map_location != None: evaluator = structure["training_modules"].evaluator