From 43f5b56740344589719707df03cf091f1af91153 Mon Sep 17 00:00:00 2001 From: Nicholas Lubbers Date: Mon, 9 Sep 2024 17:46:43 -0600 Subject: [PATCH] Many documentation updates, small tweaks to database interface. --- AUTHORS.txt | 4 +- CHANGELOG.rst | 32 ++- COPYRIGHT.txt | 10 + LICENSE.txt | 10 - README.rst | 1 + docs/source/conf.py | 4 +- docs/source/examples/controller.rst | 1 - docs/source/examples/index.rst | 6 +- docs/source/examples/lightning.rst | 20 ++ docs/source/index.rst | 40 +++- docs/source/installation.rst | 4 +- docs/source/user_guide/ckernels.rst | 2 +- docs/source/user_guide/concepts.rst | 5 +- docs/source/user_guide/databases.rst | 43 +++- docs/source/user_guide/features.rst | 15 +- docs/source/user_guide/settings.rst | 5 + examples/ani1x_training.py | 2 +- hippynn/__init__.py | 17 +- hippynn/databases/__init__.py | 9 +- hippynn/databases/database.py | 261 ++++++++++++++++-------- hippynn/databases/h5_pyanitools.py | 76 ++++--- hippynn/databases/ondisk.py | 2 +- hippynn/experiment/__init__.py | 5 +- hippynn/experiment/lightning_trainer.py | 101 ++++++++- hippynn/experiment/routines.py | 4 +- hippynn/layers/pairs/filters.py | 5 +- hippynn/molecular_dynamics/__init__.py | 5 +- hippynn/molecular_dynamics/md.py | 1 + hippynn/tools.py | 19 +- 29 files changed, 530 insertions(+), 179 deletions(-) create mode 100644 COPYRIGHT.txt create mode 100644 docs/source/examples/lightning.rst diff --git a/AUTHORS.txt b/AUTHORS.txt index 4e0d97f7..8b6c4bac 100644 --- a/AUTHORS.txt +++ b/AUTHORS.txt @@ -19,7 +19,7 @@ Emily Shinkle (LANL) Michael G. Taylor (LANL) Jan Janssen (LANL) Cagri Kaymak (LANL) -Shuhao Zhang (CMU, LANL) +Shuhao Zhang (CMU, LANL) - Batched Optimization routines Also thanks to testing and feedback from: @@ -36,3 +36,5 @@ David Rosenberger Michael Tynes Drew Rohskopf Neil Mehta +Alice E A Allen + diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 1618d1bd..6842e0d5 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -3,23 +3,47 @@ Breaking changes: ----------------- +- set_e0_values has been renamed hierarchical_energy_initialization. The old name is + still provided but deprecated, and will be removed. + New Features: ------------- -- Added a new custom cuda kernel implementation using triton. These are highly performant and now the default implementation. -- Exporting a database to NPZ or H5 format after preprocessing is now just a function call away. -- SNAPjson format can now support an optional number of comment lines. -- Added Batch optimizer features in order to optimize geometries in parallel on the GPU. Algorithms include FIRE and BFGS. +- Added a new custom cuda kernel implementation using triton. + These are highly performant and now the default implementation. +- Exporting any database to NPZ or H5 format after preprocessing can be done with a method call. +- Database states can be cached to disk to simplify the restarting of training. +- Added batch geometry optimizer features in order to optimize geometries + in parallel on the GPU. Algorithms include FIRE, Newton-Raphson, and BFGS. +- Added experiment pytorch lightning trainer to provide for simple parallelized training. +- Added a molecular dynamics engine which includes the ability to batch over systems. +- Added examples pertaining to coarse graining. +- Added pair finders based on scipy KDTree for training to large systems. +- Added tool to drastically simplify creating ensemble models. The ensemblized graphs + are compatible with molecular dynamics codes such ASE and LAMMPS. +- Added the ability to weight different systems/atoms/bonds in a loss function. + Improvements: ------------- - Eliminated dependency on pyanitools for loading ANI-style H5 datasets. +- SNAPjson format can now support an optional number of comment lines. +- Added unit conversion options to the LAMMPS interface. +- Improved performance of bond order regression. +- It is now possible to limit the memory usage of the MLIAP interface in LAMMPS + using a library setting. +- Provide tunable regularization of HIP-NN-TS with an epsilon parameter, and + set the default to use a better value for epsilon. + Bug Fixes: ---------- - Fixed bug where custom kernels were not launching properly on non-default GPUs +- Fixed error when LAMMPS interface is in kokkos mode and the kokkos device was set to CPU. +- MLIAPInterface objects +- Fixed bug with RDF computer automatic initialization. 0.0.3 ======= diff --git a/COPYRIGHT.txt b/COPYRIGHT.txt new file mode 100644 index 00000000..aea758d6 --- /dev/null +++ b/COPYRIGHT.txt @@ -0,0 +1,10 @@ + +Copyright 2019. Triad National Security, LLC. All rights reserved. +This program was produced under U.S. Government contract 89233218CNA000001 for Los Alamos +National Laboratory (LANL), which is operated by Triad National Security, LLC for the U.S. +Department of Energy/National Nuclear Security Administration. All rights in the program are +reserved by Triad National Security, LLC, and the U.S. Department of Energy/National Nuclear +Security Administration. The Government is granted for itself and others acting on its behalf a +nonexclusive, paid-up, irrevocable worldwide license in this material to reproduce, prepare +derivative works, distribute copies to the public, perform publicly and display publicly, and to permit +others to do so. diff --git a/LICENSE.txt b/LICENSE.txt index af0925a1..2f40f860 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -1,15 +1,5 @@ -Copyright 2019. Triad National Security, LLC. All rights reserved. -This program was produced under U.S. Government contract 89233218CNA000001 for Los Alamos -National Laboratory (LANL), which is operated by Triad National Security, LLC for the U.S. -Department of Energy/National Nuclear Security Administration. All rights in the program are -reserved by Triad National Security, LLC, and the U.S. Department of Energy/National Nuclear -Security Administration. The Government is granted for itself and others acting on its behalf a -nonexclusive, paid-up, irrevocable worldwide license in this material to reproduce, prepare -derivative works, distribute copies to the public, perform publicly and display publicly, and to permit -others to do so. - This program is open source under the BSD-3 License. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/README.rst b/README.rst index 3bf6a020..45252206 100644 --- a/README.rst +++ b/README.rst @@ -106,6 +106,7 @@ The Journal of chemical physics, 148(24), 241715. See AUTHORS.txt for information on authors. See LICENSE.txt for licensing information. hippynn is licensed under the BSD-3 license. +See COPYRIGHT.txt for copyright information. Triad National Security, LLC (Triad) owns the copyright to hippynn, which it identifies as project number LA-CC-19-093. diff --git a/docs/source/conf.py b/docs/source/conf.py index a47dfe54..8e707bdf 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -45,9 +45,11 @@ "no-show-inheritance": True, "special-members": "__init__", } +autodoc_member_order = 'bysource' + # The following are highly optional, so we mock them for doc purposes. -autodoc_mock_imports = ["pyanitools", "seqm", "schnetpack", "cupy", "lammps", "numba", "triton", "pytorch_lightning"] +autodoc_mock_imports = ["pyanitools", "seqm", "schnetpack", "cupy", "lammps", "numba", "triton", "pytorch_lightning", 'triton', 'scipy'] # -- Options for HTML output ------------------------------------------------- diff --git a/docs/source/examples/controller.rst b/docs/source/examples/controller.rst index 83de14a5..cc2d5016 100644 --- a/docs/source/examples/controller.rst +++ b/docs/source/examples/controller.rst @@ -1,7 +1,6 @@ Controller ========== - How to define a controller for more customized control of the training process. We assume that there is a set of ``training_modules`` assembled and a ``database`` object has been constructed. diff --git a/docs/source/examples/index.rst b/docs/source/examples/index.rst index 78703eac..d72ee5f3 100644 --- a/docs/source/examples/index.rst +++ b/docs/source/examples/index.rst @@ -3,8 +3,8 @@ Examples Here are some examples about how to use various features in ``hippynn``. Besides the :doc:`/examples/minimal_workflow` example, -the examples are just snippets. For runnable example scripts, see -`the examples at the hippynn github repository`_ +the examples are just snippets, rather than full scripts. +For runnable example scripts, see `the examples at the hippynn github repository`_ .. _`the examples at the hippynn github repository`: https://github.com/lanl/hippynn/tree/development/examples @@ -23,5 +23,5 @@ the examples are just snippets. For runnable example scripts, see mliap_unified excited_states weighted_loss - + lightning diff --git a/docs/source/examples/lightning.rst b/docs/source/examples/lightning.rst new file mode 100644 index 00000000..bb572426 --- /dev/null +++ b/docs/source/examples/lightning.rst @@ -0,0 +1,20 @@ +Pytorch Lightning module +======================== + + +Hippynn incldues support for distributed training using `pytorch-lightning`_. +This can be accessed using the :class:`hippynn.experiment.HippynnLightningModule` class. +The class has two class-methods for creating the lightning module using the same +types of arguments that would be used for an ordinary hippynn experiment. +These are :meth:`hippynn.experiment.HippynnLightningModule.from_experiment_setup` +and :meth:`hippynn.experiment.HippynnLightningModule.from_train_setup`. +Alternatively, you may construct and supply the arguments for the module yourself. + +Finally, in additional to the usual pytorch lightning arguments, +the hippynn lightning module saves an additional file, `experiment_structure.pt`, +which needs to be provided as an argument to the +:meth:`hippynn.experiment.HippynnLightningModule.load_from_checkpoint` constructor. + + +.. _pytorch-lightning: https://github.com/Lightning-AI/pytorch-lightning + diff --git a/docs/source/index.rst b/docs/source/index.rst index 50bcd450..fc17eb26 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -8,31 +8,53 @@ We hope you enjoy your stay. What is hippynn? ================ -`hippynn` is a python library for machine learning on atomistic systems. +``hippynn`` is a python library for machine learning on atomistic systems +using `pytorch`_. We aim to provide high-performance modular design so that different components can be re-used, extended, or added to. You can find more information -at the :doc:`/user_guide/features` page. The development home is located -at `the hippynn github repository`_, which also contains `many example files`_ +about overall library features at the :doc:`/user_guide/features` page. +The development home is located at `the github github repository`_, which also contains `many example files`_. +Additionally, the :doc:`user guide ` aims to describe abstract +aspects of the library, while the +:doc:`examples documentation section ` aims to show +more concretely how to perform tasks with hippynn. Finally, the +:doc:`api documentation ` contains a comprehensive +listing of the library components and their documentation. The main components of hippynn are constructing models, loading databases, training the models to those databases, making predictions on new databases, -and interfacing with other atomistic codes. In particular, we provide interfaces -to `ASE`_ (prediction), `PYSEQM`_ (training/prediction), and `LAMMPS`_ (prediction). +and interfacing with other atomistic codes for operations such as molecular dynamics. +In particular, we provide interfaces to `ASE`_ (prediction), +`PYSEQM`_ (training/prediction), and `LAMMPS`_ (prediction). hippynn is also used within `ALF`_ for generating machine learned potentials along with their training data completely from scratch. -Multiple formats for training data are supported, including -Numpy arrays, the ASE Database, `fitSNAP`_ JSON format, and `ANI HDF5 files`_. +Multiple :doc:`database formats ` for training data are supported, including +Numpy arrays, `ASE`_-compatible formats, `FitSNAP`_ JSON format, and `ANI HDF5 files`_. + +``hippynn`` includes many tools, such as an :doc:`ASE calculator`, +a :doc:`LAMMPS MLIAP interface`, +:doc:`batched prediction ` and batched geometry optimization, +:doc:`automatic ensemble creation `, +:doc:`restarting training from checkpoints `, +:doc:`sample-weighted loss functions `, +:doc:`distributed training with pytorch lightning `, +and more. + +``hippynn`` is highly modular, and if you are a model developer, interfacing your +pytorch model into the hippynn node/graph system will make it simple and easy for users +to build models of energy, charge, bond order, excited state energies, and more. .. _`ASE`: https://wiki.fysik.dtu.dk/ase/ .. _`PYSEQM`: https://github.com/lanl/PYSEQM/ .. _`LAMMPS`: https://www.lammps.org -.. _`fitSNAP`: https://github.com/FitSNAP/FitSNAP +.. _`FitSNAP`: https://github.com/FitSNAP/FitSNAP .. _`ANI HDF5 files`: https://doi.org/10.1038/s41597-020-0473-z .. _`ALF`: https://github.com/lanl/ALF/ -.. _`the hippynn github repository`: https://github.com/lanl/hippynn/ +.. _`the github github repository`: https://github.com/lanl/hippynn/ .. _`many example files`: https://github.com/lanl/hippynn/tree/development/examples +.. _`pytorch`: https://pytorch.org .. toctree:: diff --git a/docs/source/installation.rst b/docs/source/installation.rst index 4064fea9..c8a07152 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -2,7 +2,6 @@ Installation ============ - Requirements ^^^^^^^^^^^^ @@ -43,6 +42,8 @@ Interfacing codes: .. _LAMMPS: https://www.lammps.org/ .. _PYSEQM: https://github.com/lanl/PYSEQM .. _pytorch-lightning: https://github.com/Lightning-AI/pytorch-lightning +.. _hippynn: https://github.com/lanl/hippynn/ + Installation Instructions ^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -67,7 +68,6 @@ Clone the hippynn_ repository and navigate into it, e.g.:: $ git clone https://github.com/lanl/hippynn.git $ cd hippynn -.. _hippynn: https://github.com/lanl/hippynn/ Dependencies using conda diff --git a/docs/source/user_guide/ckernels.rst b/docs/source/user_guide/ckernels.rst index c810bbcd..eb504da7 100644 --- a/docs/source/user_guide/ckernels.rst +++ b/docs/source/user_guide/ckernels.rst @@ -60,7 +60,7 @@ The three custom kernels correspond to the interaction sum in hip-nn: .. math:: - a'_{i,a} = = \sum_{\nu,b} V^\nu_{a,b} e^{\nu}_{i,b} + a'_{i,a} = \sum_{\nu,b} V^\nu_{a,b} e^{\nu}_{i,b} e^{\nu}_{i,a} = \sum_p s^\nu_{p} z_{p_j,a} diff --git a/docs/source/user_guide/concepts.rst b/docs/source/user_guide/concepts.rst index 62cd8ff5..79b4faf6 100644 --- a/docs/source/user_guide/concepts.rst +++ b/docs/source/user_guide/concepts.rst @@ -45,8 +45,9 @@ Graphs A :class:`~hippynn.graphs.GraphModule` is a 'compiled' set of nodes; a ``torch.nn.Module`` that executes the graph. -GraphModules are used in a number of places within hippynn. - +GraphModules are used in a number of places within hippynn, +such as the model, the loss, the evaluator, the predictor, the ASE interface, +and the LAMMPS interface objects all use GraphModules. Experiment ^^^^^^^^^^ diff --git a/docs/source/user_guide/databases.rst b/docs/source/user_guide/databases.rst index 4448033b..2b339cf7 100644 --- a/docs/source/user_guide/databases.rst +++ b/docs/source/user_guide/databases.rst @@ -31,12 +31,45 @@ the [i,j] element of the cell gives the j cartesian coordinate of cell vector i. massive difficulties fitting to periodic boundary conditions, you may check the transposed version of your cell data, or compute the RDF. +Database Formats and notes +--------------------------- -ASE Objects Database handling ----------------------------------------------------------- -If your training data is stored as ASE files of any type (.json,.db,.xyz,.traj ... etc.) it can be loaded directly -a Database for hippynn. +Numpy arrays on disk +........................ + +see :class:`hippynn.databases.NPZDatabase` (if arrays are stored +in a `.npz` dictionary) or :class:`hippynn.databases.DirectoryDatabase` +(if each array is in its own file). + +Numpy arrays in memory +........................ + +Use the base :class:`hippynn.databases.Database` class directly to initialize +a database from a dictionary mapping db_names to numpy arrays. + +pyanitools H5 files +........................ + +See :class:`hippynn.databases.PyAniFileDB` and see :class:`hippynn.databases.PyAniDirectoryDB`. + +This format requires ``h5py`` and ``ase`` to be installed. + +Snap JSON Format +........................ + +See :class:`hippynn.databases.SNAPDirectoryDatabase`. This format requires ``ase`` to be installed. + +For more information on this format, see the FitSNAP_ software. + +.. _FitSNAP: https://fitsnap.github.io + +ASE Database +........................ + +If your training data is stored as ASE files of any type, +(.json,.db,.xyz,.traj ... etc.) it can be loaded directly +as a Database for hippynn. The ASE database :class:`~hippynn.databases.AseDatabase` can be loaded with ASE installed. -See ~/examples/ase_db_example.py for a basic example utilzing the class. \ No newline at end of file +See ~/examples/ase_db_example.py for a basic example utilizing the class. \ No newline at end of file diff --git a/docs/source/user_guide/features.rst b/docs/source/user_guide/features.rst index 06fac16f..b95d6158 100644 --- a/docs/source/user_guide/features.rst +++ b/docs/source/user_guide/features.rst @@ -11,7 +11,7 @@ Modular set of pytorch layers for atomistic operations if you want to use them in your scripts without using the rest of the features provided here -- no problem! -API documentation for :mod:`~hippynn.layers` +API documentation for :mod:`~hippynn.layers` and :mod:`~hippynn.networks` Graph level API for simple and flexible construction of models from pytorch components. --------------------------------------------------------------------------------------- @@ -26,6 +26,12 @@ Graph level API for simple and flexible construction of models from pytorch comp API documentation for :mod:`~hippynn.graphs` +For more information on nodes and graphs, see the `graph exploration ipython notebook`_ which can also +be found in the example files. + +.. _graph exploration ipython notebook: https://github.com/lanl/hippynn/blob/development/examples/graph_exploration.ipynb + + Plot level API for tracking your training. ---------------------------------------------------------- - Using the graph API, define quantities to evaluate before, during, or after training as @@ -46,7 +52,7 @@ API documentation for :mod:`~hippynn.experiment` Custom Kernels for fast execution ---------------------------------------------------------- - Certain operations are not efficiently written in pure pytorch, we provide - alternative implementations with ``numba`` + alternative implementations. - These are directly linked in with pytorch Autograd -- use them like native pytorch functions. - These provide advantages in memory footprint and speed - Includes CPU and GPU execution for custom kernels @@ -55,7 +61,8 @@ More information at :doc:`this page ` Interfaces ---------------------------------------------------------- -- ASE: Define `ase` calculators based on the graph-level API. -- PYSEQM: Use `pyseqm` calculations as nodes in a graph. +- ASE: Define ``ase`` calculators based on the graph-level API. +- PYSEQM: Use ``pyseqm`` calculations as nodes in a graph. +- LAMMPS: Create a file for use as a `pair style mliap` object. API documentation for :mod:`~hippynn.interfaces` \ No newline at end of file diff --git a/docs/source/user_guide/settings.rst b/docs/source/user_guide/settings.rst index c6764206..a4c8fcef 100644 --- a/docs/source/user_guide/settings.rst +++ b/docs/source/user_guide/settings.rst @@ -69,3 +69,8 @@ The following settings are available: - float between 0 and 1 - 1.0 - no + * - TIMEPLOT_AUTOSCALING + - If True, only provide log-scaled plots of training quantities over time if warranted by the data. If False, always produce all plots in linear, log, and loglog scales. + - bool + - True + - yes diff --git a/examples/ani1x_training.py b/examples/ani1x_training.py index f97f6114..6d1608d0 100644 --- a/examples/ani1x_training.py +++ b/examples/ani1x_training.py @@ -108,7 +108,7 @@ def load_db(db_info, en_name, force_name, seed, anidata_location, n_workers): found_indices = ~np.isnan(database.arr_dict[en_name]) database.arr_dict = {k: v[found_indices] for k, v in database.arr_dict.items()} - database.make_trainvalidtest_split(0.1, 0.1) + database.make_trainvalidtest_split(test_size=0.1, valid_size=0.1) return database diff --git a/hippynn/__init__.py b/hippynn/__init__.py index 520356ff..bf59fa7d 100644 --- a/hippynn/__init__.py +++ b/hippynn/__init__.py @@ -7,27 +7,36 @@ from . import _version __version__ = _version.get_versions()['version'] -# Configurational settings +# Configuration settings from ._settings_setup import settings - # Pytorch modules from . import layers -from . import networks +from . import networks # wait this one is different from the other one. # Graph abstractions from . import graphs +from .graphs import nodes, IdxType, GraphModule, Predictor # Database loading from . import databases +from .databases import Database, NPZDatabase, DirectoryDatabase # Training/testing routines from . import experiment -from .experiment import setup_and_train +from .experiment import setup_and_train, train_model, setup_training,\ + test_model, load_model_from_cwd, load_checkpoint, load_checkpoint_from_cwd + +# Other subpackages +from . import molecular_dynamics +from . import optimizer # Custom Kernels from . import custom_kernels +from .custom_kernels import set_custom_kernels from . import pretraining +from .pretraining import hierarchical_energy_initialization from . import tools +from .tools import active_directory, log_terminal diff --git a/hippynn/databases/__init__.py b/hippynn/databases/__init__.py index e97ad715..91aca915 100644 --- a/hippynn/databases/__init__.py +++ b/hippynn/databases/__init__.py @@ -23,14 +23,15 @@ pass if has_ase: - from ..interfaces.ase_interface import AseDatabase - if has_h5: - from .h5_pyanitools import PyAniFileDB, PyAniDirectoryDB + from ..interfaces.ase_interface import AseDatabase + from .SNAPJson import SNAPDirectoryDatabase + if has_h5: + from .h5_pyanitools import PyAniFileDB, PyAniDirectoryDB all_list = ["Database", "DirectoryDatabase", "NPZDatabase"] if has_ase: - all_list += ["AseDatabase"] + all_list += ["AseDatabase", "SNAPDirectoryDatabase"] if has_h5: all_list += ["PyAniFileDB", "PyAniDirectoryDB"] __all__ = all_list diff --git a/hippynn/databases/database.py b/hippynn/databases/database.py index fa503763..15bdd93b 100644 --- a/hippynn/databases/database.py +++ b/hippynn/databases/database.py @@ -15,6 +15,7 @@ _AUTO_SPLIT_PREFIX = "split_mask_" + class Database: """ Class for holding a pytorch dataset, splitting it, generating dataloaders, etc." @@ -22,18 +23,18 @@ class Database: def __init__( self, - arr_dict: dict[str,torch.Tensor], + arr_dict: dict[str, np.ndarray], inputs: list[str], targets: list[str], - seed: [int,np.random.RandomState,tuple], - test_size: Union[float,int]=None, - valid_size: Union[float,int]=None, - num_workers: int=0, - pin_memory: bool=True, - allow_unfound:bool =False, - auto_split:bool =False, - device: torch.device=None, - dataloader_kwargs:dict[str,object]=None, + seed: [int, np.random.RandomState, tuple], + test_size: Union[float, int] = None, + valid_size: Union[float, int] = None, + num_workers: int = 0, + pin_memory: bool = True, + allow_unfound: bool = False, + auto_split: bool = False, + device: torch.device = None, + dataloader_kwargs: dict[str, object] = None, quiet=False, ): """ @@ -56,7 +57,7 @@ def __init__( :param quiet: If True, print little or nothing while loading. """ - # Restartable Children of this class should change this after super().__init__ . + # Restartable Children of this class should change this after calling super().__init__() . self.restarter = NoRestart() self.inputs = inputs @@ -75,11 +76,12 @@ def __init__( _var_list = self.var_list except RuntimeError: if not quiet: - print("Database inputs and/or targets not specified. " - "The database will not be checked against and model inputs and targets (db_info).") + print( + "Database inputs and/or targets not specified. " + "The database will not be checked against and model inputs and targets (db_info)." + ) _var_list = [] - for k in _var_list: if k not in arr_dict and k not in ("indices", "split_indices"): if allow_unfound: @@ -113,13 +115,15 @@ def __init__( if self.auto_split: if test_size is not None or valid_size is not None: - warnings.warn(f"Auto split was set but test and valid size was also set." - f" Ignoring supplied test and validation sizes ({test_size} and {valid_size}.") + warnings.warn( + f"Auto split was set but test and valid size was also set." + f" Ignoring supplied test and validation sizes ({test_size} and {valid_size}." + ) self.make_automatic_splits() if test_size is not None or valid_size is not None: if test_size is None or valid_size is None: - raise ValueError("Both test and valid size must be set for auto-splitting based on fractions") + raise ValueError("Both test_size and valid_size must be set for splitting when creating a database.") else: self.make_trainvalidtest_split(test_size=test_size, valid_size=valid_size) @@ -142,11 +146,16 @@ def var_list(self): raise RuntimeError(f"Database inputs not defined, set {Database}.targets.") return self.inputs + self.targets - def send_to_device(self, device=None): + def send_to_device(self, device: torch.device = None): """ Move the database to an accelerator device if possible. In some circumstances this can accelerate training. + .. Note:: + If the database is moved to a GPU, + pin_memory will be set to False + and num_workers will be set to 0. + :param device: device to move to, if None, try to auto-detect. :return: """ @@ -167,11 +176,13 @@ def send_to_device(self, device=None): for split, arrdict in self.splits.items(): for k in arrdict: arrdict[k] = arrdict[k].to(device) + return - def make_random_split(self, evaluation_mode, split_size): + def make_random_split(self, split_name: str, split_size: Union[int, float]): """ + Make a random split using self.random_state to select items. - :param evaluation_mode: String naming the split, can be anything, but 'train', 'valid', and 'test' are special.s + :param split_name: String naming the split, can be anything, but 'train', 'valid', and 'test' are special. :param split_size: int (number of items) or float<1, fraction of samples. :return: """ @@ -185,9 +196,25 @@ def make_random_split(self, evaluation_mode, split_size): split_indices.sort() - return self.make_explicit_split(evaluation_mode, split_indices) + return self.make_explicit_split(split_name, split_indices) + + def make_trainvalidtest_split(self, *, test_size: Union[int, float], valid_size: Union[int, float]): + """ + Make a split for train, valid, and test out of any remaining unsplit entries in the database. + The size is specified in terms of test and valid splits; the train split will be the remainder. + + If you wish to specify precise rows for each split, see `make_explict_split` + or `make_explicit_split_bool`. + + This function takes keyword-arguments only in order to prevent confusion over which + size is which. + + The types of both test_size and valid_size parameters must match. - def make_trainvalidtest_split(self, test_size, valid_size): + :param test_size: int (count) or float (fraction) of data to assign to test split + :param valid_size: int (count) or float (fraction) of data to assign to valid split + :return: None + """ if self.splitting_completed: raise RuntimeError("Database already split!") @@ -196,19 +223,18 @@ def make_trainvalidtest_split(self, test_size, valid_size): raise ValueError("If train or valid size is set as a fraction, then set test_size as a fraction") else: if valid_size + test_size > 1: - raise ValueError( - f"Test fraction ({test_size}) plus valid fraction " f"({valid_size}) are greater than 1!" - ) + raise ValueError(f"Test fraction ({test_size}) plus valid fraction " f"({valid_size}) are greater than 1!") valid_size /= 1 - test_size self.make_random_split("test", test_size) self.make_random_split("valid", valid_size) self.split_the_rest("train") + return - def make_explicit_split(self, evaluation_mode, split_indices): + def make_explicit_split(self, split_name:str, split_indices: np.ndarray): """ - :param evaluation_mode: name for split, typically 'train', 'valid', 'test' + :param split_name: name for split, typically 'train', 'valid', 'test' :param split_indices: the indices of the items for the split :return: """ @@ -227,18 +253,18 @@ def make_explicit_split(self, evaluation_mode, split_indices): where_complement = np.where(complement_mask) # Split off data, and keep the rest. - self.splits[evaluation_mode] = {k: torch.from_numpy(self.arr_dict[k][where_index]) for k in self.arr_dict} - if "split_indices" not in self.splits[evaluation_mode]: + self.splits[split_name] = {k: torch.from_numpy(self.arr_dict[k][where_index]) for k in self.arr_dict} + if "split_indices" not in self.splits[split_name]: if not self.quiet: - print(f"Adding split indices for split: {evaluation_mode}") - self.splits[evaluation_mode]["split_indices"] = torch.arange(len(split_indices), dtype=torch.int64) + print(f"Adding split indices for split: {split_name}") + self.splits[split_name]["split_indices"] = torch.arange(len(split_indices), dtype=torch.int64) for k, v in self.arr_dict.items(): self.arr_dict[k] = v[where_complement] if not self.quiet: - print(f"Arrays for split: {evaluation_mode}") - prettyprint_arrays(self.splits[evaluation_mode]) + print(f"Arrays for split: {split_name}") + prettyprint_arrays(self.splits[split_name]) if arrdict_len(self.arr_dict) == 0: if not self.quiet: @@ -246,25 +272,28 @@ def make_explicit_split(self, evaluation_mode, split_indices): self.splitting_completed = True return - def make_explicit_split_bool(self, evaluation_mode, split_mask): + def make_explicit_split_bool(self, split_name: str, + split_mask: Union[np.ndarray, torch.tensor]): """ - :param evaluation_mode: name for split, typically 'train', 'valid', 'test' + :param split_name: name for split, typically 'train', 'valid', 'test' :param split_mask: a boolean array for where to split :return: """ + if isinstance(split_mask, torch.tensor): + split_mask = split_mask.numpy() if split_mask.dtype != np.bool_: if not np.isin(split_mask, [0, 1]).all(): raise ValueError(f"Mask function contains invalid values. Values found: {np.unique(split_mask)}") else: split_mask = split_mask.astype(np.bool_) - indices = self.arr_dict['indices'][split_mask] - self.make_explicit_split(evaluation_mode, indices) + indices = self.arr_dict["indices"][split_mask] + self.make_explicit_split(split_name, indices) return - def split_the_rest(self, evaluation_mode): - self.make_explicit_split(evaluation_mode, self.arr_dict["indices"]) + def split_the_rest(self, split_name: str): + self.make_explicit_split(split_name, self.arr_dict["indices"]) self.splitting_completed = True return @@ -296,9 +325,9 @@ def add_split_masks(self, dict_to_add_to=None, split_prefix=None): for sprime, split in self.splits.items(): if sprime == s: - mask = np.ones_like(split['indices'], dtype=np.bool_) + mask = np.ones_like(split["indices"], dtype=np.bool_) else: - mask = np.zeros_like(split['indices'], dtype=np.bool_) + mask = np.zeros_like(split["indices"], dtype=np.bool_) if write_tensor: mask = torch.as_tensor(mask) @@ -336,7 +365,7 @@ def make_automatic_splits(self, split_prefix=None, dry_run=False): if k.startswith(split_prefix): if arr.ndim != 1: raise ValueError(f"Split mask for '{k}' has too many dimensions. Shape: {arr.shape=}") - if arr.dtype == np.dtype('bool'): + if arr.dtype == np.dtype("bool"): mask_vars.add(k) elif arr.dtype is np.int and arr.ndim == 1: if np.isin(arr, [0, 1]).all(): @@ -350,7 +379,7 @@ def make_automatic_splits(self, split_prefix=None, dry_run=False): if not len(mask_vars): raise ValueError("No split mask detected.") - masks = {k[len(split_prefix):]: self.arr_dict[k].astype(bool) for k in mask_vars} + masks = {k[len(split_prefix) :]: self.arr_dict[k].astype(bool) for k in mask_vars} if not self.quiet: print("Auto-detected splits:", list(masks.keys())) @@ -369,13 +398,15 @@ def make_automatic_splits(self, split_prefix=None, dry_run=False): mask_counts += arr.astype(int) if not (mask_counts == 1).all(): set_of_counts = set(mask_counts) - raise ValueError(f" Auto-splitting requires unique split for each item." + - f" Items with the following split counts were detected: {set_of_counts}") + raise ValueError( + f" Auto-splitting requires unique split for each item." + + f" Items with the following split counts were detected: {set_of_counts}" + ) if dry_run: return - masks = {k: self.arr_dict['indices'][m] for k, m in masks.items()} + masks = {k: self.arr_dict["indices"][m] for k, m in masks.items()} for k, m in masks.items(): self.make_explicit_split(k, m) @@ -388,11 +419,18 @@ def make_automatic_splits(self, split_prefix=None, dry_run=False): return - def make_generator(self, split_type, evaluation_mode, batch_size=None, subsample=False): + def make_generator(self, + split_name: str, + evaluation_mode: str, + batch_size: Union[int, None] = None, + subsample: Union[float, bool] = False + ): """ Makes a dataloader for the given type of split and evaluation mode of the model. - :param split_type: str; "train", "valid", or "test" ; selects data to use + In most cases, you do not need to call this function directly as a user. + + :param split_name: str; "train", "valid", or "test" ; selects data to use :param evaluation_mode: str; "train" or "eval". Used for whether to shuffle. :param batch_size: passed to pytorch :param subsample: fraction to subsample @@ -402,16 +440,14 @@ def make_generator(self, split_type, evaluation_mode, batch_size=None, subsample if not self.splitting_completed: raise ValueError("Database has not yet been split.") - if split_type not in self.splits: - raise ValueError(f"Split {split_type} Invalid. Current splits:{list(self.splits.keys())}") + if split_name not in self.splits: + raise ValueError(f"Split {split_name} Invalid. Current splits:{list(self.splits.keys())}") - data = [self.splits[split_type][k] for k in self.var_list] + data = [self.splits[split_name][k] for k in self.var_list] if evaluation_mode == "train": - if split_type != "train": - raise ValueError( - "evaluation mode 'train' can only be used with training data." "(got {})".format(split_type) - ) + if split_name != "train": + raise ValueError("evaluation mode 'train' can only be used with training data." "(got {})".format(split_name)) shuffle = True elif evaluation_mode == "eval": shuffle = False @@ -423,9 +459,7 @@ def make_generator(self, split_type, evaluation_mode, batch_size=None, subsample n_total = data[0].shape[0] n_selected = int(n_total * subsample) sampled_indices = torch.argsort(torch.rand(n_total))[:n_selected] - # sampled_indices = torch.rand(data[0].shape[0]) < subsample dataset = Subset(dataset, sampled_indices) - # data = [a[sampled_indices] for a in data] generator = DataLoader( dataset, @@ -463,8 +497,8 @@ def _array_stat_helper(self, key, species_key, atomwise, norm_per_atom, norm_axi n_atoms = (self.arr_dict[species_key] > 0).sum(axis=1) # Transposes broadcast the result rightwards instead of leftwards. # numpy transpose on higher-order arrays reverses all dimensions. - prop = (prop.T/n_atoms).T - stat_prop = (stat_prop.T/n_atoms).T + prop = (prop.T / n_atoms).T + stat_prop = (stat_prop.T / n_atoms).T mean = stat_prop.mean() std = stat_prop.std() @@ -473,8 +507,16 @@ def _array_stat_helper(self, key, species_key, atomwise, norm_per_atom, norm_axi return prop, mean, std - - def remove_high_property(self, key, atomwise, norm_per_atom=False, species_key=None, cut=None, std_factor=10, norm_axis=None): + def remove_high_property( + self, + key: str, + atomwise: bool, + norm_per_atom: bool = False, + species_key: str = None, + cut: Union[float, None] = None, + std_factor: Union[float, None] = 10, + norm_axis: Union[int, None] = None, + ): """ For removing outliers from a dataset. Use with caution; do not inadvertently remove outliers from benchmarks! @@ -505,7 +547,7 @@ def remove_high_property(self, key, atomwise, norm_per_atom=False, species_key=N if std_factor is not None: prop, mean, std = self._array_stat_helper(key, species_key, atomwise, norm_per_atom, norm_axis) - large_property_mask = np.abs(prop - mean)/std > std_factor + large_property_mask = np.abs(prop - mean) / std > std_factor # Scan over all non-batch indices. non_batch_axes = tuple(range(1, prop.ndim)) drop_mask = np.sum(large_property_mask, axis=non_batch_axes) > 0 @@ -514,7 +556,23 @@ def remove_high_property(self, key, atomwise, norm_per_atom=False, species_key=N print(f"Removed {drop_mask.astype(int).sum()} outlier systems in variable {key} due to std. factor.") self.make_explicit_split(f"failed_std_fac_{key}", indices) - def write_h5(self, split=None, h5path=None, species_key='species', overwrite=False): + def write_h5(self, + split: Union[str, None] = None, + h5path: Union[str, None] = None, + species_key: str = "species", + overwrite:bool = False): + """ + Write this database to the pyanitools h5 format. + See :func:`hippynn.databases.h5_pyanitools.write_h5` for details. + + Note: This function will error if h5py is not installed. + + :param split: + :param h5path: + :param species_key: + :param overwrite: + :return: + """ try: from .h5_pyanitools import write_h5 as write_h5_function @@ -523,33 +581,41 @@ def write_h5(self, split=None, h5path=None, species_key='species', overwrite=Fal return write_h5_function(self, split=split, file=h5path, species_key=species_key, overwrite=overwrite) - def write_npz(self, file: str, record_split_masks: bool = True, compressed:bool =True, overwrite: bool = False, split_prefix=None, return_only=False): + def write_npz( + self, + file: str, + record_split_masks: bool = True, + compressed: bool = True, + overwrite: bool = False, + split_prefix: Union[str, None] = None, + return_only: bool = False, + ): """ :param file: str, Path, or file object compatible with np.save - :param record_split_masks: + :param record_split_masks: whether to generate and place masks for the splits into the saved database. + :param compressed: whether to use np.savez_compressed (True) or np.savez :param overwrite: Whether to accept an existing path. Only used if fname is str or path. - :param split_prefix: optionally change the prefix for the masks computed by the splits. + :param split_prefix: optionally override the prefix for the masks computed by the splits. :param return_only: if True, ignore the file string and just return the resulting dictionary of numpy arrays. - :return: + :return: """ if split_prefix is None: split_prefix = _AUTO_SPLIT_PREFIX if not self.splitting_completed: - raise ValueError("Cannot write an incompletely split database to npz file.\n" + - "You can split the rest using `database.split_the_rest('other_data')`\n" + - "to put the remaining data into a new split named 'other_data'") + raise ValueError( + "Cannot write an incompletely split database to npz file.\n" + + "You can split the rest using `database.split_the_rest('other_data')`\n" + + "to put the remaining data into a new split named 'other_data'" + ) # get combined dictionary of arrays. - np_dict = {sname: - {arr_name: array.to('cpu').numpy() for arr_name, array in split.items()} - for sname, split in self.splits.items()} + np_dict = {sname: {arr_name: array.to("cpu").numpy() for arr_name, array in split.items()} for sname, split in self.splits.items()} # insert split masks if requested. if record_split_masks: self.add_split_masks(dict_to_add_to=np_dict, split_prefix=split_prefix) - # Stack numpy arrays: arr_dict = {} a_split = list(np_dict.values())[0] @@ -577,10 +643,12 @@ def write_npz(self, file: str, record_split_masks: bool = True, compressed:bool return arr_dict - def sort_by_index(self, index_name='indices'): + def sort_by_index(self, index_name: str = "indices"): """ + Sort arrays in each split of the database by an index key. - The default is 'indices', also possible is 'split_indices', or any other variable name in the database. + + The default is 'indices', also possible is 'split_indices', or any other variable name in the database. :param index_name: :return: None @@ -592,12 +660,14 @@ def sort_by_index(self, index_name='indices'): for k, v in split.items(): split[k] = v[ind_order] - def trim_by_species(self, species_key: str, keep_splits_same_size: bool =True): + def trim_by_species(self, species_key: str, keep_splits_same_size: bool = True): """ Remove any excess padding in a database. + :param species_key: what array to use to mark atom presence. - :param keep_splits_same_size: true: trim by the minimum amount across splits, false: trim by the maximum amount for each split. - :return: + :param keep_splits_same_size: true: trim by the minimum amount across splits, + false: trim by the maximum amount for each split. + :return: None """ if not self.splitting_completed: raise ValueError("Cannot trim arrays until splitting has been completed.") @@ -653,7 +723,12 @@ def trim_by_species(self, species_key: str, keep_splits_same_size: bool =True): return - def get_device(self): + def get_device(self) -> torch.device: + """ + Determine what device the database resides on. Raises ValueError if multiple devices are encountered. + + :return: device. + """ if not self.splitting_completed: raise ValueError("Device should not be changed before splitting is complete.") @@ -664,11 +739,15 @@ def get_device(self): device = devices.pop() return device - def make_database_cache(self, file="./hippynn_db_cache.npz", overwrite=False, **override_kwargs): + def make_database_cache(self, file: str = "./hippynn_db_cache.npz", overwrite: bool = False, **override_kwargs) -> "Database": """ Cache the database as-is, and re-open it. Useful for creating an easy restart script if the storage space is available. + The new datatbase will by default inherit the properties of this database. + + usage: + >>> database = database.make_database_cache() :param file: where to store the database :param overwrite: whether to overwrite an existing cache file with this name. @@ -702,14 +781,20 @@ def make_database_cache(self, file="./hippynn_db_cache.npz", overwrite=False, ** if not self.quiet: print("Writing Cached database to", file) - self.write_npz(file=file, - record_split_masks=True, # allows inheriting of splits from this db. - overwrite=overwrite, - return_only=False) + self.write_npz( + file=file, record_split_masks=True, overwrite=overwrite, return_only=False # allows inheriting of splits from this db. + ) # now reload cached file. return NPZDatabase(**arguments) -def compute_index_mask(indices, index_pool): + +def compute_index_mask(indices: np.ndarray, index_pool: np.ndarray) -> np.ndarray: + """ + + :param indices: + :param index_pool: + :return: + """ if not np.all(np.isin(indices, index_pool)): raise ValueError("Provided indices not in database") @@ -723,9 +808,9 @@ def compute_index_mask(indices, index_pool): return index_mask -def prettyprint_arrays(arr_dict): +def prettyprint_arrays(arr_dict: dict[str: np.ndarray]): """ - Pretty-print array dictionary + Pretty-print array dictionary. :return: None """ column_format = "| {:<30} | {:<18} | {:<28} |" diff --git a/hippynn/databases/h5_pyanitools.py b/hippynn/databases/h5_pyanitools.py index fd4c6a27..a50a9f2d 100644 --- a/hippynn/databases/h5_pyanitools.py +++ b/hippynn/databases/h5_pyanitools.py @@ -1,7 +1,7 @@ """ -Read Databases in the ANI H5 format. -Note: You will need `pyanitools.py` to be importable to import this module. +Read Databases in the pyanitools H5 format. + """ import os @@ -37,8 +37,7 @@ def extract_full_file(self, file, species_key="species"): for c in progress_bar(x, desc="Data Groups", unit="group", total=x.group_size()): batch_dict = {} if species_key not in c: - raise ValueError(f"Species key '{species_key}' not found' in file {file}!\n" - f"\tFound keys: {set(c.keys())}") + raise ValueError(f"Species key '{species_key}' not found' in file {file}!\n" f"\tFound keys: {set(c.keys())}") for k, v in c.items(): # Filter things we don't need if k in self._IGNORE_KEYS: @@ -104,18 +103,19 @@ def determine_key_structure(self, batch_list, sys_count, n_atoms_max, species_ke shape_scheme[k][axis] = n_atoms_max shape_scheme[k][0] = sys_count - padding_scheme['sys_number'] = [] + padding_scheme["sys_number"] = [] return padding_scheme, shape_scheme, bkey def process_batches(self, batches, n_atoms_max, sys_count, species_key="species"): # Get padding abd shape info and batch size key - padding_scheme, shape_scheme, size_key =\ - self.determine_key_structure(batches, sys_count, n_atoms_max, species_key=species_key) + padding_scheme, shape_scheme, size_key = self.determine_key_structure(batches, sys_count, n_atoms_max, species_key=species_key) # add system numbers to the final arrays - shape_scheme['sys_number'] = [sys_count, ] - batches[0]['sys_number'] = np.asarray([0], dtype=np.int64) + shape_scheme["sys_number"] = [ + sys_count, + ] + batches[0]["sys_number"] = np.asarray([0], dtype=np.int64) arr_dict = {} for k, shape in shape_scheme.items(): @@ -126,7 +126,7 @@ def process_batches(self, batches, n_atoms_max, sys_count, species_key="species" for i, b in enumerate(progress_bar(batches, desc="Processing Batches", unit="batch")): # Get batch metadata n_sys = b[size_key].shape[0] - b['sys_number'] = np.asarray([i], dtype=np.int64) + b["sys_number"] = np.asarray([i], dtype=np.int64) sys_end = sys_start + n_sys # n_atoms_batch = b[species_key].shape[1] # don't need this! @@ -173,7 +173,7 @@ def filter_arrays(self, arr_dict, allow_unfound=False, quiet=False): class PyAniFileDB(Database, PyAniMethods, Restartable): - def __init__(self, file, inputs, targets, *args, allow_unfound=False, species_key="species", quiet=False, driver='core', **kwargs): + def __init__(self, file, inputs, targets, *args, allow_unfound=False, species_key="species", quiet=False, driver="core", **kwargs): """ :param file: @@ -197,7 +197,14 @@ def __init__(self, file, inputs, targets, *args, allow_unfound=False, species_ke super().__init__(arr_dict, inputs, targets, *args, **kwargs, quiet=quiet, allow_unfound=allow_unfound) self.restarter = self.make_restarter( - file, inputs, targets, *args, **kwargs, driver=driver, quiet=quiet, allow_unfound=allow_unfound, + file, + inputs, + targets, + *args, + **kwargs, + driver=driver, + quiet=quiet, + allow_unfound=allow_unfound, species_key=species_key, ) @@ -211,8 +218,19 @@ def load_arrays(self, allow_unfound=False, quiet=False): class PyAniDirectoryDB(Database, PyAniMethods, Restartable): - def __init__(self, directory, inputs, targets, *args, files=None, allow_unfound=False, species_key="species", - quiet=False, driver='core', **kwargs): + def __init__( + self, + directory, + inputs, + targets, + *args, + files=None, + allow_unfound=False, + species_key="species", + quiet=False, + driver="core", + **kwargs, + ): self.directory = directory self.files = files @@ -221,11 +239,10 @@ def __init__(self, directory, inputs, targets, *args, files=None, allow_unfound= self.species_key = species_key self.driver = driver - arr_dict = self.load_arrays(allow_unfound=allow_unfound,quiet=quiet) + arr_dict = self.load_arrays(allow_unfound=allow_unfound, quiet=quiet) super().__init__(arr_dict, inputs, targets, *args, **kwargs, quiet=quiet, allow_unfound=allow_unfound) - self.restarter = self.make_restarter(directory, inputs, targets, *args, files=files, quiet=quiet, - species_key=species_key, **kwargs) + self.restarter = self.make_restarter(directory, inputs, targets, *args, files=files, quiet=quiet, species_key=species_key, **kwargs) def load_arrays(self, allow_unfound=False, quiet=False): @@ -257,24 +274,29 @@ def load_arrays(self, allow_unfound=False, quiet=False): return arr_dict -def write_h5(database: Database, split: str = None, file: Path = None, species_key: str = 'species', overwrite=False): +def write_h5( + database: Database, + split: str = None, + file: Path = None, + species_key: str = "species", + overwrite: bool = False, +) -> dict: """ - :param database: database to get + :param database: Database to use :param split: str, None, or True; selects data split to save. If None, contents of arr_dict are used. If True, save all splits and save split masks as well. - :param file: where to save the database. + :param file: where to save the database. if None, does not save the file. :param species_key: the key used for system contents (padding and chemical formulas) :param overwrite: boolean; enables over-writing of h5 file. - :return: dictionary of ANI-style systems. + :return: dictionary of pyanitools-format systems. """ if split is True: database = database.write_npz("", record_split_masks=True, return_only=True) - print("writenpz", database.keys()) elif split in database.splits: database = database.splits[split] - database = {k: v.to('cpu').numpy() for k,v in database.items()} + database = {k: v.to("cpu").numpy() for k, v in database.items()} elif split is None: database = database.arr_dict else: @@ -297,10 +319,8 @@ def write_h5(database: Database, split: str = None, file: Path = None, species_k n_atoms_max = db_species.shape[1] # determine which keys have second shape of N atoms - is_atom_var = { - k: (len(k_arr.shape) > 1) and (k_arr.shape[1] == n_atoms_max) for k, k_arr in database.items() - } - del (is_atom_var[species_key]) # species handled separately + is_atom_var = {k: (len(k_arr.shape) > 1) and (k_arr.shape[1] == n_atoms_max) for k, k_arr in database.items()} + del is_atom_var[species_key] # species handled separately # Create the data dictionary # Maps hashes of system chemical formulas to dictionaries of system information. @@ -343,7 +363,7 @@ def write_h5(database: Database, split: str = None, file: Path = None, species_k mol[k] = np.asarray(mol[k]) if np.issubdtype(mol[k].dtype, np.unicode_): - mol[k] = [el.encode('utf-8') for el in list(mol[k])] + mol[k] = [el.encode("utf-8") for el in list(mol[k])] mol[k] = np.array(mol[k]) # Store data if packer is not None: diff --git a/hippynn/databases/ondisk.py b/hippynn/databases/ondisk.py index fce60bdc..6fce1d37 100644 --- a/hippynn/databases/ondisk.py +++ b/hippynn/databases/ondisk.py @@ -14,7 +14,7 @@ class DirectoryDatabase(Database, Restartable): """ - Database stored as NPY files in a diectory. + Database stored as NPY files in a directory. :param directory: directory path where the files are stored :param name: prefix for the arrays. diff --git a/hippynn/experiment/__init__.py b/hippynn/experiment/__init__.py index 3a222e9b..f42a8091 100644 --- a/hippynn/experiment/__init__.py +++ b/hippynn/experiment/__init__.py @@ -12,9 +12,10 @@ from .assembly import assemble_for_training from .routines import setup_and_train, setup_training, train_model, test_model, SetupParams +from .serialization import load_checkpoint, load_checkpoint_from_cwd, load_model_from_cwd - -__all__ = ["assemble_for_training", "setup_and_train", "setup_training", "train_model", "test_model", "SetupParams",] +__all__ = ["assemble_for_training", "setup_and_train", "setup_training", "train_model", "test_model", "SetupParams", + "load_checkpoint", "load_checkpoint_from_cwd", "load_model_from_cwd"] try: from .lightning_trainer import HippynnLightningModule diff --git a/hippynn/experiment/lightning_trainer.py b/hippynn/experiment/lightning_trainer.py index ead8eb57..3b6e1d52 100644 --- a/hippynn/experiment/lightning_trainer.py +++ b/hippynn/experiment/lightning_trainer.py @@ -31,6 +31,9 @@ class HippynnLightningModule(pl.LightningModule): + """ + A pytorch lightning module for running a hippynn experiment. + """ def __init__( self, model: GraphModule, @@ -84,6 +87,15 @@ def __init__( @classmethod def from_experiment_setup(cls, training_modules: TrainingModules, database: Database, setup_params: SetupParams, **kwargs): + """ + Create a lightning module using the same arguments as for :func:`hippynn.experiment.setup_and_train`. + + :param training_modules: + :param database: + :param setup_params: + :param kwargs: + :return: + """ training_modules, controller, metric_tracker = setup_training(training_modules, setup_params) return cls.from_train_setup(training_modules, database, controller, metric_tracker, **kwargs) @@ -98,6 +110,19 @@ def from_train_setup( batch_callbacks=None, **kwargs, ): + """ + Create a lightning module from the same arguments as for :func:`hippynn.experiment.train_model`. + + :param training_modules: + :param database: + :param controller: + :param metric_tracker: + :param callbacks: + :param batch_callbacks: + :param kwargs: + :return: + """ + model, loss, evaluator = training_modules @@ -131,6 +156,11 @@ def from_train_setup( return trainer, HippynnDataModule(database, controller.batch_size) def on_save_checkpoint(self, checkpoint) -> None: + """ + + :param checkpoint: + :return: + """ # Note to future developers: # trainer.log_dir property needs to be called on all ranks! This is weird but important; @@ -163,6 +193,16 @@ def on_save_checkpoint(self, checkpoint) -> None: @classmethod def load_from_checkpoint(cls, checkpoint_path, map_location=None, structure_file=None, hparams_file=None, strict=True, **kwargs): + """ + + :param checkpoint_path: + :param map_location: + :param structure_file: + :param hparams_file: + :param strict: + :param kwargs: + :return: + """ if structure_file is None: # Assume checkpoint_path is like /version_/checkpoints/.chkpt @@ -178,11 +218,20 @@ def load_from_checkpoint(cls, checkpoint_path, map_location=None, structure_file ) def on_load_checkpoint(self, checkpoint) -> None: + """ + + :param checkpoint: + :return: + """ cstate = checkpoint.pop("controller_state") self.controller.load_state_dict(cstate) return def configure_optimizers(self): + """ + + :return: + """ scheduler_list = [] for s in self.scheduler_list: @@ -201,14 +250,24 @@ def configure_optimizers(self): return optimizer_list, scheduler_list def on_train_epoch_start(self): + """ + + :return: + """ for optimizer in self.optimizer_list: print_lr(optimizer, print_=self.print) self.print("Batch size:", self.trainer.train_dataloader.batch_size) def training_step(self, batch, batch_idx): + """ + + :param batch: + :param batch_idx: + :return: + """ batch_inputs = batch[: self.n_inputs] - batch_targets = batch[-self.n_targets :] + batch_targets = batch[-self.n_targets:] batch_model_outputs = self.model(*batch_inputs) batch_train_loss = self.loss(*batch_model_outputs, *batch_targets)[0] @@ -232,9 +291,21 @@ def _eval_step(self, batch, batch_idx): return batch_predictions def validation_step(self, batch, batch_idx): + """ + + :param batch: + :param batch_idx: + :return: + """ return self._eval_step(batch, batch_idx) def test_step(self, batch, batch_idx): + """ + + :param batch: + :param batch_idx: + :return: + """ return self._eval_step(batch, batch_idx) def _eval_epoch_end(self, prefix): @@ -259,10 +330,18 @@ def _eval_epoch_end(self, prefix): return def on_validation_epoch_end(self): + """ + + :return: + """ self._eval_epoch_end(prefix="valid_") return def on_test_epoch_end(self): + """ + + :return: + """ self._eval_epoch_end(prefix="test_") return @@ -326,10 +405,18 @@ def _eval_end(self, prefix, when=None) -> None: return def on_validation_end(self): + """ + + :return: + """ self._eval_end(prefix="valid_") return def on_test_end(self): + """ + + :return: + """ self._eval_end(prefix="test_", when="test") return @@ -360,10 +447,22 @@ def __init__(self, database: Database, batch_size): self.batch_size = batch_size def train_dataloader(self): + """ + + :return: + """ return self.database.make_generator("train", "train", self.batch_size) def val_dataloader(self): + """ + + :return: + """ return self.database.make_generator("valid", "eval", self.batch_size) def test_dataloader(self): + """ + + :return: + """ return self.database.make_generator("test", "eval", self.batch_size) diff --git a/hippynn/experiment/routines.py b/hippynn/experiment/routines.py index 84faa5c0..ed2e7746 100644 --- a/hippynn/experiment/routines.py +++ b/hippynn/experiment/routines.py @@ -21,6 +21,8 @@ from .. import tools from .assembly import TrainingModules from .step_functions import get_step_function +from ..databases import Database + from .. import custom_kernels @@ -101,7 +103,7 @@ def __post_init__(self): def setup_and_train( training_modules: TrainingModules, - database, + database: Database, setup_params: SetupParams, store_all_better=False, store_best=True, diff --git a/hippynn/layers/pairs/filters.py b/hippynn/layers/pairs/filters.py index d1aebedc..653cb505 100644 --- a/hippynn/layers/pairs/filters.py +++ b/hippynn/layers/pairs/filters.py @@ -4,10 +4,11 @@ from .open import _PairIndexer class FilterDistance(_PairIndexer): - """ Filters a list of tensors in *pair_lists by distance. + """ + Filters a list of tensors in pair_tensors by distance. pair_dist is first positional argument. - :param _PairIndexer: FilterDistance subclasses _PairIndexer so that the + FilterDistance subclasses _PairIndexer so that the FilterPairIndexers behave as regular PairIndexers. """ diff --git a/hippynn/molecular_dynamics/__init__.py b/hippynn/molecular_dynamics/__init__.py index 3bcc1722..622cb6ce 100644 --- a/hippynn/molecular_dynamics/__init__.py +++ b/hippynn/molecular_dynamics/__init__.py @@ -2,4 +2,7 @@ Molecular dynamics driver with great flexibility and customizability regarding which quantities which are evolved and what algorithms are used to evolve them. Calls a hippynn `Predictor` on current state during each MD step. """ -from .md import * \ No newline at end of file +from .md import MolecularDynamics, Variable, NullUpdater, VelocityVerlet, LangevinDynamics + + +__all__ = ["MolecularDynamics", "Variable", "NullUpdater", "VelocityVerlet", "LangevinDynamics"] diff --git a/hippynn/molecular_dynamics/md.py b/hippynn/molecular_dynamics/md.py index 8752990c..1375fd77 100644 --- a/hippynn/molecular_dynamics/md.py +++ b/hippynn/molecular_dynamics/md.py @@ -10,6 +10,7 @@ from ..graphs import Predictor from ..layers.pairs.periodic import wrap_systems_torch + class Variable: """ Tracks the state of a quantity (eg. position, cell, species, diff --git a/hippynn/tools.py b/hippynn/tools.py index df1507cb..15e4768e 100644 --- a/hippynn/tools.py +++ b/hippynn/tools.py @@ -11,7 +11,7 @@ from . import settings -class teed_file_output: +class TeedFileOutput: def __init__(self, *streams): self.streams = streams @@ -42,8 +42,8 @@ def log_terminal(file, *args, **kwargs): file = open(file, *args, **kwargs) else: close_on_exit = False - teed_stderr = teed_file_output(file, sys.stderr) - teed_stdout = teed_file_output(file, sys.stdout) + teed_stderr = TeedFileOutput(file, sys.stderr) + teed_stdout = TeedFileOutput(file, sys.stdout) with contextlib.redirect_stderr(teed_stderr): with contextlib.redirect_stdout(teed_stdout): try: @@ -102,6 +102,16 @@ def active_directory(dirname, create=None): def progress_bar(iterable, *args, **kwargs): + """ + Wrap an iterable in a progress bar according to hippynn's current progress bar settings. + + for args and kwargs, see tqdm documentation. + + :param iterable: + :param args: + :param kwargs: + :return: + """ if settings.PROGRESS is None: return iterable else: @@ -166,9 +176,12 @@ def unsqueeze_multiple(tensor, dims: tuple): tensor = tensor.unsqueeze(d) dims = tuple(d+1 for d in rest) return tensor + + def np_of_torchdefaultdtype(): return torch.ones(1, dtype=torch.get_default_dtype()).numpy().dtype + def is_equal_state_dict(d1, d2, raise_where=False): """ Checks if two pytorch state dictionaries are equal. Calls itself recursively