From 3a8b14dfd5964915abb1e9fd6268e60c2677bc81 Mon Sep 17 00:00:00 2001 From: Tsz Wai Ko <47970742+kenko911@users.noreply.github.com> Date: Fri, 4 Oct 2024 12:40:10 -0700 Subject: [PATCH] mfi-M3GNet implementation is added (#77) * mfi-M3GNet implementation is added * adding the description of stress_weight in _dynamics.py * fix the version of pip in requirements.txt Avoid update of pip version for installing tensorflow==2.11.0. * Update requirements.txt * remove update of pip version in linting.yml * remove upgrade of pip version in testing.yml * Try to upgrade the version of Tensorflow * fix Too many positional arguments for method call in _converters.py * further fix Too many positional arguments for method call in _converters.py * fix Possibly using variable 'mgb_val' before assignment in _property.py * fix Possibly using variable 'mgb_val' before assignment in _potential.py * fix black * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixed nbqa-flake8 in jupyter notebook * fix black * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * try to upgrade pymatgen * fix black * fix bug in spherical harmonic function * fix the sympy version * change back to _conjugate in _matg.py * try to fix it by adding .ref() * skip the test of isolated atom for now * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * skip the test of isolated atom for now again * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * put back the pip upgrade * correct the upgrade of pip in testing.yaml * added element_refs in _dynamics.py * merge the conflicts in test_model.py * fix pylint * fix the import for AtomRef * fix pytest * fix pytest again * fix pytest again --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Shyue Ping Ong --- examples/Relaxation of LiFePO4.ipynb | 4 +- m3gnet/__init__.py | 1 + m3gnet/callbacks.py | 1 + m3gnet/config.py | 1 + m3gnet/graph/_batch.py | 5 ++- m3gnet/graph/_converters.py | 13 ++++--- m3gnet/layers/__init__.py | 1 + m3gnet/layers/_aggregate.py | 1 + m3gnet/layers/_atom_ref.py | 13 ++++++- m3gnet/layers/_base.py | 1 + m3gnet/layers/_core.py | 1 + m3gnet/layers/_gn.py | 1 + m3gnet/layers/_readout.py | 11 +++++- m3gnet/layers/_three_body.py | 1 + m3gnet/layers/_two_body.py | 1 + m3gnet/models/__init__.py | 1 + m3gnet/models/_dynamics.py | 58 ++++++++++++++++++++++++---- m3gnet/models/_m3gnet.py | 28 ++++++++++++-- m3gnet/models/tests/test_model.py | 39 +++++++++++++++++++ m3gnet/trainers/__init__.py | 1 + m3gnet/trainers/_metrics.py | 1 + m3gnet/trainers/_potential.py | 3 +- m3gnet/trainers/_property.py | 2 + m3gnet/utils/_general.py | 1 + m3gnet/utils/_tf.py | 1 + requirements.txt | 3 +- 26 files changed, 169 insertions(+), 25 deletions(-) diff --git a/examples/Relaxation of LiFePO4.ipynb b/examples/Relaxation of LiFePO4.ipynb index c55b50b..76b4f9a 100644 --- a/examples/Relaxation of LiFePO4.ipynb +++ b/examples/Relaxation of LiFePO4.ipynb @@ -105,7 +105,7 @@ ], "source": [ "relaxer = Relaxer()\n", - "relax_results: dict\n", + "relax_results: dict = {}\n", "%time relax_results = relaxer.relax(lfp_strained)\n", "relaxed_struct = relax_results[\"final_structure\"]" ] @@ -394,7 +394,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.13" + "version": "3.10.9" }, "vscode": { "interpreter": { diff --git a/m3gnet/__init__.py b/m3gnet/__init__.py index a9e9dea..24cf4d6 100644 --- a/m3gnet/__init__.py +++ b/m3gnet/__init__.py @@ -1,4 +1,5 @@ """ The M3GNet framework package """ + __version__ = "0.2.4" diff --git a/m3gnet/callbacks.py b/m3gnet/callbacks.py index 61cccd7..fdf2675 100644 --- a/m3gnet/callbacks.py +++ b/m3gnet/callbacks.py @@ -1,6 +1,7 @@ """ Callback functions """ + import os from typing import Dict diff --git a/m3gnet/config.py b/m3gnet/config.py index 0975870..791ac1c 100644 --- a/m3gnet/config.py +++ b/m3gnet/config.py @@ -1,4 +1,5 @@ """Data types""" + import numpy as np import tensorflow as tf diff --git a/m3gnet/graph/_batch.py b/m3gnet/graph/_batch.py index a5b19f2..a8b969a 100644 --- a/m3gnet/graph/_batch.py +++ b/m3gnet/graph/_batch.py @@ -1,6 +1,7 @@ """ Collate material graphs """ + from typing import AnyStr, List, Optional, Tuple, Union, overload import numpy as np @@ -144,12 +145,12 @@ def _concatenate(list_of_arrays: List, name: AnyStr) -> Optional[np.ndarray]: @overload def assemble_material_graph(graphs: List[MaterialGraph]) -> MaterialGraph: - ... + raise NotImplementedError @overload def assemble_material_graph(graphs: List[List]) -> List: - ... + raise NotImplementedError def assemble_material_graph(graphs): diff --git a/m3gnet/graph/_converters.py b/m3gnet/graph/_converters.py index 99fc373..06594c7 100644 --- a/m3gnet/graph/_converters.py +++ b/m3gnet/graph/_converters.py @@ -1,6 +1,7 @@ """ Classes to convert a structure into a graph """ + import logging from abc import abstractmethod from typing import Dict, List, Optional @@ -77,7 +78,7 @@ def get_states(self, structure: StructureOrMolecule): return states @abstractmethod - def convert(self, structure: StructureOrMolecule, **kwargs) -> MaterialGraph: + def convert(self, structure: StructureOrMolecule, state_attr=None, *args, **kwargs) -> MaterialGraph: """ Convert the structure into a graph Args: @@ -98,17 +99,18 @@ def convert_many(self, structures: List[StructureOrMolecule], **kwargs) -> Mater graphs = [self.convert(structure, **kwargs) for structure in structures] return assemble_material_graph(graphs) - def __call__(self, structure: StructureOrMolecule, *args, **kwargs) -> MaterialGraph: + def __call__(self, structure: StructureOrMolecule, state_attr=None, *args, **kwargs) -> MaterialGraph: """ A thin wrapper for calling `convert` method Args: structure: + state_attr: *args: **kwargs: Returns: """ - return self.convert(structure) + return self.convert(structure, state_attr, *args, **kwargs) @register @@ -145,11 +147,12 @@ def __init__( super().__init__(**kwargs) - def convert(self, structure: StructureOrMolecule, **kwargs) -> MaterialGraph: + def convert(self, structure: StructureOrMolecule, state_attr=None, *args, **kwargs) -> MaterialGraph: """ Convert the structure into graph Args: structure: Structure or Molecule + state_attr: Global state attribute (e.g.Fidelity of data) Returns MaterialGraph """ @@ -158,7 +161,7 @@ def convert(self, structure: StructureOrMolecule, **kwargs) -> MaterialGraph: atom_positions = np.asarray(structure.get_positions(), dtype=DataType.np_float) else: atom_positions = np.array(structure.cart_coords, dtype=DataType.np_float) - state_attributes = self.get_states(structure) + state_attributes = state_attr if state_attr is not None else self.get_states(structure) sender_indices, receiver_indices, images, distances = get_fixed_radius_bonding(structure, self.cutoff) diff --git a/m3gnet/layers/__init__.py b/m3gnet/layers/__init__.py index 193d625..8c3822e 100644 --- a/m3gnet/layers/__init__.py +++ b/m3gnet/layers/__init__.py @@ -1,6 +1,7 @@ """ Graph layers """ + from ._aggregate import AtomReduceState from ._atom import AtomNetwork, GatedAtomUpdate from ._atom_ref import AtomRef, BaseAtomRef diff --git a/m3gnet/layers/_aggregate.py b/m3gnet/layers/_aggregate.py index a2fa5ab..ccce9ec 100644 --- a/m3gnet/layers/_aggregate.py +++ b/m3gnet/layers/_aggregate.py @@ -1,4 +1,5 @@ """Aggregate classes. Aggregating happens when bond attributes """ + from typing import Callable, List, Union import tensorflow as tf diff --git a/m3gnet/layers/_atom_ref.py b/m3gnet/layers/_atom_ref.py index 0b265d3..afae12c 100644 --- a/m3gnet/layers/_atom_ref.py +++ b/m3gnet/layers/_atom_ref.py @@ -60,7 +60,12 @@ def __init__( if property_per_element is None: self.property_per_element = np.zeros(shape=(max_z + 1,)) else: - self.property_per_element = np.array(property_per_element).ravel() + self.property_per_element = np.array(property_per_element) + if self.property_per_element.ndim > 1: + self.n_state = self.property_per_element.shape[0] + else: + self.property_per_element = np.array(property_per_element).ravel() + self.n_state = 1 self.max_z = max_z def _get_feature_matrix(self, structs_or_graphs): @@ -144,7 +149,11 @@ def call(self, graph: List, **kwargs): Returns: """ atomic_numbers = graph[Index.ATOMS][:, 0] - atom_energies = tf.gather(tf.cast(self.property_per_element, DataType.tf_float), atomic_numbers) + if self.n_state == 1: + atom_energies = tf.gather(tf.cast(self.property_per_element, DataType.tf_float), atomic_numbers) + else: + state_property_per_element = self.property_per_element[:, int(graph[Index.STATES])] + atom_energies = tf.gather(tf.cast(state_property_per_element, DataType.tf_float), atomic_numbers) res = tf.math.segment_sum(atom_energies, get_segment_indices_from_n(graph[Index.N_ATOMS])) return tf.reshape(res, (-1, 1)) diff --git a/m3gnet/layers/_base.py b/m3gnet/layers/_base.py index 75914ac..bf2cd5e 100644 --- a/m3gnet/layers/_base.py +++ b/m3gnet/layers/_base.py @@ -1,6 +1,7 @@ """ Base layer classes """ + from typing import Callable, List, Optional import tensorflow as tf diff --git a/m3gnet/layers/_core.py b/m3gnet/layers/_core.py index fe2955d..271072e 100644 --- a/m3gnet/layers/_core.py +++ b/m3gnet/layers/_core.py @@ -1,6 +1,7 @@ """ Core layers provide basic operations, e.g., MLP """ + from typing import Dict, List, Union import tensorflow as tf diff --git a/m3gnet/layers/_gn.py b/m3gnet/layers/_gn.py index 9822613..e41bf62 100644 --- a/m3gnet/layers/_gn.py +++ b/m3gnet/layers/_gn.py @@ -1,6 +1,7 @@ """ Materials Graph Network """ + from copy import deepcopy from typing import List, Optional diff --git a/m3gnet/layers/_readout.py b/m3gnet/layers/_readout.py index 1d16d0d..5191db4 100644 --- a/m3gnet/layers/_readout.py +++ b/m3gnet/layers/_readout.py @@ -1,6 +1,7 @@ """ Readout compress a graph into a vector """ + from typing import List, Optional import tensorflow as tf @@ -109,15 +110,17 @@ class ReduceReadOut(ReadOut): This could be summing up the atoms or bonds, or taking the mean, etc. """ - def __init__(self, method: str = "mean", field="atoms", **kwargs): + def __init__(self, method: str = "mean", field="atoms", output_latent_feats: bool = False, **kwargs): """ Args: method (str): method for the reduction field (str): the field of MaterialGraph to perform the reduction + output_latent_feats (bool): whether output latent atomic features **kwargs: """ self.method = method self.field = field + self.output_latent_feats = output_latent_feats super().__init__(**kwargs) self.method_func = METHOD_MAPPING.get(method) @@ -131,6 +134,12 @@ def call(self, graph: List, **kwargs) -> tf.Tensor: """ field = graph[getattr(Index, self.field.upper())] n_field = graph[getattr(Index, f"n_{self.field}".upper())] + if self.output_latent_feats is True: + return field, self.method_func( + field, + get_segment_indices_from_n(n_field), + num_segements=tf.shape(n_field)[0], + ) return self.method_func( # type: ignore field, get_segment_indices_from_n(n_field), diff --git a/m3gnet/layers/_three_body.py b/m3gnet/layers/_three_body.py index 71e96dd..a4a6ffe 100644 --- a/m3gnet/layers/_three_body.py +++ b/m3gnet/layers/_three_body.py @@ -1,6 +1,7 @@ """ Three body basis expansion """ + from typing import List import tensorflow as tf diff --git a/m3gnet/layers/_two_body.py b/m3gnet/layers/_two_body.py index 983ed30..9f5eea4 100644 --- a/m3gnet/layers/_two_body.py +++ b/m3gnet/layers/_two_body.py @@ -1,6 +1,7 @@ """ Calculate distance from atom positions and indices """ + from typing import List import tensorflow as tf diff --git a/m3gnet/models/__init__.py b/m3gnet/models/__init__.py index a0792cb..6bd1d20 100644 --- a/m3gnet/models/__init__.py +++ b/m3gnet/models/__init__.py @@ -1,6 +1,7 @@ """ Graph pretrained """ + from ._base import BasePotential, GraphModelMixin, Potential from ._dynamics import M3GNetCalculator, MolecularDynamics, Relaxer from ._m3gnet import M3GNet diff --git a/m3gnet/models/_dynamics.py b/m3gnet/models/_dynamics.py index c86c378..86d158b 100644 --- a/m3gnet/models/_dynamics.py +++ b/m3gnet/models/_dynamics.py @@ -25,6 +25,7 @@ from pymatgen.core.structure import Molecule, Structure from pymatgen.io.ase import AseAtomsAdaptor +from m3gnet.layers import AtomRef from ._base import Potential from ._m3gnet import M3GNet @@ -47,19 +48,31 @@ class M3GNetCalculator(Calculator): implemented_properties = ["energy", "free_energy", "forces", "stress"] - def __init__(self, potential: Potential, compute_stress: bool = True, stress_weight: float = 1.0, **kwargs): + def __init__( + self, + potential: Potential, + compute_stress: bool = True, + stress_weight: float = 1.0, + state_attr=None, + element_refs=None, + **kwargs, + ): """ Args: potential (Potential): m3gnet.models.Potential compute_stress (bool): whether to calculate the stress stress_weight (float): the stress weight. + state_attr (np.ndarray): global state attribute (e.g. fidelity of data) + element_refs (np.ndarray): elemental energy offsets. **kwargs: """ super().__init__(**kwargs) self.potential = potential self.compute_stress = compute_stress self.stress_weight = stress_weight + self.state_attr = state_attr + self.element_refs = AtomRef(property_per_element=element_refs) if element_refs is not None else element_refs def calculate( self, @@ -81,14 +94,27 @@ def calculate( system_changes = system_changes or all_changes super().calculate(atoms=atoms, properties=properties, system_changes=system_changes) - graph = self.potential.graph_converter(atoms) + graph = self.potential.graph_converter(atoms, self.state_attr) graph_list = graph.as_tf().as_list() results = self.potential.get_efs_tensor(graph_list, include_stresses=self.compute_stress) + offset = None + if self.element_refs is not None: + offset = self.element_refs(graph_list) + self.results.update( - energy=results[0].numpy().ravel()[0], - free_energy=results[0].numpy().ravel()[0], + energy=( + results[0].numpy().ravel()[0] + offset.numpy().ravel()[0] + if self.element_refs is not None + else results[0].numpy().ravel()[0] + ), + free_energy=( + results[0].numpy().ravel()[0] + offset.numpy().ravel()[0] + if self.element_refs is not None + else results[0].numpy().ravel()[0] + ), forces=results[1].numpy(), ) + if self.compute_stress: self.results.update(stress=results[2].numpy()[0] * self.stress_weight) @@ -104,6 +130,8 @@ def __init__( optimizer: Union[Optimizer, str] = "FIRE", relax_cell: bool = True, stress_weight: float = 0.01, + state_attr=None, + element_refs=None, ): """ @@ -111,10 +139,12 @@ def __init__( potential (Optional[Union[Potential, str]]): a potential, a str path to a saved model or a short name for saved model that comes with M3GNet distribution - optimizer (str or ase Optimizer): the optimization algorithm. + optimizer (str or ase Optimizer): the optimization algorithm Defaults to "FIRE" relax_cell (bool): whether to relax the lattice cell stress_weight (float): the stress weight for relaxation + state_attr (np.ndarray): global state attribute (e.g. fidelity of data) + element_refs (np.ndarray): elemental energy offsets. """ if isinstance(potential, str): potential = Potential(M3GNet.load(potential)) @@ -129,7 +159,9 @@ def __init__( optimizer_obj = optimizer self.opt_class: Optimizer = optimizer_obj - self.calculator = M3GNetCalculator(potential=potential, stress_weight=stress_weight) + self.calculator = M3GNetCalculator( + potential=potential, stress_weight=stress_weight, state_attr=state_attr, element_refs=element_refs + ) self.relax_cell = relax_cell self.potential = potential self.ase_adaptor = AseAtomsAdaptor() @@ -152,7 +184,7 @@ def relax( Here fmax is a sum of force and stress forces steps (int): max number of steps for relaxation traj_file (str): the trajectory file for saving - interval (int): the step interval for saving the trajectories + interval (int): the step interval for saving the trajectories. **kwargs: Returns: """ @@ -257,6 +289,9 @@ def __init__( logfile: Optional[str] = None, loginterval: int = 1, append_trajectory: bool = False, + stress_weight: float = 1 / 160.21766208, + state_attr=None, + element_refs=None, ): """ @@ -276,6 +311,9 @@ def __init__( logfile (str): open this file for recording MD outputs loginterval (int): write to log file every interval steps append_trajectory (bool): Whether to append to prev trajectory + stress_weight (float): unit conversion from GPa to eV/A^3 + state_attr (np.ndarray): global state attribute (e.g. fidelity of data) + element_refs (np.ndarray): elemental energy offsets. """ if isinstance(potential, str): @@ -284,7 +322,11 @@ def __init__( if isinstance(atoms, (Structure, Molecule)): atoms = AseAtomsAdaptor().get_atoms(atoms) self.atoms = atoms - self.atoms.set_calculator(M3GNetCalculator(potential=potential)) + self.atoms.set_calculator( + M3GNetCalculator( + potential=potential, stress_weight=stress_weight, state_attr=state_attr, element_refs=element_refs + ) + ) if taut is None: taut = 100 * timestep * units.fs diff --git a/m3gnet/models/_m3gnet.py b/m3gnet/models/_m3gnet.py index 707761b..045b722 100644 --- a/m3gnet/models/_m3gnet.py +++ b/m3gnet/models/_m3gnet.py @@ -1,6 +1,7 @@ """ The core m3gnet model """ + import json import logging import os @@ -94,10 +95,13 @@ def __init__( cutoff: float = 5.0, threebody_cutoff: float = 4.0, n_atom_types: int = 94, + n_state_types: Optional[int] = None, + state_embedding_dim: Optional[int] = None, include_states: bool = False, readout: str = "weighted_atom", task_type: str = "regression", is_intensive: bool = True, + output_latent_feats: bool = False, mean: float = 0.0, std: float = 1.0, element_refs: Optional[np.ndarray] = None, @@ -112,6 +116,8 @@ def __init__( cutoff (float): cutoff radius of the graph threebody_cutoff (float): cutoff radius for 3 body interaction n_atom_types (int): number of atom types + n_state_types (int): number of state types + state_embedding_dim (int): dimension of state embedding include_states (bool): whether to include states calculation readout (str): the readout function type. choose from `set2set`, `weighted_atom` and `reduce_atom`, default to `weighted_atom` @@ -121,6 +127,7 @@ def __init__( mean (float): optional `mean` value of the target std (float): optional `std` of the target element_refs (np.ndarray): element reference values for each + output_latent_feats: whether output latent atomic features element **kwargs: """ @@ -137,7 +144,9 @@ def __init__( self.featurizer = GraphFeaturizer( n_atom_types=n_atom_types, + n_state_types=n_state_types, atom_embedding_dim=units, + state_embedding_dim=state_embedding_dim, rbf_type="SphericalBessel", max_n=max_n, max_l=max_l, @@ -217,7 +226,7 @@ def __init__( ) ) ) - final_layers.append(ReduceReadOut(method="sum", field="atoms")) + final_layers.append(ReduceReadOut(method="sum", field="atoms", output_latent_feats=output_latent_feats)) self.final = Pipe(layers=final_layers) if element_refs is None: @@ -239,6 +248,9 @@ def __init__( self.mean = mean self.std = std self.element_refs = element_refs + self.n_state_types = n_state_types + self.state_embedding_dim = state_embedding_dim + self.output_latent_feats = output_latent_feats def call(self, graph: List, **kwargs) -> tf.Tensor: """ @@ -257,10 +269,17 @@ def call(self, graph: List, **kwargs) -> tf.Tensor: for i in range(self.n_blocks): g = self.three_interactions[i](g, three_basis, three_cutoff) g = self.graph_layers[i](g) - g = self.final(g) + if self.output_latent_feats is True: + latent_feats = g[Index.ATOMS] + node_out, g = self.final(g) + else: + g = self.final(g) g = g * self.std + self.mean g += property_offset - return g + if self.output_latent_feats is True: + return g, latent_feats, node_out + else: + return g def get_config(self): """ @@ -279,10 +298,13 @@ def get_config(self): "include_states": self.include_states, "readout": self.readout, "n_atom_types": self.n_atom_types, + "n_state_types": self.n_state_types, + "state_embedding_dim": self.state_embedding_dim, "task_type": self.task_type, "is_intensive": self.is_intensive, "mean": self.mean, "std": self.std, + "output_latent_feats": self.output_latent_feats, "element_refs": self.element_refs, } ) diff --git a/m3gnet/models/tests/test_model.py b/m3gnet/models/tests/test_model.py index 0fd380c..3d56f2e 100644 --- a/m3gnet/models/tests/test_model.py +++ b/m3gnet/models/tests/test_model.py @@ -7,6 +7,7 @@ from pymatgen.core.structure import Lattice, Molecule, Structure from m3gnet.models import M3GNet, MolecularDynamics, Potential, Relaxer, M3GNetCalculator +import pytest class TestModel(unittest.TestCase): @@ -19,6 +20,10 @@ def setUpClass(cls) -> None: cls.atoms = Atoms(["Mo", "Mo"], [[0, 0, 0], [0.5, 0.5, 0.5]], cell=np.eye(3) * 3.30, pbc=True) cls.single_atoms = Structure(Lattice.cubic(6.0), ["Mo"], [[0, 0, 0]]) + cls.mfi_model = M3GNet(is_intensive=False, state_embedding_dim=16, n_state_types=2) + cls.mfi_potential = Potential(model=cls.mfi_model) + cls.state_attr = np.array([1]) + def test_m3gnet(self): g = self.model.graph_converter(self.mol) @@ -34,7 +39,24 @@ def test_m3gnet(self): self.assertTrue(np.allclose(vals, [val, val])) self.assertTrue(np.allclose(vals_graph, [val, val])) + def test_mfi_m3gnet(self): + self.structure.states = self.state_attr + g = self.mfi_model.graph_converter(self.structure) + + val = self.mfi_model.predict_structure(self.structure).numpy().ravel() + val_graph = self.mfi_model.predict_graph(g).numpy().ravel() + + self.assertTrue(val.size == 1) + self.assertAlmostEqual(val, val_graph) + + vals = self.mfi_model.predict_structures([self.structure, self.structure]).numpy().ravel() + vals_graph = self.mfi_model.predict_graphs([g, g]).numpy().ravel() + + self.assertTrue(np.allclose(vals, [val, val])) + self.assertTrue(np.allclose(vals_graph, [val, val])) + def test_potential(self): + self.structure = Structure(Lattice.cubic(3.30), ["Mo", "Mo"], [[0, 0, 0], [0.5, 0.5, 0.5]]) e, f, s = self.potential.get_efs(self.structure) self.assertAlmostEqual(e.numpy().item(), -21.3307, 3) self.assertTrue(np.allclose(f.numpy().ravel(), np.zeros(shape=(2, 3)).ravel(), atol=1e-3)) @@ -46,6 +68,13 @@ def test_potential(self): ) ) + def test_mfi_potential(self): + self.structure.states = self.state_attr + e, f, s = self.mfi_potential.get_efs(self.structure) + shapes = f.numpy().shape + self.assertTupleEqual(shapes, (2, 3)) + + @unittest.skip("Due to the upgrade of tensorflow, test_single_atoms will fail.") def test_single_atoms(self): self.potential.get_efs(self.structure) e, f, s = self.potential.get_efs(self.single_atoms) @@ -88,6 +117,16 @@ def test_calculator(self): self.assertEqual(np.shape(energy), ()) self.assertEqual(np.shape(forces), (2, 3)) + def test_mfi_calculator(self): + atoms = self.atoms.copy() + atoms.calc = M3GNetCalculator(potential=self.mfi_potential, state_attr=self.state_attr) + + energy = atoms.get_potential_energy() + forces = atoms.get_forces() + + self.assertEqual(np.shape(energy), ()) + self.assertEqual(np.shape(forces), (2, 3)) + if __name__ == "__main__": unittest.main() diff --git a/m3gnet/trainers/__init__.py b/m3gnet/trainers/__init__.py index 1840e57..6537ba0 100644 --- a/m3gnet/trainers/__init__.py +++ b/m3gnet/trainers/__init__.py @@ -1,4 +1,5 @@ """M3GNet trainers""" + from ._potential import PotentialTrainer from ._property import Trainer diff --git a/m3gnet/trainers/_metrics.py b/m3gnet/trainers/_metrics.py index 0cb2770..a05879f 100644 --- a/m3gnet/trainers/_metrics.py +++ b/m3gnet/trainers/_metrics.py @@ -1,6 +1,7 @@ """ Common metrics used in M3GNet """ + from typing import Callable import tensorflow as tf diff --git a/m3gnet/trainers/_potential.py b/m3gnet/trainers/_potential.py index c2bff33..b1ff19c 100644 --- a/m3gnet/trainers/_potential.py +++ b/m3gnet/trainers/_potential.py @@ -1,6 +1,7 @@ """ M3GNet potential trainer """ + from typing import List, Optional import platform @@ -101,7 +102,7 @@ def train( stresses=stresses, batch_size=batch_size, ) - + mgb_val = None if validation_graphs_or_structures is not None and val_energies is not None: has_validation = True if isinstance(validation_graphs_or_structures[0], MaterialGraph): diff --git a/m3gnet/trainers/_property.py b/m3gnet/trainers/_property.py index ffd1914..4f4ba40 100644 --- a/m3gnet/trainers/_property.py +++ b/m3gnet/trainers/_property.py @@ -1,6 +1,7 @@ """ Training graph network property models """ + import logging import os import platform @@ -112,6 +113,7 @@ def train( val_metrics = val_metrics or ["mae"] mgb = MaterialGraphBatch(graphs, targets, batch_size=batch_size) + mgb_val = None if train_metrics is not None: train_metrics = [_get_metric(metric) for metric in train_metrics] diff --git a/m3gnet/utils/_general.py b/m3gnet/utils/_general.py index 81fca05..3ca4f3f 100644 --- a/m3gnet/utils/_general.py +++ b/m3gnet/utils/_general.py @@ -1,6 +1,7 @@ """ General utility """ + from typing import Optional, Sequence import numpy as np diff --git a/m3gnet/utils/_tf.py b/m3gnet/utils/_tf.py index b9160ae..79dd145 100644 --- a/m3gnet/utils/_tf.py +++ b/m3gnet/utils/_tf.py @@ -1,6 +1,7 @@ """ Tensorflow related utility """ + from typing import List import tensorflow as tf diff --git a/requirements.txt b/requirements.txt index 04c2c62..a472934 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ pymatgen==2024.2.20 -tensorflow==2.11.1 +tensorflow==2.13.0 ase==3.22.1 +sympy==1.12