diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 00000000..e8abe0c7 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,24 @@ +# Check that package builds +name: Build Checks + +on: + push: + pull_request: + workflow_dispatch: + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: "3.10" + + - name: Install dependencies + run: >- + python -m pip install --user --upgrade setuptools wheel + - name: Build + run: >- + python setup.py sdist bdist_wheel diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index 3e95dd6f..e4044c6a 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -5,8 +5,10 @@ name: PyPi Release on: push: - pull_request: - workflow_dispatch: + tags: + - hippynn-* + release: + types: [published] jobs: build: diff --git a/.gitignore b/.gitignore index 2818d631..299d9aea 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,4 @@ __pycache__/ *.pyc build/ -hippynn.egg-info/* \ No newline at end of file +hippynn.egg-info/* diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 04ae5769..97d6e938 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,17 @@ +0.0.2a3 +======= + +New Features: +------------- + +- Add nodes for non-adiabatic coupling vectors (NACR) and phase-less loss. + See /examples/excited_states_azomethane.py. + +Improvements +------------ + +- Multi-target dipole node now has a shape of (n_molecules, n_targets, 3). + 0.0.2a2 ======= diff --git a/docs/source/examples/excited_states.rst b/docs/source/examples/excited_states.rst new file mode 100644 index 00000000..7a1e7a58 --- /dev/null +++ b/docs/source/examples/excited_states.rst @@ -0,0 +1,72 @@ +Non-Adiabiatic Excited States +============================= + +`hippynn` has features for training to excited-state energies, transition dipoles, and +the non-adiabatic coupling vectors (NACR). These features can be found in +:mod:`~hippynn.graphs.nodes.excited`. + +For a more detailed description, please see the paper [Li2023]_ + +Multi-targets nodes are recommended over the usage of one node per target. + +For energies, the node can be constructed just like the ground-state +counterpart:: + + energy = targets.HEnergyNode("E", network, module_kwargs={"n_target": n_states + 1}) + mol_energy = energy.mol_energy + mol_energy.db_name = "E" + +Note that a `multi-target node` is used here, defined by the keyword +``module_kwargs={"n_target": n_states + 1}``. Here, `n_states` is the number of +*excited* states in consideration. The extra state is for the ground state, which is often +useful. The database name is simply `E` with a shape of ``(n_molecules, +n_states+1)``. + +Predicting the transition dipoles is also similar to the ground-state permanent +dipole:: + + charge = targets.HChargeNode("Q", network, module_kwargs={"n_target": n_states}) + dipole = physics.DipoleNode("D", (charge, positions), db_name="D") + +The database name is `D` with a shape of ``(n_molecules, n_states, 3)``. + +For NACR, to avoid singularity problems, we enforcing the training of NACR*ΔE +instead:: + + nacr = excited.NACRMultiStateNode( + "ScaledNACR", + (charge, positions, energy), + db_name="ScaledNACR", + module_kwargs={"n_target": n_states}, + ) + +For NACR between state `i` and `j`, :math:`\boldsymbol{d}_{ij}`, it is expressed +in the following way + +.. math:: + \boldsymbol{d}_{ij}\Delta E_{ij} = \Delta E_{ij}\boldsymbol{q}_i \frac{\partial\boldsymbol{q}_j}{\partial\boldsymbol{R}} + +:math:`E_{ij}` is energy difference between state `i` and `j`, which is +calculated internally in the NACR node based on the input of the ``energy`` +node. :math:`\boldsymbol{R}` corresponding the ``positions`` node in the code. +:math:`\boldsymbol{q}_{i}` and :math:`\boldsymbol{q}_{j}` are the transition +atomic charges for state `i` and `j` contained in the ``charge`` node. This +charge node can be constructed from scratch or reused from the dipole +predictions. The database name is `ScaledNACR` with a shape of ``(n_molecules, +n_states*(n_states-1)/2, 3*n_atoms)``. + +Due to the phase problem, when the loss function is constructed, the +`phase-less` version of MAE or RMSE should be used:: + + energy_mae = loss.MAELoss.of_node(energy) + dipole_mae = excited.MAEPhaseLoss.of_node(dipole) + nacr_mae = excited.MAEPhaseLoss.of_node(nacr) + +:class:`~hippynn.graphs.nodes.excited.MAEPhaseLoss` and +:class:`~hippynn.graphs.nodes.excited.MSEPhaseLoss` are the `phase-less` version MAE +and MSE, which take the minimum error over the possible signs of the output. + +For a complete script, please take a look at ``examples/excited_states_azomethane.py``. + +.. [Li2023] | Machine Learning Framework for Modeling Exciton-Polaritons in Molecular Materials. + | Li et. al, 2023. https://arxiv.org/abs/2306.02523 diff --git a/docs/source/examples/index.rst b/docs/source/examples/index.rst index f7b4ba50..548b884b 100644 --- a/docs/source/examples/index.rst +++ b/docs/source/examples/index.rst @@ -18,4 +18,5 @@ the examples are just snippets. For fully-fledged examples see the restarting ase_calculator mliap_unified + excited_states diff --git a/examples/close_contact_finding.py b/examples/close_contact_finding.py new file mode 100644 index 00000000..18ff366d --- /dev/null +++ b/examples/close_contact_finding.py @@ -0,0 +1,74 @@ +""" +close_contact_finding.py + +This example shows how to use the hippynn function calculate_min_dists +to find distances in the dataset where there are close contacts. +Such a procedure is often useful in active learning to identify +outlier data which inhibits training. + +This script was designed for an external dataset available at +https://github.com/atomistic-ml/ani-al + +Note: It is necessary to untar the h5 data files in ani-al/data/ +before running this script. + +""" +import sys + +sys.path.append("../../datasets/ani-al/readers/lib/") +import pyanitools # Check if pyanitools is found early + +### Loading the database +from hippynn.databases.h5_pyanitools import PyAniDirectoryDB + +database = PyAniDirectoryDB( + directory="../../datasets/ani-al/data/", + seed=0, + quiet=False, + allow_unfound=True, + inputs=None, + targets=None, +) + +### Calculating minimum distance array +from hippynn.pretraining import calculate_min_dists + +min_dist_array = calculate_min_dists( + database.arr_dict, + species_name="species", + positions_name="coordinates", + cell_name="cell", # for open boundaries, do not pass a cell name (or pass None) + dist_hard_max=4.0, # check for distances up to this radius + batch_size=50, +) + +print("Minimum distance in configurations:") +print(f"{min_dist_array.dtype=}") +print(f"{min_dist_array.shape=}") +print("First 100 values:", min_dist_array[:100]) + +### Making a plot of the minimum distance for each configuration +import matplotlib.pyplot as plt + +plt.hist(min_dist_array, bins=100) +plt.title("Minimum distance per config") +plt.xlabel("Distance") +plt.ylabel("Count") +plt.yscale("log") +plt.show() + +#### How to remove and separate low distance configurations +dist_thresh = 1.7 # Note: what threshold to use may be highly problem-dependent. +low_dist_configs = min_dist_array < dist_thresh +where_low_dist = database.arr_dict["indices"][low_dist_configs] + +# This makes the low distance configurations +# into their own split, separate from train/valid/test. +database.make_explicit_split("LOW_DISTANCE_FILTER", where_low_dist) + +# This deletes the new split, although deleting it is not necessary; +# this data iwll not be included in train/valid/test splits +del database.splits["LOW_DISTANCE_FILTER"] + +### Continue on with data processing, e.g. +database.make_trainvalidtest_split(test_size=0.1, valid_size=0.1) diff --git a/examples/excited_states_azomethane.py b/examples/excited_states_azomethane.py new file mode 100644 index 00000000..44b3daa7 --- /dev/null +++ b/examples/excited_states_azomethane.py @@ -0,0 +1,188 @@ +""" + +Example training script to predicted excited-states energies, transition dipoles, and +non-adiabatic coupling vectors (NACR) + +The dataset used in this example can be found at https://doi.org/10.5281/zenodo.7076420. + +This script is set up to assume the "release" folder from the zenodo record + is placed in ../../datasets/azomethane/ relative to this script. + +For more information on the modeling techniques, please see the paper: +Machine Learning Framework for Modeling Exciton-Polaritons in Molecular Materials +Li, et al. (2023) +https://arxiv.org/abs/2306.02523 + +""" +import json + +import matplotlib +import numpy as np +import torch + +import hippynn +from hippynn import plotting +from hippynn.experiment import setup_training, train_model +from hippynn.experiment.controllers import PatienceController, RaiseBatchSizeOnPlateau +from hippynn.graphs import inputs, loss, networks, physics, targets, excited + +matplotlib.use("Agg") +# default types for torch +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +torch.set_default_dtype(torch.float32) + +hippynn.settings.WARN_LOW_DISTANCES = False +hippynn.settings.TRANSPARENT_PLOT = True + +n_atoms = 10 +n_states = 3 +plot_frequency = 100 +dipole_weight = 4 +nacr_weight = 2 +l2_weight = 2e-5 + +# Hyperparameters for the network +# Note: These hyperparameters were generated via +# a tuning algorithm, hence their somewhat arbitrary nature. +network_params = { + "possible_species": [0, 1, 6, 7], + "n_features": 30, + "n_sensitivities": 28, + "dist_soft_min": 0.7665723566179274, + "dist_soft_max": 3.4134447177301515, + "dist_hard_max": 4.6860240434651805, + "n_interaction_layers": 3, + "n_atom_layers": 3, +} +# dump parameters to the log file +print("Network parameters\n\n", json.dumps(network_params, indent=4)) + +with hippynn.tools.active_directory("TEST_AZOMETHANE_MODEL"): + with hippynn.tools.log_terminal("training_log.txt", "wt"): + # build network + species = inputs.SpeciesNode(db_name="Z") + positions = inputs.PositionsNode(db_name="R") + network = networks.Hipnn("hipnn_model", (species, positions), module_kwargs=network_params) + # add energy + energy = targets.HEnergyNode("E", network, module_kwargs={"n_target": n_states + 1}) + mol_energy = energy.mol_energy + mol_energy.db_name = "E" + # add dipole + charge = targets.HChargeNode("Q", network, module_kwargs={"n_target": n_states}) + dipole = physics.DipoleNode("D", (charge, positions), db_name="D") + # add NACR + nacr = excited.NACRMultiStateNode( + "ScaledNACR", + (charge, positions, energy), + db_name="ScaledNACR", + module_kwargs={"n_target": n_states}, + ) + # set up plotter + plotter = [] + for node in [mol_energy, dipole, nacr]: + plotter.append(plotting.Hist2D.compare(node, saved=True, shown=False)) + for i in range(network_params["n_interaction_layers"]): + plotter.append( + plotting.SensitivityPlot( + network.torch_module.sensitivity_layers[i], + saved=f"Sensitivity_{i}.pdf", + shown=False, + ) + ) + plotter = plotting.PlotMaker(*plotter, plot_every=plot_frequency) + # build the loss function + validation_losses = {} + # energy + energy_rmse = loss.MSELoss.of_node(energy) ** 0.5 + validation_losses["E-RMSE"] = energy_rmse + energy_mae = loss.MAELoss.of_node(energy) + validation_losses["E-MAE"] = energy_mae + energy_loss = energy_rmse + energy_mae + validation_losses["E-Loss"] = energy_loss + total_loss = energy_loss + # dipole + dipole_rmse = excited.MSEPhaseLoss.of_node(dipole) ** 0.5 + validation_losses["D-RMSE"] = dipole_rmse + dipole_mae = excited.MAEPhaseLoss.of_node(dipole) + validation_losses["D-MAE"] = dipole_mae + dipole_loss = dipole_rmse / np.sqrt(3) + dipole_mae + validation_losses["D-Loss"] = dipole_loss + total_loss += dipole_weight * dipole_loss + # nacr + nacr_rmse = excited.MSEPhaseLoss.of_node(nacr) ** 0.5 + validation_losses["NACR-RMSE"] = nacr_rmse + nacr_mae = excited.MAEPhaseLoss.of_node(nacr) + validation_losses["NACR-MAE"] = nacr_mae + nacr_loss = nacr_rmse / np.sqrt(3 * n_atoms) + nacr_mae + validation_losses["NACR-Loss"] = nacr_loss + total_loss += nacr_weight * nacr_loss + # l2 regularization + l2_reg = loss.l2reg(network) + validation_losses["L2"] = l2_reg + loss_regularization = l2_weight * l2_reg + # add total loss to the dictionary + validation_losses["Loss_wo_L2"] = total_loss + validation_losses["Loss"] = total_loss + loss_regularization + + # set up experiment + training_modules, db_info = hippynn.experiment.assemble_for_training( + validation_losses["Loss"], + validation_losses, + plot_maker=plotter, + ) + # set up the optimizer + optimizer = torch.optim.AdamW(training_modules.model.parameters(), lr=1e-3) + # use higher patience for production runs + scheduler = RaiseBatchSizeOnPlateau(optimizer=optimizer, max_batch_size=2048, patience=10, factor=0.5) + controller = PatienceController( + optimizer=optimizer, + scheduler=scheduler, + batch_size=32, + eval_batch_size=2048, + # use higher max_epochs for production runs + max_epochs=100, + stopping_key="Loss", + fraction_train_eval=0.1, + # use higher termination_patience for production runs + termination_patience=10, + ) + experiment_params = hippynn.experiment.SetupParams(controller=controller) + + # load database + database = hippynn.databases.DirectoryDatabase( + name="azo_", # Prefix for arrays in the directory + directory="../../../datasets/azomethane/release/training/", + seed=114514, # Random seed for splitting data + **db_info, # Adds the inputs and targets db_names from the model as things to load + ) + # use 10% of the dataset just for quick testing purpose + database.make_random_split("train", 0.07) + database.make_random_split("valid", 0.02) + database.make_random_split("test", 0.01) + database.splitting_completed = True + # split the whole dataset into train, valid, test in the ratio of 7:2:1 + # database.make_trainvalidtest_split(0.1, 0.2) + + # set up training + training_modules, controller, metric_tracker = setup_training( + training_modules=training_modules, + setup_params=experiment_params, + ) + # train model + metric_tracker = train_model( + training_modules, + database, + controller, + metric_tracker, + callbacks=None, + batch_callbacks=None, + ) + + del network_params["possible_species"] + network_params["metric"] = metric_tracker.best_metric_values + network_params["avg_epoch_time"] = np.average(metric_tracker.epoch_times) + network_params["Loss"] = metric_tracker.best_metric_values["valid"]["Loss"] + + with open("training_summary.json", "w") as out: + json.dump(network_params, out, indent=4) diff --git a/examples/periodic_pairfinding.py b/examples/periodic_pairfinding.py new file mode 100644 index 00000000..81c8c988 --- /dev/null +++ b/examples/periodic_pairfinding.py @@ -0,0 +1,42 @@ +import torch +from hippynn.graphs import GraphModule +from hippynn.graphs.nodes.inputs import SpeciesNode, PositionsNode, CellNode +from hippynn.graphs.nodes.indexers import acquire_encoding_padding +from hippynn.graphs.nodes.pairs import PeriodicPairIndexer + + +n_atom = 30 +n_system = 30 +n_dim = 3 +distance_cutoff = 0.3 + +if torch.cuda.is_available(): + device = "cuda" +else: + device = "cpu" + +floatX = torch.float32 + +# Set up input nodes +sp = SpeciesNode("Z") +pos = PositionsNode("R") +cell = CellNode("C") + +# Set up and compile calculation +enc, pidxer = acquire_encoding_padding(sp, species_set=[0, 1]) +pairfinder = PeriodicPairIndexer("pair finder", (pos, enc, pidxer, cell), dist_hard_max=distance_cutoff) +computer = GraphModule([sp, pos, cell], [*pairfinder.children]) +computer.to(device) + +# Get some random inputs +species_tensor = torch.ones(n_system, n_atom, device=device, dtype=torch.int64) +pos_tensor = torch.rand(n_system, n_atom, 3, device=device, dtype=floatX) +cell_tensor = torch.eye(3, 3, device=device, dtype=floatX).unsqueeze(0).expand(n_system, n_dim, n_dim).clone() + +# Run calculation +outputs = computer(species_tensor, pos_tensor, cell_tensor) + +# Print outputs +output_as_dict = {c.name: o for c, o in zip(pairfinder.children, outputs)} +for k, v in output_as_dict.items(): + print(k, v.shape, v.dtype, v.min(), v.max()) diff --git a/hippynn/custom_kernels/env_numba.py b/hippynn/custom_kernels/env_numba.py index ae98df39..7581732e 100644 --- a/hippynn/custom_kernels/env_numba.py +++ b/hippynn/custom_kernels/env_numba.py @@ -86,7 +86,7 @@ def kernel(sens, feat, pfirst, psecond, atom1_ids, atom1_starts, env): @staticmethod @via_numpy - @numba.jit(parallel=True) + @numba.jit(nopython=True, parallel=True) def cpu_kernel(sens, feat, pfirst, psecond, atom_ids, atom_starts): n_pairs, n_nu = sens.shape @@ -152,7 +152,7 @@ def kernel(env, feat, pfirst, psecond, sense): @staticmethod @via_numpy - @numba.jit(parallel=True) + @numba.jit(nopython=True, parallel=True) def cpu_kernel(env, feat, pfirst, psecond): n_atom, n_nu, n_feat = env.shape (n_pairs,) = pfirst.shape @@ -247,7 +247,7 @@ def kernel(env, sense, pfirst, psecond, atom2_ids, atom2_starts, feat): @staticmethod @via_numpy - @numba.jit(parallel=True) + @numba.jit(nopython=True, parallel=True) def cpu_kernel( env, sens, diff --git a/hippynn/databases/database.py b/hippynn/databases/database.py index a9b282de..4eada5b9 100644 --- a/hippynn/databases/database.py +++ b/hippynn/databases/database.py @@ -191,7 +191,6 @@ def make_generator(self, split_type, evaluation_mode, batch_size=None, subsample if not self.splitting_completed: raise ValueError("Database has not yet been split.") - if split_type not in self.splits: raise ValueError(f"Split {split_type} Invalid. Current splits:{list(self.splits.keys())}") @@ -226,6 +225,51 @@ def make_generator(self, split_type, evaluation_mode, batch_size=None, subsample ) return generator + + def trim_all_arrays(self,index): + """ + To be used in conjuction with remove_high_property + """ + for key in self.arr_dict: + self.arr_dict[key] = self.arr_dict[key][index] + + def remove_high_property(self,key,perAtom,species_key=None,cut=None,std_factor=10): + """ + This function removes outlier data from the dataset + Must be called before splitting + "key": the property key in the dataset to check for high values + "perAtom": True if the property is defined per atom in axis 1, otherwise property is treated as full system + "std_factor": systems with values larger than this multiplier time the standard deviation of all data will be reomved. None to skip this step + "cut_factor": systems with values larger than this number are reomved. None to skip this step. This step is done first. + """ + if perAtom: + if species_key==None: + raise RuntimeError("species_key must be defined to trim a per atom quantity") + atom_ind = self.arr_dict[species_key] > 0 + ndim = len(self.arr_dict[key].shape) + if cut!=None: + if perAtom: + Kmean = np.mean(self.arr_dict[key][atom_ind]) + else: + Kmean = np.mean(self.arr_dict[key]) + failArr = np.abs(self.arr_dict[key]-Kmean)>cut + #This does nothing with ndim=1 + trimArr = np.sum(failArr,axis=tuple(range(1,ndim)))==0 + self.trim_all_arrays(trimArr) + + if std_factor!=None: + if perAtom: + atom_ind = self.arr_dict[species_key] > 0 + Kmean = np.mean(self.arr_dict[key][atom_ind]) + std_cut = np.std(self.arr_dict[key][atom_ind]) * std_factor + else: + Kmean = np.mean(self.arr_dict[key]) + std_cut = np.std(self.arr_dict[key]) * std_factor + failArr = np.abs(self.arr_dict[key]-Kmean)>std_cut + #This does nothing with ndim=1 + trimArr = np.sum(failArr,axis=tuple(range(1,ndim)))==0 + self.trim_all_arrays(trimArr) + def compute_index_mask(indices, index_pool): diff --git a/hippynn/experiment/routines.py b/hippynn/experiment/routines.py index 4eebf9ad..88dcba29 100644 --- a/hippynn/experiment/routines.py +++ b/hippynn/experiment/routines.py @@ -20,6 +20,7 @@ from .device import set_devices from .. import tools from .assembly import TrainingModules +from .step_functions import get_step_function from .. import custom_kernels @@ -419,12 +420,13 @@ def training_loop( epoch = metric_tracker.current_epoch device = evaluator.model_device + step_function = get_step_function(controller.optimizer) + optimizer = controller.optimizer continue_training = True # Assume that nobody ran this function without wanting at least 1 epoch. while continue_training: - optimizer = controller.optimizer qprint("_" * 50) qprint("Epoch {}:".format(epoch)) tools.print_lr(optimizer) @@ -442,20 +444,13 @@ def training_loop( batch_targets = batch[-n_targets:] batch_targets = [x.requires_grad_(False) for x in batch_targets] - optimizer.zero_grad(set_to_none=True) - batch_model_outputs = model(*batch_inputs) - - # The extra .mean call here deals with an edge case for multi-GPU DataParallel with scalar outputs - batch_train_loss = loss(*batch_model_outputs, *batch_targets)[0].mean() - - batch_train_loss.backward() - optimizer.step() + batch_model_outputs = step_function(optimizer, model, loss, batch_inputs, batch_targets) if batch_callbacks: for cb in batch_callbacks: cb(batch_inputs, batch_model_outputs, batch_targets) # Allow garbage collection of computed values. - del batch_model_outputs, batch_train_loss + del batch_model_outputs elapsed_epoch_run_time = timeit.default_timer() - epoch_run_time qprint("Training time: ", round(elapsed_epoch_run_time, 2), "s") diff --git a/hippynn/experiment/step_functions.py b/hippynn/experiment/step_functions.py new file mode 100644 index 00000000..c03350e7 --- /dev/null +++ b/hippynn/experiment/step_functions.py @@ -0,0 +1,90 @@ +""" +This file implements various stepping protocols used by different optimizer APIs. + +In particular: + - The "standard" step function which only requires that backwards has been called. + - The "closure" step function for when line search is required (currently only active on LBFGS) + - The "two step" style of Sharpness Aware Minimization algorithms + +The main output function here is `get_step_function(optimizer)-> callable`. + +The various step functions are provided as classes that act with staticmethods. +This is to provide for the possibility of extension, for example, to schemes with +stepping schemes that require additional state, or for the possibility to specifiy +the step function explicitly within the controller. +""" +from torch.optim import Optimizer, LBFGS + + +def standard_step_fn(optimizer, model, loss, batch_inputs, batch_targets): + optimizer.zero_grad(set_to_none=True) + batch_model_outputs = model(*batch_inputs) + + # The extra .mean call here deals with an edge case for multi-GPU DataParallel with scalar outputs + batch_train_loss = loss(*batch_model_outputs, *batch_targets)[0].mean() + + batch_train_loss.backward() + optimizer.step() + return batch_model_outputs + + +def twostep_step_fn(optimizer, model, loss, batch_inputs, batch_targets): + # Step function for SAM algorithm. + optimizer.zero_grad(set_to_none=True) + + batch_model_outputs = model(*batch_inputs) + batch_train_loss = loss(*batch_model_outputs, *batch_targets)[0].mean() + batch_train_loss.backward() + optimizer.first_step(zero_grad=True) + + batch_model_outputs_2 = model(*batch_inputs) + loss(*batch_model_outputs_2, *batch_targets)[0].mean().backward() + optimizer.second_step(zero_grad=True) + return batch_model_outputs + + +def closure_step_fn(optimizer, model, loss, batch_inputs, batch_targets): + return_outputs = None + + def closure(): + nonlocal return_outputs + optimizer.zero_grad(set_to_none=True) + batch_model_outputs = model(*batch_inputs) + if return_outputs is None: + return_outputs = batch_model_outputs + batch_train_loss = loss(*batch_model_outputs, *batch_targets)[0].mean() + batch_train_loss.backward() + return batch_train_loss + + optimizer.step(closure) + return return_outputs + + +# Note: The staticmethod version here can be re-written using class parameters +# and __init_subclass, but will they always be staticmethods? +class StepFn: + step = NotImplemented + + def __call__(self, *args, **kwargs): + return self.step(*args, **kwargs) + + +class StandardStep(StepFn): + step = staticmethod(standard_step_fn) + + +class TwoStep(StepFn): + step = staticmethod(twostep_step_fn) + + +class ClosureStep(StepFn): + step = staticmethod(closure_step_fn) + + +def get_step_function(optimizer: Optimizer) -> callable: + if type(optimizer).__name__ == "SAM": + return TwoStep() + if isinstance(optimizer, (LBFGS,)): + return ClosureStep() + else: + return StandardStep() diff --git a/hippynn/graphs/__init__.py b/hippynn/graphs/__init__.py index eed24245..4fc3bd66 100644 --- a/hippynn/graphs/__init__.py +++ b/hippynn/graphs/__init__.py @@ -11,11 +11,13 @@ from . import indextypes from .indextypes import clear_index_cache, IdxType -from .nodes import base, inputs, networks, targets, physics, loss +from .nodes import base, inputs from .nodes.base import find_unique_relative, find_relatives, get_connected_nodes from .gops import get_subgraph, copy_subgraph, replace_node, compute_evaluation_order +from .nodes import networks, targets, physics, loss, excited + # Needed to populate the registry of index transformers. # This has to happen before the indextypes package can work, # however, we don't want the indextypes package to depend on actual diff --git a/hippynn/graphs/nodes/excited.py b/hippynn/graphs/nodes/excited.py new file mode 100644 index 00000000..142ba6fc --- /dev/null +++ b/hippynn/graphs/nodes/excited.py @@ -0,0 +1,169 @@ +from typing import Tuple +import torch + +from ...layers import excited as excited_layers +from .. import IdxType, find_unique_relative +from .base import AutoKw, SingleNode, ExpandParents, MultiNode +from .loss import _BaseCompareLoss +from .tags import Energies, HAtomRegressor, Network, AtomIndexer +from ...layers import physics as physics_layers + + +class NACRNode(AutoKw, SingleNode): + """ + Compute the non-adiabatic coupling vector multiplied by the energy difference + between two states. + """ + + _input_names = "charges i", "charges j", "coordinates", "energy i", "energy j" + _auto_module_class = excited_layers.NACR + + def __init__(self, name: str, parents: Tuple, module="auto", module_kwargs=None, **kwargs): + """Automatically build the node for calculating NACR * ΔE between two states i + and j. + + :param name: name of the node + :type name: str + :param parents: parents of the NACR node in the sequence of (charges i, \ + charges j, positions, energy i, energy j) + :type parents: Tuple + :param module: _description_, defaults to "auto" + :type module: str, optional + :param module_kwargs: keyword arguments passed to the corresponding layer, + defaults to None + :type module_kwargs: dict, optional + """ + + self.module_kwargs = {} + if module_kwargs is not None: + self.module_kwargs.update(module_kwargs) + charges1, charges2, positions, energy1, energy2 = parents + positions.requires_grad = True + self._index_state = IdxType.Molecules + # self._index_state = positions._index_state + parents = ( + charges1.main_output, + charges2.main_output, + positions, + energy1.main_output, + energy2.main_output, + ) + super().__init__(name, parents, module=module, **kwargs) + + +class NACRMultiStateNode(AutoKw, SingleNode): + """ + Compute the non-adiabatic coupling vector multiplied by the energy difference + between all pairs of states. + """ + + _input_names = "charges", "coordinates", "energies" + _auto_module_class = excited_layers.NACRMultiState + + def __init__(self, name, parents, module="auto", module_kwargs=None, **kwargs): + """Automatically build the node for calculating NACR * ΔE between all pairs of + states. + + :param name: name of the node + :type name: str + :param parents: parents of the NACR node in the sequence of (charges, \ + positions, energies) + :type parents: Tuple + :param module: _description_, defaults to "auto" + :type module: str, optional + :param module_kwargs: keyword arguments passed to the corresponding layer, + defaults to None + :type module_kwargs: dict, optional + """ + + self.module_kwargs = {} + if module_kwargs is not None: + self.module_kwargs.update(module_kwargs) + charges, positions, energies = parents + positions.requires_grad = True + self._index_state = IdxType.Molecules + # self._index_state = positions._index_state + parents = ( + charges.main_output, + positions, + energies.main_output, + ) + super().__init__(name, parents, module=module, **kwargs) + + +class LocalEnergyNode(Energies, ExpandParents, HAtomRegressor, MultiNode): + """ + Predict a localized energy, with contributions from implicitly computed atoms. + """ + + _input_names = "hier_features", "mol_index", "atom index", "n_molecules", "n_atoms_max" + _output_names = "mol_energy", "atom_energy", "atom_preenergy", "atom_probabilities", "atom_propensities" + _main_output = "mol_energy" + _output_index_states = IdxType.Molecules, IdxType.Atoms, IdxType.Atoms, IdxType.Atoms, IdxType.Atoms + _auto_module_class = excited_layers.LocalEnergy + + @_parent_expander.match(Network) + def expansion0(self, net, *, purpose, **kwargs): + pdindexer = find_unique_relative(net, AtomIndexer, why_desc=purpose) + return net, pdindexer + + @_parent_expander.match(Network, AtomIndexer) + def expansion1(self, net, pdindexer, **kwargs): + return net, pdindexer.mol_index, pdindexer.atom_index, pdindexer.n_molecules, pdindexer.n_atoms_max + + _parent_expander.assertlen(5) + + def __init__(self, name, parents, first_is_interacting=False, module="auto", **kwargs): + parents = self.expand_parents(parents) + self.module_kwargs = {"first_is_interacting": first_is_interacting} + super().__init__(name, parents, module=module, **kwargs) + + def auto_module(self): + network = find_unique_relative(self, Network).torch_module + return self._auto_module_class(network.feature_sizes, **self.module_kwargs) + + +def _mae_with_phases(predict: torch.Tensor, true: torch.Tensor): + """MAE with phases + + :param predict: predicted values + :type predict: torch.Tensor + :param true: true values + :type true: torch.Tensor + :return: MAE with phases + :rtype: torch.Tensor + """ + + errors = torch.minimum( + torch.linalg.norm(true - predict, ord=1, dim=-1), + torch.linalg.norm(true + predict, ord=1, dim=-1), + ) + # errors = absolute_errors(predict, true) + return torch.sum(errors) / predict.numel() + + +def _mse_with_phases(predict: torch.Tensor, true: torch.Tensor): + """MSE with phases + + :param predict: predicted values + :type predict: torch.Tensor + :param true: true values + :type true: torch.Tensor + :return: MSE with phases + :rtype: torch.Tensor + """ + + errors = torch.minimum( + torch.linalg.norm(true - predict, dim=-1), + torch.linalg.norm(true + predict, dim=-1), + ) + # errors = absolute_errors(predict, true) ** 2 + return torch.sum(errors**2) / predict.numel() + + +class MAEPhaseLoss(_BaseCompareLoss, op=_mae_with_phases): + pass + + +class MSEPhaseLoss(_BaseCompareLoss, op=_mse_with_phases): + pass diff --git a/hippynn/graphs/nodes/loss.py b/hippynn/graphs/nodes/loss.py index 2320df1d..e11ade2b 100644 --- a/hippynn/graphs/nodes/loss.py +++ b/hippynn/graphs/nodes/loss.py @@ -134,3 +134,21 @@ def l2reg(network): def l1reg(network): return lpreg(network, p=1) + +# For loss functions with phases +def absolute_errors(predict: torch.Tensor, true: torch.Tensor): + """Compute the absolute errors with phases between predicted and true values. In + other words, prediction should be close to the absolute value of true, and the sign + does not matter. + + :param predict: predicted values + :type predict: torch.Tensor + :param true: true values + :type true: torch.Tensor + :return: errors + :rtype: torch.Tensor + """ + + return torch.minimum(torch.abs(true - predict), torch.abs(true + predict)) + + diff --git a/hippynn/graphs/nodes/physics.py b/hippynn/graphs/nodes/physics.py index 6ceed115..5044324f 100644 --- a/hippynn/graphs/nodes/physics.py +++ b/hippynn/graphs/nodes/physics.py @@ -3,17 +3,24 @@ """ import warnings -from .base import SingleNode, MultiNode, AutoNoKw, AutoKw, ExpandParents, find_unique_relative, _BaseNode +from ...layers import indexers as index_layers +from ...layers import pairs as pair_layers +from ...layers import physics as physics_layers +from ..indextypes import IdxType, elementwise_compare_reduce, index_type_coercion +from .base import ( + AutoKw, + AutoNoKw, + ExpandParents, + MultiNode, + SingleNode, + _BaseNode, + find_unique_relative, +) from .base.node_functions import NodeNotFound from .indexers import AtomIndexer, PaddingIndexer, acquire_encoding_padding -from .pairs import OpenPairIndexer -from .tags import Encoder, PairIndexer, Charges, Energies from .inputs import PositionsNode, SpeciesNode - -from ..indextypes import IdxType, index_type_coercion, elementwise_compare_reduce -from ...layers import indexers as index_layers -from ...layers import physics as physics_layers -from ...layers import pairs as pair_layers +from .pairs import OpenPairIndexer +from .tags import Charges, Encoder, Energies, PairIndexer class GradientNode(AutoKw, SingleNode): @@ -283,7 +290,6 @@ def __init__(self, name, parents, module="auto", **kwargs): # TODO: This seems broken for parent expanders, check the signature of the layer. class BondToMolSummmer(ExpandParents, AutoNoKw, SingleNode): - _input_names = "pairfeatures", "mol_index", "n_molecules", "pair_first" _auto_module_class = pair_layers.MolPairSummer _index_state = IdxType.Molecules @@ -329,17 +335,20 @@ def __init__(self, name, parents, module="auto", **kwargs): super().__init__(name, parents, module=module, **kwargs) - class CombineEnergyNode(Energies, AutoKw, ExpandParents, MultiNode): """ - Combines Local atom energies from different Energy Nodes. + Combines Local atom energies from different Energy Nodes. """ + _input_names = "input_atom_energy_1", "input_atom_energy_2", "mol_index", "n_molecules" _output_names = "mol_energy", "atom_energies" _main_output = "mol_energy" - _output_index_states = IdxType.Molecules, IdxType.Atoms, + _output_index_states = ( + IdxType.Molecules, + IdxType.Atoms, + ) _auto_module_class = physics_layers.CombineEnergy - + @_parent_expander.match(_BaseNode, Energies) def expansion0(self, energy_1, energy_2, **kwargs): return energy_1, energy_2.atom_energies @@ -364,4 +373,3 @@ def __init__(self, name, parents, module="auto", module_kwargs=None, **kwargs): self.module_kwargs = {} if module_kwargs is None else module_kwargs parents = self.expand_parents(parents, **kwargs) super().__init__(name, parents=parents, module=module, **kwargs) - diff --git a/hippynn/graphs/nodes/targets.py b/hippynn/graphs/nodes/targets.py index b666e97d..adb86a50 100644 --- a/hippynn/graphs/nodes/targets.py +++ b/hippynn/graphs/nodes/targets.py @@ -1,8 +1,8 @@ """ Nodes for prediction of variables from network features. """ + from .base import MultiNode, AutoKw, ExpandParents, find_unique_relative, _BaseNode -from .indexers import PaddingIndexer from .tags import AtomIndexer, Network, PairIndexer, HAtomRegressor, Charges, Energies from .indexers import PaddingIndexer from ..indextypes import IdxType, index_type_coercion @@ -102,33 +102,3 @@ def __init__(self, name, parents, module="auto", module_kwargs=None, **kwargs): super().__init__(name, parents, module=module, **kwargs) -class LocalEnergyNode(Energies, ExpandParents, HAtomRegressor, MultiNode): - """ - Predict a localized energy, with contributions from implicitly computed atoms. - """ - - _input_names = "hier_features", "mol_index", "atom index", "n_molecules", "n_atoms_max" - _output_names = "mol_energy", "atom_energy", "atom_preenergy", "atom_probabilities", "atom_propensities" - _main_output = "mol_energy" - _output_index_states = IdxType.Molecules, IdxType.Atoms, IdxType.Atoms, IdxType.Atoms, IdxType.Atoms - _auto_module_class = target_modules.LocalEnergy - - @_parent_expander.match(Network) - def expansion0(self, net, *, purpose, **kwargs): - pdindexer = find_unique_relative(net, AtomIndexer, why_desc=purpose) - return net, pdindexer - - @_parent_expander.match(Network, AtomIndexer) - def expansion1(self, net, pdindexer, **kwargs): - return net, pdindexer.mol_index, pdindexer.atom_index, pdindexer.n_molecules, pdindexer.n_atoms_max - - _parent_expander.assertlen(5) - - def __init__(self, name, parents, first_is_interacting=False, module="auto", **kwargs): - parents = self.expand_parents(parents) - self.module_kwargs = {"first_is_interacting": first_is_interacting} - super().__init__(name, parents, module=module, **kwargs) - - def auto_module(self): - network = find_unique_relative(self, Network).torch_module - return self._auto_module_class(network.feature_sizes, **self.module_kwargs) diff --git a/hippynn/interfaces/ase_interface/ase_database.py b/hippynn/interfaces/ase_interface/ase_database.py index 1f808595..e193f347 100644 --- a/hippynn/interfaces/ase_interface/ase_database.py +++ b/hippynn/interfaces/ase_interface/ase_database.py @@ -144,7 +144,10 @@ def load_arrays(self, directory, filename, inputs, targets, quiet=False, allow_u natom = len(record["numbers"]) for k, v in record.items(): if isinstance(v, np.ndarray): - shape = array_dict[k].shape + if array_dict.get(k,None) is not None: + shape = array_dict[k].shape + else: + shape=[0] # Note this assumes the maximum number of atoms greater than the length of property of interest # E.g. 3 for dipole (make sure your training set has something with more than 3 atoms) # Or 6 for stress tensor (make sure your training set has something with more than 6 atoms) @@ -156,9 +159,11 @@ def load_arrays(self, directory, filename, inputs, targets, quiet=False, allow_u array_dict[k][i, :, :] = v elif (len(shape) == 3) and (shape[1] == max_n_atom): # 2D array, e.g. positions, forces array_dict[k][i, :natom, :] = v + elif (len(shape) == 1): + print('Skipping {}'.format(k)) else: raise ValueError("Shape of Numpy array for key: {} unknown.".format(k)) - elif isinstance(array_dict[k], np.ndarray): # Energy, float or integers only + elif isinstance(array_dict.get(k,None), np.ndarray): # Energy, float or integers only array_dict[k][i] = v if (k == "energy") and (("energy_per_atom" in var_list) or (allow_unfound)): # Add in per-atom-energy array_dict["energy_per_atom"][i] = v / natom diff --git a/hippynn/interfaces/ase_interface/calculator.py b/hippynn/interfaces/ase_interface/calculator.py index 586fdf14..b4d1f671 100644 --- a/hippynn/interfaces/ase_interface/calculator.py +++ b/hippynn/interfaces/ase_interface/calculator.py @@ -4,7 +4,6 @@ import warnings import torch -from ase.calculators import interface from ase.calculators.calculator import compare_atoms, PropertyNotImplementedError, Calculator # Calculator is required to allow HIPNN to be used with ASE Mixing Calculators from hippynn.graphs import find_relatives, find_unique_relative, get_subgraph, copy_subgraph, replace_node, GraphModule diff --git a/hippynn/interfaces/lammps_interface/mliap_interface.py b/hippynn/interfaces/lammps_interface/mliap_interface.py index acc27e32..0e9c7335 100644 --- a/hippynn/interfaces/lammps_interface/mliap_interface.py +++ b/hippynn/interfaces/lammps_interface/mliap_interface.py @@ -2,31 +2,32 @@ Interface for creating LAMMPS MLIAP Unified models. """ import pickle +import warnings import numpy as np import torch -torch.set_default_dtype(torch.float32) from lammps.mliap.mliap_unified_abc import MLIAPUnified import hippynn -from hippynn.graphs import (find_relatives,find_unique_relative, - get_subgraph, copy_subgraph, replace_node, IdxType, - GraphModule) +from hippynn.tools import device_fallback +from hippynn.graphs import find_relatives, find_unique_relative, get_subgraph, copy_subgraph, replace_node, IdxType, GraphModule from hippynn.graphs.indextypes import index_type_coercion from hippynn.graphs.gops import check_link_consistency -from hippynn.graphs.nodes.base import InputNode, MultiNode, AutoNoKw, ExpandParents +from hippynn.graphs.nodes.base import InputNode, SingleNode, MultiNode, AutoNoKw, ExpandParents from hippynn.graphs.nodes.tags import Encoder, PairIndexer +from hippynn.graphs.nodes.indexers import PaddingIndexer from hippynn.graphs.nodes.physics import GradientNode, VecMag from hippynn.graphs.nodes.inputs import SpeciesNode from hippynn.graphs.nodes.pairs import PairFilter + class MLIAPInterface(MLIAPUnified): """ Class for creating ML-IAP Unified model based on hippynn graphs. """ - def __init__(self, energy_node, element_types, ndescriptors=1, - model_device=torch.device("cpu")): + + def __init__(self, energy_node, element_types, ndescriptors=1, model_device=torch.device("cpu"), compute_dtype=torch.float32): """ :param energy_node: Node for energy :param element_types: list of atomic symbols corresponding to element types @@ -41,16 +42,20 @@ def __init__(self, energy_node, element_types, ndescriptors=1, # Build the calculator self.rcutfac, self.species_set, self.graph = setup_LAMMPS_graph(energy_node) self.nparams = sum(p.nelement() for p in self.graph.parameters()) - self.graph.to(torch.float64) + self.compute_dtype = compute_dtype + self.graph.to(compute_dtype) def compute_gradients(self, data): pass - + def compute_descriptors(self, data): pass - - def as_tensor(self,array): - return torch.as_tensor(array,device=self.model_device) + + def as_tensor(self, array): + return torch.as_tensor(array, device=self.model_device) + + def empty_tensor(self,dimentions): + return torch.empty(dimentions,device=self.model_device) def compute_forces(self, data): """ @@ -58,47 +63,62 @@ def compute_forces(self, data): :return None This function writes results to the input `data`. """ - elems = self.as_tensor(data.elems).type(torch.int64).reshape(1, data.ntotal) - z_vals = self.species_set[elems+1] - pair_i = self.as_tensor(data.pair_i).type(torch.int64) - pair_j = self.as_tensor(data.pair_j).type(torch.int64) - rij = self.as_tensor(data.rij).type(torch.float64) - nlocal = self.as_tensor(data.nlistatoms) - - # note your sign for rij might need to be +1 or -1, depending on how your implementation works - inputs = [z_vals, pair_i, pair_j, -rij, nlocal] - atom_energy, total_energy, fij = self.graph(*inputs) - - # Test if we are using lammps-kokkos or not. Is there a more clear way to do that? - if isinstance(data.elems,np.ndarray): - return_device = 'cpu' - else: - # Hope that kokkos device and pytorch device are the same (default cuda) - return_device = elems.device - - atom_energy = atom_energy.squeeze(1).detach().to(return_device) - total_energy = total_energy.detach().to(return_device) - - f = self.as_tensor(data.f) - fij = fij.type(f.dtype).detach().to(return_device) - - if return_device=="cpu": - fij = fij.numpy() - data.eatoms = atom_energy.numpy().astype(np.double) - else: - eatoms = torch.as_tensor(data.eatoms,device=return_device) - eatoms.copy_(atom_energy) - - data.update_pair_forces(fij) - data.energy = total_energy.item() + nlocal = self.as_tensor(data.nlistatoms) + if nlocal.item() > 0: + #If there are no local atoms, do nothing + elems = self.as_tensor(data.elems).type(torch.int64).reshape(1, data.ntotal) + z_vals = self.species_set[elems + 1] + npairs = data.npairs + if npairs > 0: + pair_i = self.as_tensor(data.pair_i).type(torch.int64) + pair_j = self.as_tensor(data.pair_j).type(torch.int64) + rij = self.as_tensor(data.rij).type(self.compute_dtype) + else: + pair_i = self.empty_tensor(0).type(torch.int64) + pair_j = self.empty_tensor(0).type(torch.int64) + rij = self.empty_tensor([0,3]).type(self.compute_dtype) + + # note your sign for rij might need to be +1 or -1, depending on how your implementation works + inputs = [z_vals, pair_i, pair_j, -rij, nlocal] + atom_energy, total_energy, fij = self.graph(*inputs) + + # Test if we are using lammps-kokkos or not. Is there a more clear way to do that? + if isinstance(data.elems, np.ndarray): + return_device = "cpu" + else: + # Hope that kokkos device and pytorch device are the same (default cuda) + return_device = elems.device + + atom_energy = atom_energy.squeeze(1).detach().to(return_device) + total_energy = total_energy.detach().to(return_device) + + f = self.as_tensor(data.f) + fij = fij.type(f.dtype).detach().to(return_device) + + if return_device == "cpu": + fij = fij.numpy() + data.eatoms = atom_energy.numpy().astype(np.double) + else: + eatoms = torch.as_tensor(data.eatoms, device=return_device) + eatoms.copy_(atom_energy) + if npairs > 0: + data.update_pair_forces(fij) + data.energy = total_energy.item() def __getstate__(self): self.species_set = self.species_set.to(torch.device("cpu")) self.graph.to(torch.device("cpu")) return self.__dict__.copy() - + def __setstate__(self, state): self.__dict__.update(state) + try: + torch.ones(0).to(self.model_device) + except RuntimeError: + fallback = device_fallback() + warnings.warn(f"Model device ({self.model_device}) not found, falling back to f{fallback}") + self.model_device = fallback + self.species_set = self.species_set.to(self.model_device) self.graph.to(self.model_device) @@ -114,19 +134,21 @@ def setup_LAMMPS_graph(energy): why = "Generating LAMMPS Calculator interface" subgraph = get_subgraph(required_nodes) - search_fn = lambda targ,sg: lambda n: n in sg and isinstance(n,targ) + search_fn = lambda targ, sg: lambda n: n in sg and isinstance(n, targ) pair_indexers = find_relatives(required_nodes, search_fn(PairIndexer, subgraph), why_desc=why) new_required, new_subgraph = copy_subgraph(required_nodes, assume_inputed=pair_indexers, tag="LAMMPS") pair_indexers = find_relatives(new_required, search_fn(PairIndexer, new_subgraph), why_desc=why) - species = find_unique_relative(new_required, search_fn(SpeciesNode, new_subgraph),why_desc=why) + species = find_unique_relative(new_required, search_fn(SpeciesNode, new_subgraph), why_desc=why) encoder = find_unique_relative(species, search_fn(Encoder, new_subgraph), why_desc=why) + padding_indexer = find_unique_relative(species, search_fn(PaddingIndexer, new_subgraph), why_desc=why) + inv_real_atoms = padding_indexer.inv_real_atoms + species_set = torch.as_tensor(encoder.species_set).to(torch.int64) min_radius = max(p.dist_hard_max for p in pair_indexers) - ############################################################### # Set up graph to accept external pair indices and shifts @@ -139,15 +161,17 @@ def setup_LAMMPS_graph(energy): in_nlocal = InputNode("(LAMMPS)nlocal") in_nlocal._index_state = hippynn.graphs.IdxType.Scalar pair_dist = VecMag("(LAMMPS)pair_dist", in_pair_coord) + mapped_pair_first = ReIndexAtomNode("pair_first_internal", (in_pair_first, inv_real_atoms)) + mapped_pair_second = ReIndexAtomNode("pair_second_internal", (in_pair_second, inv_real_atoms)) - new_inputs = [species,in_pair_first,in_pair_second,in_pair_coord,in_nlocal] - - # Construct Filters and replace the existing pair indexers with the + new_inputs = [species, in_pair_first, in_pair_second, in_pair_coord, in_nlocal] + + # Construct Filters and replace the existing pair indexers with the # corresponding new (filtered) node that accepts external pairs of atoms for pi in pair_indexers: if pi.dist_hard_max == min_radius: - replace_node(pi.pair_first, in_pair_first, disconnect_old=False) - replace_node(pi.pair_second, in_pair_second, disconnect_old=False) + replace_node(pi.pair_first, mapped_pair_first, disconnect_old=False) + replace_node(pi.pair_second, mapped_pair_second, disconnect_old=False) replace_node(pi.pair_coord, in_pair_coord, disconnect_old=False) replace_node(pi.pair_dist, pair_dist, disconnect_old=False) pi.disconnect() @@ -155,7 +179,7 @@ def setup_LAMMPS_graph(energy): mapped_node = PairFilter( "DistanceFilter-LAMMPS", (pair_dist, in_pair_first, in_pair_second, in_pair_coord), - dist_hard_max=pi.dist_hard_max, + dist_hard_max=pi.dist_hard_max, ) replace_node(pi.pair_first, mapped_node.pair_first, disconnect_old=False) replace_node(pi.pair_second, mapped_node.pair_second, disconnect_old=False) @@ -177,7 +201,6 @@ def setup_LAMMPS_graph(energy): "an object with an `atom_energies` attribute." ) - local_atom_energy = LocalAtomEnergyNode("(LAMMPS)local_atom_energy", (atom_energies, in_nlocal)) grad_rij = GradientNode("(LAMMPS)grad_rij", (local_atom_energy.total_local_energy, in_pair_coord), -1) @@ -190,10 +213,25 @@ def setup_LAMMPS_graph(energy): return min_radius / 2, species_set, mod +class ReIndexAtomMod(torch.nn.Module): + def forward(self, raw_atom_index_array, inverse_real_atoms): + return inverse_real_atoms[raw_atom_index_array] + + +class ReIndexAtomNode(AutoNoKw, SingleNode): + _input_names = "raw_atom_index_array", "inverse_real_atoms" + _main_output = "total_local_energy" + _auto_module_class = ReIndexAtomMod + + def __init__(self, name, parents, module="auto", **kwargs): + self._index_state = parents[0]._index_state + super().__init__(name, parents, module=module, **kwargs) + + class LocalAtomsEnergy(torch.nn.Module): def __init__(self): super().__init__() - + def forward(self, all_atom_energies, nlocal): local_atom_energies = all_atom_energies[:nlocal] total_local_energy = torch.sum(local_atom_energies) @@ -211,6 +249,6 @@ class LocalAtomEnergyNode(AutoNoKw, ExpandParents, MultiNode): _parent_expander.get_main_outputs() _parent_expander.require_idx_states(IdxType.Atoms, IdxType.Scalar) - def __init__(self, name, parents, module='auto', **kwargs): + def __init__(self, name, parents, module="auto", **kwargs): parents = self.expand_parents(parents) super().__init__(name, parents, module=module, **kwargs) diff --git a/hippynn/layers/__init__.py b/hippynn/layers/__init__.py index 58d549e0..c2180c26 100644 --- a/hippynn/layers/__init__.py +++ b/hippynn/layers/__init__.py @@ -6,3 +6,4 @@ from . import targets from . import transform from . import physics +from . import excited \ No newline at end of file diff --git a/hippynn/layers/excited.py b/hippynn/layers/excited.py new file mode 100644 index 00000000..8096e743 --- /dev/null +++ b/hippynn/layers/excited.py @@ -0,0 +1,128 @@ +import torch +from . import indexers +from torch import Tensor + + +class NACR(torch.nn.Module): + """ + Compute NAC vector * ΔE. Originally in hippynn.layers.physics. + """ + + def __init__(self): + super().__init__() + + def forward( + self, + charges1: Tensor, + charges2: Tensor, + positions: Tensor, + energy1: Tensor, + energy2: Tensor, + ): + dE = energy2 - energy1 + nacr = torch.autograd.grad( + charges2, [positions], grad_outputs=[charges1], create_graph=True + )[0].reshape(len(dE), -1) + return nacr * dE + + +class NACRMultiState(torch.nn.Module): + """ + Compute NAC vector * ΔE for all paris of states. Originally in hippynn.layers.physics. + """ + + def __init__(self, n_target=1): + self.n_target = n_target + super().__init__() + + def forward(self, charges: Tensor, positions: Tensor, energies: Tensor): + # charges shape: n_molecules, n_atoms, n_targets + # positions shape: n_molecules, n_atoms, 3 + # energies shape: n_molecules, n_targets + # dE shape: n_molecules, n_targets, n_targets + dE = energies.unsqueeze(1) - energies.unsqueeze(2) + # take the upper triangle excluding the diagonal + indices = torch.triu_indices( + self.n_target, self.n_target, offset=1, device=dE.device + ) + # dE shape: n_molecules, n_pairs + # n_pairs = n_targets * (n_targets - 1) / 2 + dE = dE[..., indices[0], indices[1]] + # compute q1 * dq2/dR + nacr_ij = [] + for i, j in zip(*indices): + nacr = torch.autograd.grad( + charges[..., j], + positions, + grad_outputs=charges[..., i], + create_graph=True, + )[0] + nacr_ij.append(nacr) + # nacr shape: n_molecules, n_atoms, 3, n_pairs + nacr = torch.stack(nacr_ij, dim=1) + n_molecule, n_pairs, n_atoms, n_dims = nacr.shape + nacr = nacr.reshape(n_molecule, n_pairs, n_atoms * n_dims) + # multiply dE + return nacr * dE.unsqueeze(2) + + +class LocalEnergy(torch.nn.Module): + def __init__(self, feature_sizes, first_is_interacting=False): + + super().__init__() + self.first_is_interacting = first_is_interacting + if first_is_interacting: + feature_sizes = feature_sizes[1:] + + self.feature_sizes = feature_sizes + + self.summer = indexers.MolSummer() + self.n_terms = len(feature_sizes) + biases = (first_is_interacting, *(True for _ in range(self.n_terms - 1))) + + self.layers = torch.nn.ModuleList(torch.nn.Linear(nf, 1, bias=bias) for nf, bias in zip(feature_sizes, biases)) + self.players = torch.nn.ModuleList(torch.nn.Linear(nf, 1, bias=False) for nf in feature_sizes) + self.ninf = float("-inf") + + def forward(self, all_features, mol_index, atom_index, n_molecules, n_atoms_max): + """ + :param all_features: list of feature tensors + :param mol_index: which molecule is the atom + :param atom_index: which atom in the molecule is that atom + :param n_molecules: total number of molecules in the batch + :param n_atoms_max: maximum number of atoms in the batch + :return: contributed_energy, atom_energy, atom_preenergy, prob, propensity + """ + + if self.first_is_interacting: + all_features = all_features[1:] + + partial_preenergy = [lay(x) for x, lay in zip(all_features, self.layers)] + atom_preenergy = sum(partial_preenergy) + partial_potentials = [lay(x) for x, lay in zip(all_features, self.players)] + propensity = sum(partial_potentials) # Keep in mind that this has shape (natoms,1) + + # This segment does not need gradients, we are constructing the subtraction parameters for softmax + # which results in a calculation that does not under or overflow; the result is most accurate this way + # But actually invariant to the subtraction used, so it does not require a grad. + # It's a standard SoftMax technique, however, the implementation is not built into pytorch for + # the molecule/atom framework. + with torch.autograd.no_grad(): + propensity_molatom = all_features[0].new_full((n_molecules, n_atoms_max, 1), self.ninf) + propensity_molatom[mol_index, atom_index] = propensity + propensity_norms = propensity_molatom.max(dim=1)[0] # first element is max vals, 2nd is max position + propensity_norm_atoms = propensity_norms[mol_index] + + propensity_normed = propensity - propensity_norm_atoms + + # Calculate probabilities with molecule version of softmax + relative_prob = torch.exp(propensity_normed) + z_factor_permol = self.summer(relative_prob, mol_index, n_molecules) + atom_zfactor = z_factor_permol[mol_index] + prob = relative_prob / atom_zfactor + + # Find molecular sum + atom_energy = prob * atom_preenergy + contributed_energy = self.summer(atom_energy, mol_index, n_molecules) + + return contributed_energy, atom_energy, atom_preenergy, prob, propensity diff --git a/hippynn/layers/physics.py b/hippynn/layers/physics.py index 8a577967..0d02e86e 100644 --- a/hippynn/layers/physics.py +++ b/hippynn/layers/physics.py @@ -6,8 +6,7 @@ import torch from torch import Tensor -from . import pairs -from . import indexers +from . import indexers, pairs class Gradient(torch.nn.Module): @@ -54,8 +53,8 @@ def __init__(self): def forward(self, charges: Tensor, positions: Tensor, mol_index: Tensor, n_molecules: int): if charges.shape[1] > 1: # charges contain multiple targets, so set up broadcasting - charges = charges.unsqueeze(1) - positions = positions.unsqueeze(2) + charges = charges.unsqueeze(2) + positions = positions.unsqueeze(1) # shape is (n_atoms, 3, n_targets) in multi-target mode # shape is (n_atoms, 3) in single target mode @@ -259,9 +258,7 @@ def forward(self, features, species): class VecMag(torch.nn.Module): def forward(self, vector_feature): - return torch.norm(vector_feature, dim=1) - - + return torch.norm(vector_feature, dim=1).unsqueeze(1) class CombineEnergy(torch.nn.Module): diff --git a/hippynn/layers/targets.py b/hippynn/layers/targets.py index 4b8128d0..f8d262da 100644 --- a/hippynn/layers/targets.py +++ b/hippynn/layers/targets.py @@ -67,7 +67,7 @@ def forward(self, all_features, mol_index, n_molecules): total_hier = torch.zeros_like(total_energies) mol_hier = torch.zeros_like(total_energies) total_atom_hier = torch.zeros_like(total_atomen) - batch_hier = torch.zeros(1,dtype=total_energies.dtype,device=total_energies.dtype) + batch_hier = torch.zeros(1,dtype=total_energies.dtype,device=total_energies.device) return total_energies, total_atomen, partial_sums, total_hier, total_atom_hier, mol_hier, batch_hier @@ -235,63 +235,3 @@ def forward(self, all_features, pair_first, pair_second, pair_dist): return total_bonds, bond_hier -class LocalEnergy(torch.nn.Module): - def __init__(self, feature_sizes, first_is_interacting=False): - - super().__init__() - self.first_is_interacting = first_is_interacting - if first_is_interacting: - feature_sizes = feature_sizes[1:] - - self.feature_sizes = feature_sizes - - self.summer = indexers.MolSummer() - self.n_terms = len(feature_sizes) - biases = (first_is_interacting, *(True for _ in range(self.n_terms - 1))) - - self.layers = torch.nn.ModuleList(torch.nn.Linear(nf, 1, bias=bias) for nf, bias in zip(feature_sizes, biases)) - self.players = torch.nn.ModuleList(torch.nn.Linear(nf, 1, bias=False) for nf in feature_sizes) - self.ninf = float("-inf") - - def forward(self, all_features, mol_index, atom_index, n_molecules, n_atoms_max): - """ - :param all_features: list of feature tensors - :param mol_index: which molecule is the atom - :param atom_index: which atom in the molecule is that atom - :param n_molecules: total number of molecules in the batch - :param n_atoms_max: maximum number of atoms in the batch - :return: contributed_energy, atom_energy, atom_preenergy, prob, propensity - """ - - if self.first_is_interacting: - all_features = all_features[1:] - - partial_preenergy = [lay(x) for x, lay in zip(all_features, self.layers)] - atom_preenergy = sum(partial_preenergy) - partial_potentials = [lay(x) for x, lay in zip(all_features, self.players)] - propensity = sum(partial_potentials) # Keep in mind that this has shape (natoms,1) - - # This segment does not need gradients, we are constructing the subtraction parameters for softmax - # which results in a calculation that does not under or overflow; the result is most accurate this way - # But actually invariant to the subtraction used, so it does not require a grad. - # It's a standard SoftMax technique, however, the implementation is not built into pytorch for - # the molecule/atom framework. - with torch.autograd.no_grad(): - propensity_molatom = all_features[0].new_full((n_molecules, n_atoms_max, 1), self.ninf) - propensity_molatom[mol_index, atom_index] = propensity - propensity_norms = propensity_molatom.max(dim=1)[0] # first element is max vals, 2nd is max position - propensity_norm_atoms = propensity_norms[mol_index] - - propensity_normed = propensity - propensity_norm_atoms - - # Calculate probabilities with molecule version of softmax - relative_prob = torch.exp(propensity_normed) - z_factor_permol = self.summer(relative_prob, mol_index, n_molecules) - atom_zfactor = z_factor_permol[mol_index] - prob = relative_prob / atom_zfactor - - # Find molecular sum - atom_energy = prob * atom_preenergy - contributed_energy = self.summer(atom_energy, mol_index, n_molecules) - - return contributed_energy, atom_energy, atom_preenergy, prob, propensity diff --git a/hippynn/networks/hipnn.py b/hippynn/networks/hipnn.py index 792e8b9c..b9b629b6 100644 --- a/hippynn/networks/hipnn.py +++ b/hippynn/networks/hipnn.py @@ -18,7 +18,7 @@ # computes E0 for the energy layer. -def compute_hipnn_e0(encoder, Z_Data, en_data, peratom=False): +def compute_hipnn_e0(encoder, Z_Data, en_data, peratom=False, fit_dtype=torch.float64): """ :param encoder: encoder of species to features (one-hot representation, probably) @@ -28,6 +28,8 @@ def compute_hipnn_e0(encoder, Z_Data, en_data, peratom=False): :return: energy per species as shape (n_features_encoded, 1) """ + original_dtype = en_data.dtype + en_data = en_data.to(fit_dtype) x, nonblank = encoder(Z_Data) sums = x.sum(dim=1).to(en_data.dtype) @@ -53,6 +55,7 @@ def compute_hipnn_e0(encoder, Z_Data, en_data, peratom=False): # if n_targets is included e_per_species = e_per_species.T + e_per_species = e_per_species.to(original_dtype) return e_per_species