Skip to content

Commit

Permalink
Fix data model (#117)
Browse files Browse the repository at this point in the history
* Update ASE

* Fix duplicate calculated results

* Fix getting Atoms results

* Fix derived data

* Test extxyz file

* Fix to ASE function

* Test to ASE function

* additional tests

* Update tests/test_abstract_model.py

---------

Co-authored-by: ElliottKasoar <ElliottKasoar@users.noreply.github.com>
Co-authored-by: Tamas K Stenczel <tks32@cam.ac.uk>
  • Loading branch information
3 people authored Dec 10, 2024
1 parent e317bee commit 4f1b0f6
Show file tree
Hide file tree
Showing 3 changed files with 280 additions and 22 deletions.
44 changes: 23 additions & 21 deletions abcd/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ def __iter__(self):

@classmethod
def from_atoms(cls, atoms: Atoms, extra_info=None, store_calc=True):
"""ASE's original implementation"""
"""Extract data from Atoms info, arrays and results."""
if not isinstance(atoms, Atoms):
raise ValueError("atoms must be an ASE Atoms object.")

reserved_keys = {
"n_atoms",
Expand All @@ -157,11 +159,13 @@ def from_atoms(cls, atoms: Atoms, extra_info=None, store_calc=True):
"derived",
"formula",
}

arrays_keys = set(atoms.arrays.keys())
info_keys = set(atoms.info.keys())
results_keys = (
set(atoms.calc.results.keys()) if store_calc and atoms.calc else {}
)
if store_calc and atoms.calc:
results_keys = atoms.calc.results.keys() - (arrays_keys | info_keys)
else:
results_keys = set()

all_keys = (reserved_keys, arrays_keys, info_keys, results_keys)
if len(set.union(*all_keys)) != sum(map(len, all_keys)):
Expand All @@ -172,46 +176,43 @@ def from_atoms(cls, atoms: Atoms, extra_info=None, store_calc=True):

n_atoms = len(atoms)

dct = {
data = {
"n_atoms": n_atoms,
"cell": atoms.cell.tolist(),
"pbc": atoms.pbc.tolist(),
"formula": atoms.get_chemical_formula(),
}

info_keys.update({"n_atoms", "cell", "pbc", "formula"})
info_keys.update(data.keys())

for key, value in atoms.arrays.items():
if isinstance(value, np.ndarray):
dct[key] = value.tolist()
data[key] = value.tolist()
else:
dct[key] = value
data[key] = value

for key, value in atoms.info.items():
if isinstance(value, np.ndarray):
dct[key] = value.tolist()
data[key] = value.tolist()
else:
dct[key] = value
data[key] = value

if store_calc and atoms.calc:
dct["calculator_name"] = atoms.calc.__class__.__name__
dct["calculator_parameters"] = atoms.calc.todict()
data["calculator_name"] = atoms.calc.__class__.__name__
data["calculator_parameters"] = atoms.calc.todict()
info_keys.update({"calculator_name", "calculator_parameters"})

for key, value in atoms.calc.results.items():

if isinstance(value, np.ndarray):
if value.shape[0] == n_atoms:
arrays_keys.update(key)
else:
info_keys.update(key)
dct[key] = value.tolist()
data[key] = value.tolist()
else:
data[key] = value

item.arrays_keys = list(arrays_keys)
item.info_keys = list(info_keys)
item.results_keys = list(results_keys)

item.update(dct)
item.update(data)

if extra_info:
item.info_keys.extend(extra_info.keys())
Expand Down Expand Up @@ -240,6 +241,7 @@ def to_ase(self):
# atoms.calc = get_calculator(data['results']['calculator_name'])(**params)

params = self.pop("calculator_parameters", {})
info_keys -= {"calculator_parameters"}

atoms.calc = SinglePointCalculator(atoms, **params)
atoms.calc.results.update((key, self[key]) for key in results_keys)
Expand All @@ -256,14 +258,14 @@ def pre_save(self):

if cell:
volume = abs(np.linalg.det(cell)) # atoms.get_volume()
self["volume"] = volume
self.derived_keys.append("volume")
self["volume"] = volume

virial = self.get("virial")
if virial:
# pressure P = -1/3 Tr(stress) = -1/3 Tr(virials/volume)
self["pressure"] = -1 / 3 * np.trace(virial / volume)
self.derived_keys.append("pressure")
self["pressure"] = -1 / 3 * np.trace(virial / volume)

# 'elements': Counter(atoms.get_chemical_symbols()),
self["elements"] = Counter(str(element) for element in self["numbers"])
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ numpy = "^1.26"
tqdm = "^4.66"
pymongo = "^4.7.3"
matplotlib = "^3.9"
ase = "3.22.1"
ase = "^3.23"
lark = "^1.1.9"

[tool.poetry.group.dev.dependencies]
Expand Down
256 changes: 256 additions & 0 deletions tests/test_abstract_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
import io

import ase
import pytest
from pytest import approx

from io import StringIO
from ase.io import read, write
import numpy as np

from abcd.model import AbstractModel
from ase.calculators.lj import LennardJones


@pytest.fixture
def extxyz_file():
return StringIO(
"""2
Properties=species:S:1:pos:R:3:forces:R:3 energy=-1 pbc="F T F" info="test"
Si 0.0 0.0 0.0 0.4 0.6 -0.4
Si 0.0 0.0 0.0 -0.1 -0.5 -0.6
"""
)


def test_from_atoms(extxyz_file):
"""Test extracting data from ASE Atoms object."""
expected_forces = np.array([[0.4, 0.6, -0.4], [-0.1, -0.5, -0.6]])
expected_stress = np.array([-1.0, -1.0, -1.0, -2.1, 2.0, 1.8])

atoms = read(extxyz_file, format="extxyz")
atoms.calc.results["stress"] = expected_stress
data = AbstractModel.from_atoms(atoms)

# Test info
info_keys = {
"pbc",
"n_atoms",
"cell",
"formula",
"calculator_name",
"calculator_parameters",
"info",
}
assert info_keys == set(data.info_keys)
assert data["pbc"] == [False, True, False]
assert data["n_atoms"] == 2
assert len(data["cell"]) == 3
assert all(arr == [0.0, 0.0, 0.0] for arr in data["cell"])
assert data["formula"] == "Si2"
assert data["info"] == "test"

# Test arrays
assert {"numbers", "positions"} == set(data.arrays_keys)

# Test results
assert {"energy", "stress", "forces"} == set(data.results_keys)
assert data["energy"] == -1
assert data["forces"] == pytest.approx(expected_forces)
assert data["stress"] == pytest.approx(expected_stress)

# Test derived
derived_keys = {
"elements",
"username",
"uploaded",
"modified",
"volume",
"hash",
"hash_structure",
}
assert derived_keys == set(data.derived_keys)


def test_from_atoms_no_calc(extxyz_file):
"""Test extracting data from ASE Atoms object without results."""
expected_stress = np.array([-1.0, -1.0, -1.0, -2.1, 2.0, 1.8])

atoms = read(extxyz_file, format="extxyz")
atoms.calc.results["stress"] = expected_stress
data = AbstractModel.from_atoms(atoms, store_calc=False)

# Test info
assert {"pbc", "n_atoms", "cell", "formula", "info"} == set(data.info_keys)
assert data["pbc"] == [False, True, False]
assert data["n_atoms"] == 2
assert len(data["cell"]) == 3
assert all(arr == [0.0, 0.0, 0.0] for arr in data["cell"])
assert data["formula"] == "Si2"
assert data["info"] == "test"

# Test arrays
assert {"numbers", "positions"} == set(data.arrays_keys)

# Test results
results_keys = {
"energy",
"forces",
"stress",
"calculator_name",
"calculator_parameters",
}
assert all(key not in data for key in results_keys)

# Test derived
derived_keys = {
"elements",
"username",
"uploaded",
"modified",
"volume",
"hash",
"hash_structure",
}
assert derived_keys == set(data.derived_keys)


def test_to_ase(extxyz_file):
"""Test returning data to ASE Atoms object with results."""
atoms = read(extxyz_file, format="extxyz")
data = AbstractModel.from_atoms(atoms, store_calc=True)

new_atoms = data.to_ase()

# Test info set
assert new_atoms.cell == pytest.approx(atoms.cell)
assert new_atoms.pbc == pytest.approx(atoms.pbc)
assert new_atoms.positions == pytest.approx(atoms.positions)
assert new_atoms.numbers == pytest.approx(atoms.numbers)

assert new_atoms.info["n_atoms"] == len(atoms)
assert new_atoms.info["formula"] == atoms.get_chemical_formula()

assert new_atoms.calc.results["energy"] == pytest.approx(
atoms.calc.results["energy"]
)
assert new_atoms.calc.results["forces"] == pytest.approx(
atoms.calc.results["forces"]
)


def test_to_ase_no_results(extxyz_file):
"""Test returning data to ASE Atoms object without results."""
atoms = read(extxyz_file, format="extxyz")
data = AbstractModel.from_atoms(atoms, store_calc=False)

new_atoms = data.to_ase()

# Test info set
assert new_atoms.cell == pytest.approx(atoms.cell)
assert new_atoms.pbc == pytest.approx(atoms.pbc)
assert new_atoms.positions == pytest.approx(atoms.positions)
assert new_atoms.numbers == pytest.approx(atoms.numbers)

assert new_atoms.info["n_atoms"] == len(atoms)
assert new_atoms.info["formula"] == atoms.get_chemical_formula()

assert new_atoms.calc is None


def test_from_atoms_len_atoms_3():
atoms = ase.Atoms(
"H3",
positions=[[0, 0, 0], [0, 0, 1], [0, 1, 0]],
pbc=True,
cell=[2, 2, 2],
)
atoms.calc = LennardJones()
atoms.calc.calculate(atoms)

# convert
abcd_data = AbstractModel.from_atoms(atoms, store_calc=True)

assert set(abcd_data.info_keys) == {
"pbc",
"n_atoms",
"cell",
"formula",
"calculator_name",
"calculator_parameters",
}
assert set(abcd_data.arrays_keys) == {"numbers", "positions"}
assert set(abcd_data.results_keys) == {
"stress",
"energy",
"forces",
"energies",
"stresses",
"free_energy",
}

# check some values as well
assert abcd_data["energy"] == atoms.get_potential_energy()
assert abcd_data["forces"] == approx(atoms.get_forces())


@pytest.mark.parametrize("store_calc", [True, False])
def test_write_and_read(store_calc):
# create atoms & add a calculator
atoms = ase.Atoms(
"H3",
positions=[[0, 0, 0], [0, 0, 1], [0, 1, 0]],
pbc=True,
cell=[2, 2, 2],
)
atoms.calc = LennardJones()
atoms.calc.calculate(atoms)

# dump to XYZ
buffer = io.StringIO()
write(buffer, atoms, format="extxyz")

# read back
buffer.seek(0)
atoms_read = read(buffer, format="extxyz")

# read in both of them
abcd_data = AbstractModel.from_atoms(atoms, store_calc=store_calc)
abcd_data_after_read = AbstractModel.from_atoms(atoms_read, store_calc=store_calc)

# check that all results are the same
for key in ["info_keys", "arrays_keys", "derived_keys", "results_keys"]:
assert set(getattr(abcd_data, key)) == set(
getattr(abcd_data_after_read, key)
), f"{key} mismatched"

# info & arrays same, except calc recognised as LJ when not from XYZ
for key in set(abcd_data.info_keys + abcd_data.arrays_keys) - {
"calculator_name",
"calculator_parameters",
}:
assert (
abcd_data[key] == abcd_data_after_read[key]
), f"{key}'s value does not match"

# date & hashed will differ
for key in set(abcd_data.derived_keys) - {
"hash",
"modified",
"uploaded",
"hash_structure", # see issue #118
}:
assert (
abcd_data[key] == abcd_data_after_read[key]
), f"{key}'s value does not match"

# expected differences - n.b. order of calls above
assert abcd_data_after_read["modified"] > abcd_data["modified"]
assert abcd_data_after_read["uploaded"] > abcd_data["uploaded"]
assert abcd_data_after_read["hash"] != abcd_data["hash"]

# expect results to match within fp precision
for key in set(abcd_data.results_keys):
assert abcd_data[key] == approx(
np.array(abcd_data_after_read[key])
), f"{key}'s value does not match"

0 comments on commit 4f1b0f6

Please sign in to comment.