Skip to content

Commit

Permalink
mfi-M3GNet implementation is added (#77)
Browse files Browse the repository at this point in the history
* mfi-M3GNet implementation is added

* adding the description of stress_weight in _dynamics.py

* fix the version of pip in requirements.txt

Avoid update of pip version for installing tensorflow==2.11.0.

* Update requirements.txt

* remove update of pip version in linting.yml

* remove upgrade of pip version in testing.yml

* Try to upgrade the version of Tensorflow

* fix Too many positional arguments for method call in _converters.py

* further fix Too many positional arguments for method call in _converters.py

* fix Possibly using variable 'mgb_val' before assignment in _property.py

* fix Possibly using variable 'mgb_val' before assignment in _potential.py

* fix black

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixed nbqa-flake8 in jupyter notebook

* fix black

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* try to upgrade pymatgen

* fix black

* fix bug in spherical harmonic function

* fix the sympy version

* change back to _conjugate in _matg.py

* try to fix it by adding .ref()

* skip the test of isolated atom for now

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* skip the test of isolated atom for now again

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* put back the pip upgrade

* correct the upgrade of pip in testing.yaml

* added element_refs in _dynamics.py

* merge the conflicts in test_model.py

* fix pylint

* fix the import for AtomRef

* fix pytest

* fix pytest again

* fix pytest again

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Shyue Ping Ong <shyuep@users.noreply.github.com>
  • Loading branch information
3 people authored Oct 4, 2024
1 parent 3666db8 commit 3a8b14d
Show file tree
Hide file tree
Showing 26 changed files with 169 additions and 25 deletions.
4 changes: 2 additions & 2 deletions examples/Relaxation of LiFePO4.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
],
"source": [
"relaxer = Relaxer()\n",
"relax_results: dict\n",
"relax_results: dict = {}\n",
"%time relax_results = relaxer.relax(lfp_strained)\n",
"relaxed_struct = relax_results[\"final_structure\"]"
]
Expand Down Expand Up @@ -394,7 +394,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
"version": "3.10.9"
},
"vscode": {
"interpreter": {
Expand Down
1 change: 1 addition & 0 deletions m3gnet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""
The M3GNet framework package
"""

__version__ = "0.2.4"
1 change: 1 addition & 0 deletions m3gnet/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Callback functions
"""

import os
from typing import Dict

Expand Down
1 change: 1 addition & 0 deletions m3gnet/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Data types"""

import numpy as np
import tensorflow as tf

Expand Down
5 changes: 3 additions & 2 deletions m3gnet/graph/_batch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Collate material graphs
"""

from typing import AnyStr, List, Optional, Tuple, Union, overload

import numpy as np
Expand Down Expand Up @@ -144,12 +145,12 @@ def _concatenate(list_of_arrays: List, name: AnyStr) -> Optional[np.ndarray]:

@overload
def assemble_material_graph(graphs: List[MaterialGraph]) -> MaterialGraph:
...
raise NotImplementedError


@overload
def assemble_material_graph(graphs: List[List]) -> List:
...
raise NotImplementedError


def assemble_material_graph(graphs):
Expand Down
13 changes: 8 additions & 5 deletions m3gnet/graph/_converters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Classes to convert a structure into a graph
"""

import logging
from abc import abstractmethod
from typing import Dict, List, Optional
Expand Down Expand Up @@ -77,7 +78,7 @@ def get_states(self, structure: StructureOrMolecule):
return states

@abstractmethod
def convert(self, structure: StructureOrMolecule, **kwargs) -> MaterialGraph:
def convert(self, structure: StructureOrMolecule, state_attr=None, *args, **kwargs) -> MaterialGraph:
"""
Convert the structure into a graph
Args:
Expand All @@ -98,17 +99,18 @@ def convert_many(self, structures: List[StructureOrMolecule], **kwargs) -> Mater
graphs = [self.convert(structure, **kwargs) for structure in structures]
return assemble_material_graph(graphs)

def __call__(self, structure: StructureOrMolecule, *args, **kwargs) -> MaterialGraph:
def __call__(self, structure: StructureOrMolecule, state_attr=None, *args, **kwargs) -> MaterialGraph:
"""
A thin wrapper for calling `convert` method
Args:
structure:
state_attr:
*args:
**kwargs:
Returns:
"""
return self.convert(structure)
return self.convert(structure, state_attr, *args, **kwargs)


@register
Expand Down Expand Up @@ -145,11 +147,12 @@ def __init__(

super().__init__(**kwargs)

def convert(self, structure: StructureOrMolecule, **kwargs) -> MaterialGraph:
def convert(self, structure: StructureOrMolecule, state_attr=None, *args, **kwargs) -> MaterialGraph:
"""
Convert the structure into graph
Args:
structure: Structure or Molecule
state_attr: Global state attribute (e.g.Fidelity of data)
Returns
MaterialGraph
"""
Expand All @@ -158,7 +161,7 @@ def convert(self, structure: StructureOrMolecule, **kwargs) -> MaterialGraph:
atom_positions = np.asarray(structure.get_positions(), dtype=DataType.np_float)
else:
atom_positions = np.array(structure.cart_coords, dtype=DataType.np_float)
state_attributes = self.get_states(structure)
state_attributes = state_attr if state_attr is not None else self.get_states(structure)

sender_indices, receiver_indices, images, distances = get_fixed_radius_bonding(structure, self.cutoff)

Expand Down
1 change: 1 addition & 0 deletions m3gnet/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Graph layers
"""

from ._aggregate import AtomReduceState
from ._atom import AtomNetwork, GatedAtomUpdate
from ._atom_ref import AtomRef, BaseAtomRef
Expand Down
1 change: 1 addition & 0 deletions m3gnet/layers/_aggregate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Aggregate classes. Aggregating happens when bond attributes """

from typing import Callable, List, Union

import tensorflow as tf
Expand Down
13 changes: 11 additions & 2 deletions m3gnet/layers/_atom_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,12 @@ def __init__(
if property_per_element is None:
self.property_per_element = np.zeros(shape=(max_z + 1,))
else:
self.property_per_element = np.array(property_per_element).ravel()
self.property_per_element = np.array(property_per_element)
if self.property_per_element.ndim > 1:
self.n_state = self.property_per_element.shape[0]
else:
self.property_per_element = np.array(property_per_element).ravel()
self.n_state = 1
self.max_z = max_z

def _get_feature_matrix(self, structs_or_graphs):
Expand Down Expand Up @@ -144,7 +149,11 @@ def call(self, graph: List, **kwargs):
Returns:
"""
atomic_numbers = graph[Index.ATOMS][:, 0]
atom_energies = tf.gather(tf.cast(self.property_per_element, DataType.tf_float), atomic_numbers)
if self.n_state == 1:
atom_energies = tf.gather(tf.cast(self.property_per_element, DataType.tf_float), atomic_numbers)
else:
state_property_per_element = self.property_per_element[:, int(graph[Index.STATES])]
atom_energies = tf.gather(tf.cast(state_property_per_element, DataType.tf_float), atomic_numbers)
res = tf.math.segment_sum(atom_energies, get_segment_indices_from_n(graph[Index.N_ATOMS]))
return tf.reshape(res, (-1, 1))

Expand Down
1 change: 1 addition & 0 deletions m3gnet/layers/_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Base layer classes
"""

from typing import Callable, List, Optional

import tensorflow as tf
Expand Down
1 change: 1 addition & 0 deletions m3gnet/layers/_core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Core layers provide basic operations, e.g., MLP
"""

from typing import Dict, List, Union

import tensorflow as tf
Expand Down
1 change: 1 addition & 0 deletions m3gnet/layers/_gn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Materials Graph Network
"""

from copy import deepcopy
from typing import List, Optional

Expand Down
11 changes: 10 additions & 1 deletion m3gnet/layers/_readout.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Readout compress a graph into a vector
"""

from typing import List, Optional

import tensorflow as tf
Expand Down Expand Up @@ -109,15 +110,17 @@ class ReduceReadOut(ReadOut):
This could be summing up the atoms or bonds, or taking the mean, etc.
"""

def __init__(self, method: str = "mean", field="atoms", **kwargs):
def __init__(self, method: str = "mean", field="atoms", output_latent_feats: bool = False, **kwargs):
"""
Args:
method (str): method for the reduction
field (str): the field of MaterialGraph to perform the reduction
output_latent_feats (bool): whether output latent atomic features
**kwargs:
"""
self.method = method
self.field = field
self.output_latent_feats = output_latent_feats
super().__init__(**kwargs)
self.method_func = METHOD_MAPPING.get(method)

Expand All @@ -131,6 +134,12 @@ def call(self, graph: List, **kwargs) -> tf.Tensor:
"""
field = graph[getattr(Index, self.field.upper())]
n_field = graph[getattr(Index, f"n_{self.field}".upper())]
if self.output_latent_feats is True:
return field, self.method_func(
field,
get_segment_indices_from_n(n_field),
num_segements=tf.shape(n_field)[0],
)
return self.method_func( # type: ignore
field,
get_segment_indices_from_n(n_field),
Expand Down
1 change: 1 addition & 0 deletions m3gnet/layers/_three_body.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Three body basis expansion
"""

from typing import List

import tensorflow as tf
Expand Down
1 change: 1 addition & 0 deletions m3gnet/layers/_two_body.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Calculate distance from atom positions and indices
"""

from typing import List

import tensorflow as tf
Expand Down
1 change: 1 addition & 0 deletions m3gnet/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Graph pretrained
"""

from ._base import BasePotential, GraphModelMixin, Potential
from ._dynamics import M3GNetCalculator, MolecularDynamics, Relaxer
from ._m3gnet import M3GNet
Expand Down
58 changes: 50 additions & 8 deletions m3gnet/models/_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from pymatgen.core.structure import Molecule, Structure
from pymatgen.io.ase import AseAtomsAdaptor

from m3gnet.layers import AtomRef
from ._base import Potential
from ._m3gnet import M3GNet

Expand All @@ -47,19 +48,31 @@ class M3GNetCalculator(Calculator):

implemented_properties = ["energy", "free_energy", "forces", "stress"]

def __init__(self, potential: Potential, compute_stress: bool = True, stress_weight: float = 1.0, **kwargs):
def __init__(
self,
potential: Potential,
compute_stress: bool = True,
stress_weight: float = 1.0,
state_attr=None,
element_refs=None,
**kwargs,
):
"""
Args:
potential (Potential): m3gnet.models.Potential
compute_stress (bool): whether to calculate the stress
stress_weight (float): the stress weight.
state_attr (np.ndarray): global state attribute (e.g. fidelity of data)
element_refs (np.ndarray): elemental energy offsets.
**kwargs:
"""
super().__init__(**kwargs)
self.potential = potential
self.compute_stress = compute_stress
self.stress_weight = stress_weight
self.state_attr = state_attr
self.element_refs = AtomRef(property_per_element=element_refs) if element_refs is not None else element_refs

def calculate(
self,
Expand All @@ -81,14 +94,27 @@ def calculate(
system_changes = system_changes or all_changes
super().calculate(atoms=atoms, properties=properties, system_changes=system_changes)

graph = self.potential.graph_converter(atoms)
graph = self.potential.graph_converter(atoms, self.state_attr)
graph_list = graph.as_tf().as_list()
results = self.potential.get_efs_tensor(graph_list, include_stresses=self.compute_stress)
offset = None
if self.element_refs is not None:
offset = self.element_refs(graph_list)

self.results.update(
energy=results[0].numpy().ravel()[0],
free_energy=results[0].numpy().ravel()[0],
energy=(
results[0].numpy().ravel()[0] + offset.numpy().ravel()[0]
if self.element_refs is not None
else results[0].numpy().ravel()[0]
),
free_energy=(
results[0].numpy().ravel()[0] + offset.numpy().ravel()[0]
if self.element_refs is not None
else results[0].numpy().ravel()[0]
),
forces=results[1].numpy(),
)

if self.compute_stress:
self.results.update(stress=results[2].numpy()[0] * self.stress_weight)

Expand All @@ -104,17 +130,21 @@ def __init__(
optimizer: Union[Optimizer, str] = "FIRE",
relax_cell: bool = True,
stress_weight: float = 0.01,
state_attr=None,
element_refs=None,
):
"""
Args:
potential (Optional[Union[Potential, str]]): a potential,
a str path to a saved model or a short name for saved model
that comes with M3GNet distribution
optimizer (str or ase Optimizer): the optimization algorithm.
optimizer (str or ase Optimizer): the optimization algorithm
Defaults to "FIRE"
relax_cell (bool): whether to relax the lattice cell
stress_weight (float): the stress weight for relaxation
state_attr (np.ndarray): global state attribute (e.g. fidelity of data)
element_refs (np.ndarray): elemental energy offsets.
"""
if isinstance(potential, str):
potential = Potential(M3GNet.load(potential))
Expand All @@ -129,7 +159,9 @@ def __init__(
optimizer_obj = optimizer

self.opt_class: Optimizer = optimizer_obj
self.calculator = M3GNetCalculator(potential=potential, stress_weight=stress_weight)
self.calculator = M3GNetCalculator(
potential=potential, stress_weight=stress_weight, state_attr=state_attr, element_refs=element_refs
)
self.relax_cell = relax_cell
self.potential = potential
self.ase_adaptor = AseAtomsAdaptor()
Expand All @@ -152,7 +184,7 @@ def relax(
Here fmax is a sum of force and stress forces
steps (int): max number of steps for relaxation
traj_file (str): the trajectory file for saving
interval (int): the step interval for saving the trajectories
interval (int): the step interval for saving the trajectories.
**kwargs:
Returns:
"""
Expand Down Expand Up @@ -257,6 +289,9 @@ def __init__(
logfile: Optional[str] = None,
loginterval: int = 1,
append_trajectory: bool = False,
stress_weight: float = 1 / 160.21766208,
state_attr=None,
element_refs=None,
):
"""
Expand All @@ -276,6 +311,9 @@ def __init__(
logfile (str): open this file for recording MD outputs
loginterval (int): write to log file every interval steps
append_trajectory (bool): Whether to append to prev trajectory
stress_weight (float): unit conversion from GPa to eV/A^3
state_attr (np.ndarray): global state attribute (e.g. fidelity of data)
element_refs (np.ndarray): elemental energy offsets.
"""

if isinstance(potential, str):
Expand All @@ -284,7 +322,11 @@ def __init__(
if isinstance(atoms, (Structure, Molecule)):
atoms = AseAtomsAdaptor().get_atoms(atoms)
self.atoms = atoms
self.atoms.set_calculator(M3GNetCalculator(potential=potential))
self.atoms.set_calculator(
M3GNetCalculator(
potential=potential, stress_weight=stress_weight, state_attr=state_attr, element_refs=element_refs
)
)

if taut is None:
taut = 100 * timestep * units.fs
Expand Down
Loading

0 comments on commit 3a8b14d

Please sign in to comment.