diff --git a/CrystalStructure/crystal/atomic_site.py b/CrystalStructure/crystal/atomic_site.py index 93afb34..e9da2de 100644 --- a/CrystalStructure/crystal/atomic_site.py +++ b/CrystalStructure/crystal/atomic_site.py @@ -32,6 +32,14 @@ class AtomicSite(Serializable): def __post_init__(self): self.atom_type : AtomType = AtomType(symbol=self.species_symbol) + @property + def pymatgen_species(self) -> SpeciesLike: + return self.atom_type.specifier + + @property + def element_symbol(self) -> str: + return self.pymatgen_species.element.symbol + @classmethod def make_void(cls) -> AtomicSite: return cls(x=None, y=None, z=None, occupancy=0.0, species_symbol=AtomType.void_symbol) diff --git a/CrystalStructure/crystal/base.py b/CrystalStructure/crystal/base.py index 38dca49..e25cabb 100644 --- a/CrystalStructure/crystal/base.py +++ b/CrystalStructure/crystal/base.py @@ -23,7 +23,7 @@ def __init__(self, atomic_sites : Optional[list[AtomicSite]] = None): def calculate_atomic_volume(self) -> float: total_atomic_volume = 0 for site in self.get_non_void_sites(): - element_symbol : str = site.species_symbol + element_symbol : str = site.element_symbol covalent_radius = AtomicConstants.get_covalent(element_symbol=element_symbol) vdw_radius = AtomicConstants.get_vdw_radius(element_symbol=element_symbol) diff --git a/tests/t_crystal/t_properties.py b/tests/t_crystal/t_properties.py index 48f821c..22833d5 100644 --- a/tests/t_crystal/t_properties.py +++ b/tests/t_crystal/t_properties.py @@ -1,7 +1,8 @@ import math +from typing import Union import numpy as np -from pymatgen.core import Lattice, Structure, Composition, PeriodicSite +from pymatgen.core import Lattice, Structure, Composition, PeriodicSite, Site, Species, Element import tests.t_crystal.crystal_test as BaseTest from CrystalStructure.crystal import AtomicSite, CrystalStructure @@ -14,16 +15,14 @@ def test_pymatgen(self): for struct, crystal in zip(self.pymatgen_structures, self.crystals): actual = self.to_clustered_pymatgen(crystal) expected = struct - self.assertEqual(actual.lattice, expected.lattice) - print(f'Actual sites = {actual.sites}; Expected sites = {expected.sites}') self.assertEqual(len(actual.sites), len(expected.sites)) - actual_sites = sorted(actual.sites, key=dist_from_origin) expected_sites = sorted(expected.sites, key=dist_from_origin) + print(f'Expected, actual sites = {expected_sites}, {actual_sites}') for s1,s2 in zip(actual_sites, expected_sites): - self.assertEqual(s1,s2) + self.check_sites_equal(s1,s2) print(f'Composition = {actual.composition}') @@ -59,6 +58,26 @@ def test_symmetries(self): # --------------------------------------------------------- + def check_sites_equal(self, s1 : Site, s2 : Site): + s1, s2 = self.standardize_site(s1), self.standardize_site(s2) + self.assertEqual(s1, s2) + + @staticmethod + def standardize_site(site : Site): + def cast_to_species(x : Union[Species, Element]): + if isinstance(x, Element): + x = Species(symbol=str(x), oxidation_state=0) + return x + + if isinstance(site.species, Composition): + comp = site.species + species = Composition({cast_to_species(x) : occ for x, occ in comp.items()}) + site = Site(species=species, coords=site.coords) + + return site + + + @staticmethod def to_clustered_pymatgen(crystal : CrystalStructure) -> Structure: a, b, c = crystal.lengths.as_tuple() @@ -85,13 +104,14 @@ def matching_cluster(the_site): site_comps = [] for clust in clusters: - comp = Composition({site.atom_type: site.occupancy for site in clust}) + comp = Composition({site.pymatgen_species : site.occupancy for site in clust}) site_comps.append(comp) positions = [(c[0].x, c[0].y, c[0].z) for c in clusters] return Structure(lattice=lattice, species=site_comps, coords=positions) + def dist_from_origin(site : PeriodicSite | AtomicSite): return math.sqrt(site.x ** 2 + site.y ** 2 + site.z ** 2) diff --git a/tests/t_crystal/t_standard.py b/tests/t_crystal/t_standard.py index 6063db2..7c9b617 100644 --- a/tests/t_crystal/t_standard.py +++ b/tests/t_crystal/t_standard.py @@ -35,7 +35,7 @@ def test_scaling(self): def test_standardization(self): self.mock_crystal.standardize() expected_species_list = ['O', 'Si', AtomType.void_symbol] - acrual_species_list = [site.species_symbol for site in self.mock_crystal.base] + acrual_species_list = [site.species_symbol1 for site in self.mock_crystal.base] self.assertEqual(acrual_species_list, expected_species_list) actual_primitives = self.mock_crystal.lengths.as_tuple()