Skip to content

Commit

Permalink
crystal: Fixed get_symbol and get_scatternig params for nonstandard s…
Browse files Browse the repository at this point in the history
…ites; Added tests for scattering params for real, void and placeholder sites
  • Loading branch information
Somerandomguy10111 committed Jun 30, 2024
1 parent 292de08 commit d0a8575
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 23 deletions.
15 changes: 8 additions & 7 deletions CrystalStructure/crystal/atomic_site.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __post_init__(self):

@property
def pymatgen_species(self) -> SpeciesLike:
return self.atom_type.specifier
return self.atom_type.pymatgen_species

@property
def element_symbol(self) -> str:
Expand Down Expand Up @@ -88,18 +88,19 @@ def from_str(cls, s: str):

class AtomType:
void_symbol = '⊥'
placeholder_symbol = 'NaN'
placeholder_symbol = '*'

def __init__(self, symbol : str):
if symbol == self.void_symbol:
self.specifier : SpeciesLike = DummySpecies(symbol=self.void_symbol)
self.pymatgen_species : SpeciesLike = DummySpecies(symbol=self.void_symbol, oxidation_state=None)
elif symbol == self.placeholder_symbol:
self.specifier : SpeciesLike = DummySpecies(symbol=self.placeholder_symbol)
self.pymatgen_species : SpeciesLike = DummySpecies(symbol=self.placeholder_symbol, oxidation_state=None)
else:
self.specifier : SpeciesLike = Species.from_str(species_string=symbol)
self.pymatgen_species : SpeciesLike = Species.from_str(species_string=symbol)


def get_symbol(self):
return str(self.specifier)
return str(self.pymatgen_species)

@property
def scattering_params(self) -> ScatteringParams:
Expand All @@ -111,7 +112,7 @@ def scattering_params(self) -> ScatteringParams:
else:
# TODO: This casting only currently exists beacuse the scattering param table only has values for (unoxidized) elements, not ions
# TODO: Normally would simply be species_symbol=str(self.species_like)
species_symbol = self.specifier.element.symbol
species_symbol = self.pymatgen_species.element.symbol
values = AtomicConstants.get_scattering_params(species_symbol=species_symbol)

(a1, b1), (a2, b2), (a3, b3), (a4, b4) = values
Expand Down
3 changes: 3 additions & 0 deletions CrystalStructure/crystal/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,6 @@ def from_str(cls, s: str):

def to_str(self) -> str:
return json.dumps([site.to_str() for site in self])

def __str__(self):
return str([x for x in self])
2 changes: 1 addition & 1 deletion CrystalStructure/crystal/crystal.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def to_pymatgen(self) -> Structure:
lattice = Lattice.from_parameters(a, b, c, alpha, beta, gamma)

non_void_sites = self.base.get_non_void_sites()
atoms = [site.atom_type.specifier for site in non_void_sites]
atoms = [site.atom_type.pymatgen_species for site in non_void_sites]
positions = [(site.x, site.y, site.z) for site in non_void_sites]
return Structure(lattice, atoms, positions)

Expand Down
8 changes: 4 additions & 4 deletions CrystalStructure/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,17 @@

class CrystalExamples:
@staticmethod
def get_crystal(num: int, mute: bool = False):
def get_crystal(num: int, verbose: bool = False):
cif_content = CrystalExamples.get_cif_content(num=num)
crystal_structure = CrystalStructure.from_cif(cif_content=cif_content)
if not mute:
if not verbose:
print(f'--> Cif content:\n {cif_content}')
print(f'--> Crystal structure:\n {crystal_structure}')
return crystal_structure

@staticmethod
def get_base(num : int = 1, mute : bool = True) -> CrystalBase:
crystal_stucture = CrystalExamples.get_crystal(num=num, mute=mute)
def get_base(num : int = 1, verbose : bool = False) -> CrystalBase:
crystal_stucture = CrystalExamples.get_crystal(num=num, verbose=verbose)
return crystal_stucture.base

@staticmethod
Expand Down
32 changes: 21 additions & 11 deletions tests/t_base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from CrystalStructure.crystal.atomic_site import AtomType

from CrystalStructure.crystal import CrystalBase, AtomicSite
from holytools.devtools import Unittest
from pymatgen.core import Species

Expand All @@ -7,17 +10,24 @@

class TestCrystalBase(Unittest):
def test_scattering_params(self):
base = CrystalExamples.get_base()
seen_species = set()

for atomic_site in base:
params = atomic_site.get_scattering_params()
self.assertEqual(len(params), 8)
for p in params:
self.assertIsInstance(p, float)
if not atomic_site.atom_type in seen_species:
print(f'Scattering params for species \"{atomic_site.species_str}:\n a1, a2, a3, a4, b1, b2, b3, b4 = {params}')
seen_species.add(atomic_site.atom_type)
mock_base = CrystalBase([
AtomicSite(x=0.5, y=0.5, z=0.5, occupancy=1.0, species_str="Si0+"),
AtomicSite(x=0.1, y=0.1, z=0.1, occupancy=1.0, species_str=AtomType.placeholder_symbol),
AtomicSite(x=0.9, y=0.9, z=0.9, occupancy=1.0, species_str=AtomType.void_symbol)
])
real_base = CrystalExamples.get_base()

for base in [mock_base, real_base]:
seen_species = set()

for atomic_site in base:
params = atomic_site.get_scattering_params()
self.assertEqual(len(params), 8)
for p in params:
self.assertIsInstance(p, float)
if not atomic_site.atom_type in seen_species:
print(f'Scattering params for species \"{atomic_site.species_str}:\n a1, a2, a3, a4, b1, b2, b3, b4 = {params}')
seen_species.add(atomic_site.atom_type)


if __name__ == '__main__':
Expand Down
File renamed without changes.

0 comments on commit d0a8575

Please sign in to comment.