Skip to content

Commit

Permalink
Merge pull request #71 from lamalab-org/rename_reps
Browse files Browse the repository at this point in the history
feat: rename representations
  • Loading branch information
n0w0f authored Jun 3, 2024
2 parents 5060df3 + ff0b8d4 commit e1d903a
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 61 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ text_rep = TextRep.from_input(structure)

requested_reps = [
"cif_p1",
"slice",
"atoms",
"atoms_params",
"crystal_llm_rep",
"slices",
"atom_sequences",
"atom_sequences_plusplus",
"crystal_text_llm",
"zmatrix"
]

Expand Down
14 changes: 7 additions & 7 deletions docs/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ text_rep = TextRep(structure)
# Define a list of requested representations
requested_reps = [
"cif_p1",
"slice",
"atoms_params",
"crystal_llm_rep",
"slices",
"atom_sequences_plusplus",
"crystal_text_llm",
"zmatrix"
]

Expand All @@ -42,14 +42,14 @@ requested_text_reps = text_rep.get_requested_text_reps(requested_reps)

The `TextRep` class currently supports the following text representations:

- **SlICES** (`slice`): SLICE representation of the crystal structure.
- **SlICES** (`slices`): SLICE representation of the crystal structure.
- **Composition** (`composition`): Chemical composition in hill format.
- **CIF Symmetrized** (`cif_symmetrized`): Multi-line CIF representation with symmetrized structure and rounded float numbers.
- **CIF $P_1$** (`cif_p1`): Multi-line CIF representation with the conventional unit cell and rounded float numbers.
- **Crystal-text-LLM Representation** (`crystal_llm_rep`): Representation following the format specified in the cited work.
- **Crystal-text-LLM Representation** (`crystal_text_llm`): Representation following the format specified in the cited work.
- **Robocrystallographer Representation** (`robocrys_rep`): Representation generated by Robocrystallographer.
- **Atom Sequences** (`atoms`): List of atoms inside the unit cell.
- **Atoms Squences++** (`atoms_params`): List of atoms with lattice parameters.
- **Atom Sequences** (`atom_sequences`): List of atoms inside the unit cell.
- **Atoms Sequences++** (`atom_sequences_plusplus`): List of atoms with lattice parameters.
- **Z-Matrix** (`zmatrix`): Z-Matrix representation of the crystal structure.
- **Local Env** (`local_env`): List of Wyckoff label and SMILES separated by line breaks for each local environment.

Expand Down
111 changes: 69 additions & 42 deletions src/xtal2txt/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,14 @@ def round_numbers_in_string(original_string: str, decimal_places: int) -> str:
pattern = r"\b\d+\.\d+\b"
matches = re.findall(pattern, original_string)
rounded_numbers = [round(float(match), decimal_places) for match in matches]
new_string = re.sub(pattern, lambda x: str(rounded_numbers.pop(0)), original_string)
new_string = re.sub(
pattern, lambda x: str(rounded_numbers.pop(0)), original_string
)
return new_string

def get_cif_string(self, format: str = "symmetrized", decimal_places: int = 3) -> str:
def get_cif_string(
self, format: str = "symmetrized", decimal_places: int = 3
) -> str:
"""
Generate CIF as string in multi-line format.
Expand Down Expand Up @@ -156,7 +160,9 @@ def get_lattice_parameters(self, decimal_places: int = 3) -> List[str]:
Returns:
List[str]: The lattice parameters.
"""
return [str(round(i, decimal_places)) for i in self.structure.lattice.parameters]
return [
str(round(i, decimal_places)) for i in self.structure.lattice.parameters
]

def get_coords(self, name: str = "cartesian", decimal_places: int = 3) -> List[str]:
"""
Expand All @@ -183,8 +189,8 @@ def get_coords(self, name: str = "cartesian", decimal_places: int = 3) -> List[s
elements.extend(coord)
return elements

def get_slice(self, primitive: bool = True) -> str:
"""Returns SLICE representation of the crystal structure.
def get_slices(self, primitive: bool = True) -> str:
"""Returns SLICES representation of the crystal structure.
https://www.nature.com/articles/s41467-023-42870-7
Args:
Expand Down Expand Up @@ -235,7 +241,7 @@ def get_local_env_rep(self, local_env_kwargs: Optional[dict] = None) -> str:
analyzer = LocalEnvAnalyzer(**local_env_kwargs)
return analyzer.structure_to_local_env_string(self.structure)

def get_crystal_llm_rep(
def get_crystal_text_llm(
self,
permute_atoms: bool = False,
) -> str:
Expand Down Expand Up @@ -338,7 +344,9 @@ def get_wycryst(self):

return output

def get_atoms_params_rep(self, lattice_params: bool = False, decimal_places: int = 1) -> str:
def get_atom_sequences_plusplus(
self, lattice_params: bool = False, decimal_places: int = 1
) -> str:
"""
Generating a string with the elements of composition inside the crystal lattice with the option to
get the lattice parameters as angles (int) and lengths (float) in a string with a space
Expand Down Expand Up @@ -366,7 +374,7 @@ def get_atoms_params_rep(self, lattice_params: bool = False, decimal_places: int
def updated_zmatrix_rep(self, zmatrix, decimal_places=1):
"""
Replace the variables in the Z-matrix with their values and return the updated Z-matrix.
for eg: z-matrix from pymatgen
for eg: z-matrix from pymatgen
'N\nN 1 B1\nN 1 B2 2 A2\nN 1 B3 2 A3 3 D3\n
B1=3.79
B2=6.54
Expand Down Expand Up @@ -423,14 +431,16 @@ def updated_zmatrix_rep(self, zmatrix, decimal_places=1):

def get_zmatrix_rep(self, decimal_places=1):
"""
Generate the Z-matrix representation of the crystal structure.
It provides a description of each atom in terms of its atomic number,
Generate the Z-matrix representation of the crystal structure.
It provides a description of each atom in terms of its atomic number,
bond length, bond angle, and dihedral angle, the so-called internal coordinates.
Disclaimer: The Z-matrix is meant for molecules, current implementation converts atoms within unit cell to molecule.
Hence the current implentation might overlook bonds acrosse unit cells.
Disclaimer: The Z-matrix is meant for molecules, current implementation converts atoms within unit cell to molecule.
Hence the current implentation might overlook bonds acrosse unit cells.
"""
species = [s.element if hasattr(s, 'element') else s for s in self.structure.species]
species = [
s.element if hasattr(s, "element") else s for s in self.structure.species
]
coords = [c for c in self.structure.cart_coords]
molecule_ = Molecule(
species,
Expand All @@ -452,26 +462,28 @@ def get_all_text_reps(self, decimal_places: int = 2):
self.get_cif_string, format="symmetrized", decimal_places=decimal_places
),
"cif_bonding": None,
"slice": self._safe_call(self.get_slice),
"slices": self._safe_call(self.get_slices),
"composition": self._safe_call(self.get_composition),
"crystal_llm_rep": self._safe_call(self.get_crystal_llm_rep),
"crystal_text_llm": self._safe_call(self.get_crystal_text_llm),
"robocrys_rep": self._safe_call(self.get_robocrys_rep),
"wycoff_rep": None,
"atoms": self._safe_call(
self.get_atoms_params_rep,
"atom_sequences": self._safe_call(
self.get_atom_sequences_plusplus,
lattice_params=False,
decimal_places=decimal_places,
),
"atoms_params": self._safe_call(
self.get_atoms_params_rep,
"atom_sequences_plusplus": self._safe_call(
self.get_atom_sequences_plusplus,
lattice_params=True,
decimal_places=decimal_places,
),
"zmatrix": self._safe_call(self.get_zmatrix_rep),
"local_env": self._safe_call(self.get_local_env_rep, local_env_kwargs=None),
}

def get_requested_text_reps(self, requested_reps: List[str], decimal_places: int = 2):
def get_requested_text_reps(
self, requested_reps: List[str], decimal_places: int = 2
):
"""
Returns the requested Text representations of the crystal structure in a dictionary.
Expand All @@ -483,31 +495,46 @@ def get_requested_text_reps(self, requested_reps: List[str], decimal_places: int
dict: A dictionary containing the requested text representations of the crystal structure.
"""

all_reps = {
"cif_p1": lambda: self._safe_call(
if requested_reps == "cif_p1":
return self._safe_call(
self.get_cif_string, format="p1", decimal_places=decimal_places
),
"cif_symmetrized": lambda: self._safe_call(
self.get_cif_string, format="symmetrized", decimal_places=decimal_places
),
"cif_bonding": lambda: None,
"slice": lambda: self._safe_call(self.get_slice),
"composition": lambda: self._safe_call(self.get_composition),
"crystal_llm_rep": lambda: self._safe_call(self.get_crystal_llm_rep),
"robocrys_rep": lambda: self._safe_call(self.get_robocrys_rep),
"wycoff_rep": lambda: None,
"atoms": lambda: self._safe_call(
self.get_atoms_params_rep,
)

elif requested_reps == "cif_symmetrized":
return self._safe_call(
self.get_cif_string,
format="symmetrized",
decimal_places=decimal_places,
)

elif requested_reps == "slices":
return self._safe_call(self.get_slices)

elif requested_reps == "composition":
return self._safe_call(self.get_composition)

elif requested_reps == "crystal_text_llm":
return self._safe_call(self.get_crystal_text_llm)

elif requested_reps == "robocrys_rep":
return self._safe_call(self.get_robocrys_rep)

elif requested_reps == "atom_sequences":
return self._safe_call(
self.get_atom_sequences_plusplus,
lattice_params=False,
decimal_places=decimal_places,
),
"atoms_params": lambda: self._safe_call(
self.get_atoms_params_rep,
)

elif requested_reps == "atom_sequences_plusplus":
return self._safe_call(
self.get_atom_sequences_plusplus,
lattice_params=True,
decimal_places=decimal_places,
),
"zmatrix": lambda: self._safe_call(self.get_zmatrix_rep, decimal_places=1),
"local_env": lambda: self._safe_call(self.get_local_env_rep, local_env_kwargs=None),
}
)

elif requested_reps == "zmatrix":
return self._safe_call(self.get_zmatrix_rep)

return {rep: all_reps[rep]() for rep in requested_reps if rep in all_reps}
elif requested_reps == "local_env":
return self._safe_call(self.get_local_env_rep, local_env_kwargs=None)
8 changes: 5 additions & 3 deletions tests/test_get_atom_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ def test_get_atom_params() -> None:
expected_wo_params = "Sr Ti O O O"
expected_w_params = "Sr Ti O O O 3.9 3.9 3.9 90 90 90"
expected_tiCrSe_p1 = "Tl Tl Cr Cr Cr Cr Cr Cr Cr Cr Cr Cr Se Se Se Se Se Se Se Se Se Se Se Se Se Se Se Se"
assert srtio3_p1.get_atoms_params_rep() == expected_wo_params
assert tiCrSe_p1.get_atoms_params_rep() == expected_tiCrSe_p1
assert srtio3_p1.get_atoms_params_rep(lattice_params=True) == expected_w_params
assert srtio3_p1.get_atom_sequences_plusplus() == expected_wo_params
assert tiCrSe_p1.get_atom_sequences_plusplus() == expected_tiCrSe_p1
assert (
srtio3_p1.get_atom_sequences_plusplus(lattice_params=True) == expected_w_params
)
6 changes: 3 additions & 3 deletions tests/test_textrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,14 @@ def test_get_lattice_parameters() -> None:
assert N2.get_lattice_parameters() == expected_output


def test_get_slice() -> None:
def test_get_slices() -> None:
expected_output = "N N N N 0 1 - o o 0 1 - + o 0 1 o o o 0 1 o + o 0 2 o + o 0 2 o + + 0 2 + + o 0 2 + + + 0 3 o o - 0 3 o o o 0 3 o + - 0 3 o + o 1 3 o o - 1 3 o o o 1 3 + o - 1 3 + o o 1 2 + o o 1 2 + o + 1 2 + + o 1 2 + + + 2 3 - - - 2 3 - o - 2 3 o - - 2 3 o o - "
assert N2.get_slice() == expected_output
assert N2.get_slices() == expected_output


def test_get_crystal_llm_rep() -> None:
expected_output = "5.6 5.6 5.6\n90 90 90\nN0+\n0.48 0.98 0.52\nN0+\n0.98 0.52 0.48\nN0+\n0.02 0.02 0.02\nN0+\n0.52 0.48 0.98"
assert N2.get_crystal_llm_rep() == expected_output
assert N2.get_crystal_text_llm() == expected_output


def test_robocrys_for_cif_format() -> None:
Expand Down
4 changes: 2 additions & 2 deletions tests/tokenizer/test_rt_tokenier.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,14 @@ def test_encode_decode(
decoded_tokens = cif_rt_tokenizer.decode(token_ids, skip_special_tokens=True)
assert input_string == decoded_tokens

input_string = struct.get_crystal_llm_rep()
input_string = struct.get_crystal_text_llm()
token_ids = crystal_llm_rt_tokenizer.encode(input_string)
decoded_tokens = crystal_llm_rt_tokenizer.decode(
token_ids, skip_special_tokens=True
)
assert input_string == decoded_tokens

input_string = struct.get_slice()
input_string = struct.get_slices()
token_ids = slice_rt_tokenizer.encode(input_string)
decoded_tokens = slice_rt_tokenizer.decode(token_ids, skip_special_tokens=True)
try:
Expand Down

0 comments on commit e1d903a

Please sign in to comment.