diff --git a/abcd/model.py b/abcd/model.py index f4c87b61..4b9dabfd 100644 --- a/abcd/model.py +++ b/abcd/model.py @@ -146,7 +146,9 @@ def __iter__(self): @classmethod def from_atoms(cls, atoms: Atoms, extra_info=None, store_calc=True): - """ASE's original implementation""" + """Extract data from Atoms info, arrays and results.""" + if not isinstance(atoms, Atoms): + raise ValueError("atoms must be an ASE Atoms object.") reserved_keys = { "n_atoms", @@ -157,11 +159,13 @@ def from_atoms(cls, atoms: Atoms, extra_info=None, store_calc=True): "derived", "formula", } + arrays_keys = set(atoms.arrays.keys()) info_keys = set(atoms.info.keys()) - results_keys = ( - set(atoms.calc.results.keys()) if store_calc and atoms.calc else {} - ) + if store_calc and atoms.calc: + results_keys = atoms.calc.results.keys() - (arrays_keys | info_keys) + else: + results_keys = set() all_keys = (reserved_keys, arrays_keys, info_keys, results_keys) if len(set.union(*all_keys)) != sum(map(len, all_keys)): @@ -172,46 +176,43 @@ def from_atoms(cls, atoms: Atoms, extra_info=None, store_calc=True): n_atoms = len(atoms) - dct = { + data = { "n_atoms": n_atoms, "cell": atoms.cell.tolist(), "pbc": atoms.pbc.tolist(), "formula": atoms.get_chemical_formula(), } - info_keys.update({"n_atoms", "cell", "pbc", "formula"}) + info_keys.update(data.keys()) for key, value in atoms.arrays.items(): if isinstance(value, np.ndarray): - dct[key] = value.tolist() + data[key] = value.tolist() else: - dct[key] = value + data[key] = value for key, value in atoms.info.items(): if isinstance(value, np.ndarray): - dct[key] = value.tolist() + data[key] = value.tolist() else: - dct[key] = value + data[key] = value if store_calc and atoms.calc: - dct["calculator_name"] = atoms.calc.__class__.__name__ - dct["calculator_parameters"] = atoms.calc.todict() + data["calculator_name"] = atoms.calc.__class__.__name__ + data["calculator_parameters"] = atoms.calc.todict() info_keys.update({"calculator_name", "calculator_parameters"}) for key, value in atoms.calc.results.items(): - if isinstance(value, np.ndarray): - if value.shape[0] == n_atoms: - arrays_keys.update(key) - else: - info_keys.update(key) - dct[key] = value.tolist() + data[key] = value.tolist() + else: + data[key] = value item.arrays_keys = list(arrays_keys) item.info_keys = list(info_keys) item.results_keys = list(results_keys) - item.update(dct) + item.update(data) if extra_info: item.info_keys.extend(extra_info.keys()) @@ -240,6 +241,7 @@ def to_ase(self): # atoms.calc = get_calculator(data['results']['calculator_name'])(**params) params = self.pop("calculator_parameters", {}) + info_keys -= {"calculator_parameters"} atoms.calc = SinglePointCalculator(atoms, **params) atoms.calc.results.update((key, self[key]) for key in results_keys) @@ -256,14 +258,14 @@ def pre_save(self): if cell: volume = abs(np.linalg.det(cell)) # atoms.get_volume() - self["volume"] = volume self.derived_keys.append("volume") + self["volume"] = volume virial = self.get("virial") if virial: # pressure P = -1/3 Tr(stress) = -1/3 Tr(virials/volume) - self["pressure"] = -1 / 3 * np.trace(virial / volume) self.derived_keys.append("pressure") + self["pressure"] = -1 / 3 * np.trace(virial / volume) # 'elements': Counter(atoms.get_chemical_symbols()), self["elements"] = Counter(str(element) for element in self["numbers"]) diff --git a/pyproject.toml b/pyproject.toml index fc870cce..8de3a1eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ numpy = "^1.26" tqdm = "^4.66" pymongo = "^4.7.3" matplotlib = "^3.9" -ase = "3.22.1" +ase = "^3.23" lark = "^1.1.9" [tool.poetry.group.dev.dependencies] diff --git a/tests/test_abstract_model.py b/tests/test_abstract_model.py new file mode 100644 index 00000000..f9e820e6 --- /dev/null +++ b/tests/test_abstract_model.py @@ -0,0 +1,256 @@ +import io + +import ase +import pytest +from pytest import approx + +from io import StringIO +from ase.io import read, write +import numpy as np + +from abcd.model import AbstractModel +from ase.calculators.lj import LennardJones + + +@pytest.fixture +def extxyz_file(): + return StringIO( + """2 + Properties=species:S:1:pos:R:3:forces:R:3 energy=-1 pbc="F T F" info="test" + Si 0.0 0.0 0.0 0.4 0.6 -0.4 + Si 0.0 0.0 0.0 -0.1 -0.5 -0.6 + """ + ) + + +def test_from_atoms(extxyz_file): + """Test extracting data from ASE Atoms object.""" + expected_forces = np.array([[0.4, 0.6, -0.4], [-0.1, -0.5, -0.6]]) + expected_stress = np.array([-1.0, -1.0, -1.0, -2.1, 2.0, 1.8]) + + atoms = read(extxyz_file, format="extxyz") + atoms.calc.results["stress"] = expected_stress + data = AbstractModel.from_atoms(atoms) + + # Test info + info_keys = { + "pbc", + "n_atoms", + "cell", + "formula", + "calculator_name", + "calculator_parameters", + "info", + } + assert info_keys == set(data.info_keys) + assert data["pbc"] == [False, True, False] + assert data["n_atoms"] == 2 + assert len(data["cell"]) == 3 + assert all(arr == [0.0, 0.0, 0.0] for arr in data["cell"]) + assert data["formula"] == "Si2" + assert data["info"] == "test" + + # Test arrays + assert {"numbers", "positions"} == set(data.arrays_keys) + + # Test results + assert {"energy", "stress", "forces"} == set(data.results_keys) + assert data["energy"] == -1 + assert data["forces"] == pytest.approx(expected_forces) + assert data["stress"] == pytest.approx(expected_stress) + + # Test derived + derived_keys = { + "elements", + "username", + "uploaded", + "modified", + "volume", + "hash", + "hash_structure", + } + assert derived_keys == set(data.derived_keys) + + +def test_from_atoms_no_calc(extxyz_file): + """Test extracting data from ASE Atoms object without results.""" + expected_stress = np.array([-1.0, -1.0, -1.0, -2.1, 2.0, 1.8]) + + atoms = read(extxyz_file, format="extxyz") + atoms.calc.results["stress"] = expected_stress + data = AbstractModel.from_atoms(atoms, store_calc=False) + + # Test info + assert {"pbc", "n_atoms", "cell", "formula", "info"} == set(data.info_keys) + assert data["pbc"] == [False, True, False] + assert data["n_atoms"] == 2 + assert len(data["cell"]) == 3 + assert all(arr == [0.0, 0.0, 0.0] for arr in data["cell"]) + assert data["formula"] == "Si2" + assert data["info"] == "test" + + # Test arrays + assert {"numbers", "positions"} == set(data.arrays_keys) + + # Test results + results_keys = { + "energy", + "forces", + "stress", + "calculator_name", + "calculator_parameters", + } + assert all(key not in data for key in results_keys) + + # Test derived + derived_keys = { + "elements", + "username", + "uploaded", + "modified", + "volume", + "hash", + "hash_structure", + } + assert derived_keys == set(data.derived_keys) + + +def test_to_ase(extxyz_file): + """Test returning data to ASE Atoms object with results.""" + atoms = read(extxyz_file, format="extxyz") + data = AbstractModel.from_atoms(atoms, store_calc=True) + + new_atoms = data.to_ase() + + # Test info set + assert new_atoms.cell == pytest.approx(atoms.cell) + assert new_atoms.pbc == pytest.approx(atoms.pbc) + assert new_atoms.positions == pytest.approx(atoms.positions) + assert new_atoms.numbers == pytest.approx(atoms.numbers) + + assert new_atoms.info["n_atoms"] == len(atoms) + assert new_atoms.info["formula"] == atoms.get_chemical_formula() + + assert new_atoms.calc.results["energy"] == pytest.approx( + atoms.calc.results["energy"] + ) + assert new_atoms.calc.results["forces"] == pytest.approx( + atoms.calc.results["forces"] + ) + + +def test_to_ase_no_results(extxyz_file): + """Test returning data to ASE Atoms object without results.""" + atoms = read(extxyz_file, format="extxyz") + data = AbstractModel.from_atoms(atoms, store_calc=False) + + new_atoms = data.to_ase() + + # Test info set + assert new_atoms.cell == pytest.approx(atoms.cell) + assert new_atoms.pbc == pytest.approx(atoms.pbc) + assert new_atoms.positions == pytest.approx(atoms.positions) + assert new_atoms.numbers == pytest.approx(atoms.numbers) + + assert new_atoms.info["n_atoms"] == len(atoms) + assert new_atoms.info["formula"] == atoms.get_chemical_formula() + + assert new_atoms.calc is None + + +def test_from_atoms_len_atoms_3(): + atoms = ase.Atoms( + "H3", + positions=[[0, 0, 0], [0, 0, 1], [0, 1, 0]], + pbc=True, + cell=[2, 2, 2], + ) + atoms.calc = LennardJones() + atoms.calc.calculate(atoms) + + # convert + abcd_data = AbstractModel.from_atoms(atoms, store_calc=True) + + assert set(abcd_data.info_keys) == { + "pbc", + "n_atoms", + "cell", + "formula", + "calculator_name", + "calculator_parameters", + } + assert set(abcd_data.arrays_keys) == {"numbers", "positions"} + assert set(abcd_data.results_keys) == { + "stress", + "energy", + "forces", + "energies", + "stresses", + "free_energy", + } + + # check some values as well + assert abcd_data["energy"] == atoms.get_potential_energy() + assert abcd_data["forces"] == approx(atoms.get_forces()) + + +@pytest.mark.parametrize("store_calc", [True, False]) +def test_write_and_read(store_calc): + # create atoms & add a calculator + atoms = ase.Atoms( + "H3", + positions=[[0, 0, 0], [0, 0, 1], [0, 1, 0]], + pbc=True, + cell=[2, 2, 2], + ) + atoms.calc = LennardJones() + atoms.calc.calculate(atoms) + + # dump to XYZ + buffer = io.StringIO() + write(buffer, atoms, format="extxyz") + + # read back + buffer.seek(0) + atoms_read = read(buffer, format="extxyz") + + # read in both of them + abcd_data = AbstractModel.from_atoms(atoms, store_calc=store_calc) + abcd_data_after_read = AbstractModel.from_atoms(atoms_read, store_calc=store_calc) + + # check that all results are the same + for key in ["info_keys", "arrays_keys", "derived_keys", "results_keys"]: + assert set(getattr(abcd_data, key)) == set( + getattr(abcd_data_after_read, key) + ), f"{key} mismatched" + + # info & arrays same, except calc recognised as LJ when not from XYZ + for key in set(abcd_data.info_keys + abcd_data.arrays_keys) - { + "calculator_name", + "calculator_parameters", + }: + assert ( + abcd_data[key] == abcd_data_after_read[key] + ), f"{key}'s value does not match" + + # date & hashed will differ + for key in set(abcd_data.derived_keys) - { + "hash", + "modified", + "uploaded", + "hash_structure", # see issue #118 + }: + assert ( + abcd_data[key] == abcd_data_after_read[key] + ), f"{key}'s value does not match" + + # expected differences - n.b. order of calls above + assert abcd_data_after_read["modified"] > abcd_data["modified"] + assert abcd_data_after_read["uploaded"] > abcd_data["uploaded"] + assert abcd_data_after_read["hash"] != abcd_data["hash"] + + # expect results to match within fp precision + for key in set(abcd_data.results_keys): + assert abcd_data[key] == approx( + np.array(abcd_data_after_read[key]) + ), f"{key}'s value does not match"