forked from lanl/hippynn
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'upstream/development' into development
- Loading branch information
Showing
26 changed files
with
1,017 additions
and
193 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# Check that package builds | ||
name: Build Checks | ||
|
||
on: | ||
push: | ||
pull_request: | ||
workflow_dispatch: | ||
|
||
jobs: | ||
build: | ||
runs-on: ubuntu-latest | ||
|
||
steps: | ||
- uses: actions/checkout@v2 | ||
- uses: actions/setup-python@v2 | ||
with: | ||
python-version: "3.10" | ||
|
||
- name: Install dependencies | ||
run: >- | ||
python -m pip install --user --upgrade setuptools wheel | ||
- name: Build | ||
run: >- | ||
python setup.py sdist bdist_wheel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,4 +4,4 @@ | |
__pycache__/ | ||
*.pyc | ||
build/ | ||
hippynn.egg-info/* | ||
hippynn.egg-info/* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
Non-Adiabiatic Excited States | ||
============================= | ||
|
||
`hippynn` has features for training to excited-state energies, transition dipoles, and | ||
the non-adiabatic coupling vectors (NACR). These features can be found in | ||
:mod:`~hippynn.graphs.nodes.excited`. | ||
|
||
For a more detailed description, please see the paper [Li2023]_ | ||
|
||
Multi-targets nodes are recommended over the usage of one node per target. | ||
|
||
For energies, the node can be constructed just like the ground-state | ||
counterpart:: | ||
|
||
energy = targets.HEnergyNode("E", network, module_kwargs={"n_target": n_states + 1}) | ||
mol_energy = energy.mol_energy | ||
mol_energy.db_name = "E" | ||
|
||
Note that a `multi-target node` is used here, defined by the keyword | ||
``module_kwargs={"n_target": n_states + 1}``. Here, `n_states` is the number of | ||
*excited* states in consideration. The extra state is for the ground state, which is often | ||
useful. The database name is simply `E` with a shape of ``(n_molecules, | ||
n_states+1)``. | ||
|
||
Predicting the transition dipoles is also similar to the ground-state permanent | ||
dipole:: | ||
|
||
charge = targets.HChargeNode("Q", network, module_kwargs={"n_target": n_states}) | ||
dipole = physics.DipoleNode("D", (charge, positions), db_name="D") | ||
|
||
The database name is `D` with a shape of ``(n_molecules, n_states, 3)``. | ||
|
||
For NACR, to avoid singularity problems, we enforcing the training of NACR*ΔE | ||
instead:: | ||
|
||
nacr = excited.NACRMultiStateNode( | ||
"ScaledNACR", | ||
(charge, positions, energy), | ||
db_name="ScaledNACR", | ||
module_kwargs={"n_target": n_states}, | ||
) | ||
|
||
For NACR between state `i` and `j`, :math:`\boldsymbol{d}_{ij}`, it is expressed | ||
in the following way | ||
|
||
.. math:: | ||
\boldsymbol{d}_{ij}\Delta E_{ij} = \Delta E_{ij}\boldsymbol{q}_i \frac{\partial\boldsymbol{q}_j}{\partial\boldsymbol{R}} | ||
:math:`E_{ij}` is energy difference between state `i` and `j`, which is | ||
calculated internally in the NACR node based on the input of the ``energy`` | ||
node. :math:`\boldsymbol{R}` corresponding the ``positions`` node in the code. | ||
:math:`\boldsymbol{q}_{i}` and :math:`\boldsymbol{q}_{j}` are the transition | ||
atomic charges for state `i` and `j` contained in the ``charge`` node. This | ||
charge node can be constructed from scratch or reused from the dipole | ||
predictions. The database name is `ScaledNACR` with a shape of ``(n_molecules, | ||
n_states*(n_states-1)/2, 3*n_atoms)``. | ||
|
||
Due to the phase problem, when the loss function is constructed, the | ||
`phase-less` version of MAE or RMSE should be used:: | ||
|
||
energy_mae = loss.MAELoss.of_node(energy) | ||
dipole_mae = excited.MAEPhaseLoss.of_node(dipole) | ||
nacr_mae = excited.MAEPhaseLoss.of_node(nacr) | ||
|
||
:class:`~hippynn.graphs.nodes.excited.MAEPhaseLoss` and | ||
:class:`~hippynn.graphs.nodes.excited.MSEPhaseLoss` are the `phase-less` version MAE | ||
and MSE, which take the minimum error over the possible signs of the output. | ||
|
||
For a complete script, please take a look at ``examples/excited_states_azomethane.py``. | ||
|
||
.. [Li2023] | Machine Learning Framework for Modeling Exciton-Polaritons in Molecular Materials. | ||
| Li et. al, 2023. https://arxiv.org/abs/2306.02523 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
""" | ||
close_contact_finding.py | ||
This example shows how to use the hippynn function calculate_min_dists | ||
to find distances in the dataset where there are close contacts. | ||
Such a procedure is often useful in active learning to identify | ||
outlier data which inhibits training. | ||
This script was designed for an external dataset available at | ||
https://github.com/atomistic-ml/ani-al | ||
Note: It is necessary to untar the h5 data files in ani-al/data/ | ||
before running this script. | ||
""" | ||
import sys | ||
|
||
sys.path.append("../../datasets/ani-al/readers/lib/") | ||
import pyanitools # Check if pyanitools is found early | ||
|
||
### Loading the database | ||
from hippynn.databases.h5_pyanitools import PyAniDirectoryDB | ||
|
||
database = PyAniDirectoryDB( | ||
directory="../../datasets/ani-al/data/", | ||
seed=0, | ||
quiet=False, | ||
allow_unfound=True, | ||
inputs=None, | ||
targets=None, | ||
) | ||
|
||
### Calculating minimum distance array | ||
from hippynn.pretraining import calculate_min_dists | ||
|
||
min_dist_array = calculate_min_dists( | ||
database.arr_dict, | ||
species_name="species", | ||
positions_name="coordinates", | ||
cell_name="cell", # for open boundaries, do not pass a cell name (or pass None) | ||
dist_hard_max=4.0, # check for distances up to this radius | ||
batch_size=50, | ||
) | ||
|
||
print("Minimum distance in configurations:") | ||
print(f"{min_dist_array.dtype=}") | ||
print(f"{min_dist_array.shape=}") | ||
print("First 100 values:", min_dist_array[:100]) | ||
|
||
### Making a plot of the minimum distance for each configuration | ||
import matplotlib.pyplot as plt | ||
|
||
plt.hist(min_dist_array, bins=100) | ||
plt.title("Minimum distance per config") | ||
plt.xlabel("Distance") | ||
plt.ylabel("Count") | ||
plt.yscale("log") | ||
plt.show() | ||
|
||
#### How to remove and separate low distance configurations | ||
dist_thresh = 1.7 # Note: what threshold to use may be highly problem-dependent. | ||
low_dist_configs = min_dist_array < dist_thresh | ||
where_low_dist = database.arr_dict["indices"][low_dist_configs] | ||
|
||
# This makes the low distance configurations | ||
# into their own split, separate from train/valid/test. | ||
database.make_explicit_split("LOW_DISTANCE_FILTER", where_low_dist) | ||
|
||
# This deletes the new split, although deleting it is not necessary; | ||
# this data iwll not be included in train/valid/test splits | ||
del database.splits["LOW_DISTANCE_FILTER"] | ||
|
||
### Continue on with data processing, e.g. | ||
database.make_trainvalidtest_split(test_size=0.1, valid_size=0.1) |
Oops, something went wrong.