Skip to content

Commit

Permalink
Fixed is_nonstandard(); Corrected get_symbol() for non Species atom t…
Browse files Browse the repository at this point in the history
…ypes
  • Loading branch information
Somerandomguy10111 committed Jun 23, 2024
1 parent 5d20d54 commit 3a5eeed
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 48 deletions.
34 changes: 23 additions & 11 deletions CrystalStructure/crystal/atomic_site.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,37 @@ class AtomicSite(Serializable):
y: Optional[float]
z: Optional[float]
occupancy : Optional[float]
species : Union[Element,Species, Void, UnknownSite]
atom_type : Union[Element,Species, Void, UnknownSite]
wyckoff_letter : Optional[str] = None

@classmethod
def make_void(cls) -> AtomicSite:
return cls(x=None, y=None, z=None, occupancy=0, species=Void())
return cls(x=None, y=None, z=None, occupancy=0, atom_type=Void())

@classmethod
def make_placeholder(cls):
return cls(x=None, y=None, z=None, occupancy=None, species=UnknownSite())
return cls(x=None, y=None, z=None, occupancy=None, atom_type=UnknownSite())

def is_nonstandard(self) -> bool:
if not isinstance(self.species, Species):
if isinstance(self.atom_type, Void) or isinstance(self.atom_type, UnknownSite):
return True
return False

# ---------------------------------------------------------
# properties

def get_symbol(self) -> str:
if isinstance(self.atom_type, Element):
return self.atom_type.symbol
elif isinstance(self.atom_type, Species):
return self.atom_type.element.symbol
elif isinstance(self.atom_type, Void):
return Void.symbol
elif isinstance(self.atom_type, UnknownSite):
return UnknownSite.symbol
else:
raise ValueError(f'Unknown species type: {self.atom_type}')

def as_list(self) -> list[float]:
site_arr = [*self.get_scattering_params(), self.x, self.y, self.z, self.occupancy]
return site_arr
Expand All @@ -45,15 +57,15 @@ def as_list(self) -> list[float]:
# These are *different* paramters from what you may commonly see e.g. here (https://lampz.tugraz.at/~hadley/ss1/crystaldiffraction/atomicformfactors/formfactors.php)
# since pymatgen uses a different formula to compute the form factor
def get_scattering_params(self) -> ScatteringParams:
if isinstance(self.species, Species) or isinstance(self.species, Element):
values = AtomicConstants.get_scattering_params(species=self.species)
elif isinstance(self.species, Void):
if isinstance(self.atom_type, Species) or isinstance(self.atom_type, Element):
values = AtomicConstants.get_scattering_params(species=self.atom_type)
elif isinstance(self.atom_type, Void):
values = (0, 0), (0, 0), (0, 0), (0, 0)
elif isinstance(self.species, UnknownSite):
elif isinstance(self.atom_type, UnknownSite):
fnan = float('nan')
values = (fnan,fnan), (fnan,fnan), (fnan,fnan), (fnan,fnan)
else:
raise ValueError(f'Unknown species type: {self.species}')
raise ValueError(f'Unknown species type: {self.atom_type}')

(a1, b1), (a2, b2), (a3, b3), (a4, b4) = values
return a1, b1, a2, b2, a3, b3, a4, b4
Expand All @@ -63,7 +75,7 @@ def get_scattering_params(self) -> ScatteringParams:

def to_str(self) -> str:
the_dict = {'x': self.x, 'y': self.y, 'z': self.z, 'occupancy': self.occupancy,
'species': str(self.species),
'species': str(self.atom_type),
'wyckoff_letter': self.wyckoff_letter}

return json.dumps(the_dict)
Expand All @@ -81,5 +93,5 @@ def from_str(cls, s: str):

return cls(x=the_dict['x'], y=the_dict['y'], z=the_dict['z'],
occupancy=the_dict['occupancy'],
species=species,
atom_type=species,
wyckoff_letter=the_dict['wyckoff_letter'])
4 changes: 2 additions & 2 deletions CrystalStructure/crystal/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, atomic_sites : Optional[list[AtomicSite]] = None):
def calculate_atomic_volume(self) -> float:
total_atomic_volume = 0
for site in self.get_non_void_sites():
element_symbol : ElementSymbol = site.species.element.symbol
element_symbol : ElementSymbol = site.get_symbol()
covalent_radius = AtomicConstants.get_covalent(element_symbol=element_symbol)
vdw_radius = AtomicConstants.get_vdw_radius(element_symbol=element_symbol)

Expand All @@ -42,7 +42,7 @@ def as_site_dictionaries(self) -> dict:
coords = (atom_site.x, atom_site.y, atom_site.z)
if not coords in coordinate_map:
coordinate_map[coords] = {}
coordinate_map[coords][atom_site.species] = atom_site.occupancy
coordinate_map[coords][atom_site.atom_type] = atom_site.occupancy

print(f'Coordinate map = {coordinate_map}')
return coordinate_map
Expand Down
6 changes: 3 additions & 3 deletions CrystalStructure/crystal/crystal.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def standardize(self):
new_base = CrystalBase()
for site in self.base:
x,y,z = apply_permutation([site.x, site.y, site.z], sort_permutation)
new_site = AtomicSite(x=x, y=y, z=z, occupancy=site.occupancy, species=site.species)
new_site = AtomicSite(x=x, y=y, z=z, occupancy=site.occupancy, atom_type=site.atom_type)
new_base.append(new_site)


Expand Down Expand Up @@ -102,7 +102,7 @@ def from_pymatgen(cls, pymatgen_structure: Structure) -> CrystalStructure:
site_composition = site.species
for species, occupancy in site_composition.items():
x,y,z = lattice.get_fractional_coords(site.coords)
atomic_site = AtomicSite(x,y,z, occupancy=occupancy, species=species)
atomic_site = AtomicSite(x, y, z, occupancy=occupancy, atom_type=species)
base.append(atomic_site)

crystal_str = cls(lengths=Lengths(a=lattice.a, b=lattice.b, c=lattice.c),
Expand All @@ -120,7 +120,7 @@ def to_pymatgen(self) -> Structure:
lattice = Lattice.from_parameters(a, b, c, alpha, beta, gamma)

non_void_sites = self.base.get_non_void_sites()
atoms = [site.species for site in non_void_sites]
atoms = [site.atom_type for site in non_void_sites]
positions = [(site.x, site.y, site.z) for site in non_void_sites]
return Structure(lattice, atoms, positions)

Expand Down
6 changes: 3 additions & 3 deletions tests/t_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ def test_scattering_params(self):
self.assertEqual(len(params), 8)
for p in params:
self.assertIsInstance(p, float)
if not atomic_site.species in seen_species:
print(f'Scattering params for species \"{atomic_site.species}\" a1, a2, a3, a4, b1, b2, b3, b4 = {params}')
seen_species.add(atomic_site.species)
if not atomic_site.atom_type in seen_species:
print(f'Scattering params for species \"{atomic_site.atom_type}\" a1, a2, a3, a4, b1, b2, b3, b4 = {params}')
seen_species.add(atomic_site.atom_type)


def test_site_dictionaries(self):
Expand Down
42 changes: 18 additions & 24 deletions tests/t_crystal/t_properties.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
import tests.t_crystal.crystaltest as BaseTest


# ---------------------------------------------------------

class TestPropertyCalculation(BaseTest.CrystalTest):
def test_atomic_volume(self):
for crystal in self.crystals:
print(f'self.crystal atomic volume fraction = {crystal.packing_density}')

def test_pymatgen(self):
for struct, crystal in zip(self.pymatgen_structures, self.crystals):
actual = crystal.to_pymatgen()
Expand All @@ -25,23 +20,30 @@ def test_pymatgen(self):

print(f'Composition = {actual.composition}')

def test_spacegroup_calculation(self):
for spg, crystal in zip(self.spgs,self.crystals):
crystal.calculate_properties()
computed_sg = crystal.space_group
print(f'Computed, actual spg = {computed_sg}, {spg}')

if computed_sg != spg:
raise ValueError(f'Computed spg {computed_sg} does not match actual spg {spg} given in cif file')

def test_volumes(self):
for crystal in self.crystals:
crystal.calculate_properties()

def test_volume_uc(self):
expected_volumes = [364.21601704000005, 1205.5]
for crystal, volume_exp in zip(self.crystals, expected_volumes):
self.assertAlmostEqual(crystal.volume_uc, volume_exp, places=5)
self.assertAlmostEqual(crystal.volume_uc, volume_exp, places=1)
for crystal in self.crystals:
print(f'self.crystal atomic volume fraction = {crystal.packing_density}')


def test_wyckoff_symbols(self):
def test_symmetries(self):
for crystal in self.crystals:
crystal.calculate_properties()

expected_space_groups = [57, 14]
for crystal, space_group_exp in zip(self.crystals, expected_space_groups):
self.assertEqual(crystal.space_group, space_group_exp)

expected_systems = ['orthorhombic', 'monoclinic']
for crystal, system_exp in zip(self.crystals, expected_systems):
self.assertEqual(crystal.crystal_system, system_exp)

expected_symbols = [
['d', 'd', 'd', 'd', 'd', 'd', 'd', 'd', 'c', 'c', 'c', 'c', 'd', 'd', 'd', 'd'],
['C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'O', 'O', 'O', 'O', 'O', 'O', 'N', 'N', 'H', 'H',
Expand All @@ -50,15 +52,7 @@ def test_wyckoff_symbols(self):
for crystal, symbols_exp in zip(self.crystals, expected_symbols):
self.assertEqual(crystal.wyckoff_symbols, symbols_exp)

def test_crystal_system(self):
expected_systems = ['orthorhombic', 'monoclinic']
for crystal, system_exp in zip(self.crystals, expected_systems):
self.assertEqual(crystal.crystal_system, system_exp)

def test_space_group(self):
expected_space_groups = [57, 14]
for crystal, space_group_exp in zip(self.crystals, expected_space_groups):
self.assertEqual(crystal.space_group, space_group_exp)

if __name__ == '__main__':
TestPropertyCalculation.execute_all()
10 changes: 5 additions & 5 deletions tests/t_crystal/t_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ def setUp(self):
primitives = Lengths(5, 3, 4)
mock_angles = Angles(90, 90, 90)
mock_base = CrystalBase([
AtomicSite(x=0.5, y=0.5, z=0.5, occupancy=1.0, species=Species("Si")),
AtomicSite(x=0.1, y=0.1, z=0.1, occupancy=1.0, species=Species("O")),
AtomicSite(x=0.9, y=0.9, z=0.9, occupancy=1.0, species=Void())
AtomicSite(x=0.5, y=0.5, z=0.5, occupancy=1.0, atom_type=Species("Si")),
AtomicSite(x=0.1, y=0.1, z=0.1, occupancy=1.0, atom_type=Species("O")),
AtomicSite(x=0.9, y=0.9, z=0.9, occupancy=1.0, atom_type=Void())
])
crystal = CrystalStructure(lengths=primitives, angles=mock_angles, base=mock_base)
crystal.calculate_properties()
Expand Down Expand Up @@ -47,10 +47,10 @@ def test_standardization(self):

@staticmethod
def get_site_symbol(site : AtomicSite):
if isinstance(site.species, Void):
if isinstance(site.atom_type, Void):
symbol = Void.symbol
else:
symbol = site.species.element.symbol
symbol = site.atom_type.element.symbol
return symbol


Expand Down

0 comments on commit 3a5eeed

Please sign in to comment.