diff --git a/src/xtal2txt/analysis.py b/src/xtal2txt/analysis.py index 0e63afe..0dc5915 100644 --- a/src/xtal2txt/analysis.py +++ b/src/xtal2txt/analysis.py @@ -1,31 +1,132 @@ - ANALYSIS_MASK_TOKENS = { - "atoms" : "[ATOMS]", - "directions" : "[DIR]", - "numbers" : "[NUMS]", - "bonds" : "[BONDS]", - "miscellaneous" : "[MISC]", - "identifier" : "[ID]", - "symmetry" : "[SYM]", - "lattice" : "[LATTICE]", - "composition" : "[COMP]", - None : "[NONE]" + "atoms": "[ATOMS]", + "directions": "[DIR]", + "numbers": "[NUMS]", + "bonds": "[BONDS]", + "miscellaneous": "[MISC]", + "identifier": "[ID]", + "symmetry": "[SYM]", + "lattice": "[LATTICE]", + "composition": "[COMP]", + None: "[NONE]", } ATOM_LIST_ = [ - "H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne", "Na", "Mg", "Al", "Si", "P", - "S", "Cl", "K", "Ar", "Ca", "Sc", "Ti", "V", "Cr", "Mn", "Fe", "Ni", "Co", "Cu", "Zn", - "Ga", "Ge", "As", "Se", "Br", "Kr", "Rb", "Sr", "Y", "Zr", "Nb", "Mo", "Tc", "Ru", "Rh", - "Pd", "Ag", "Cd", "In", "Sn", "Sb", "Te", "I", "Xe", "Cs", "Ba", "La", "Ce", "Pr", "Nd", - "Pm", "Sm", "Eu", "Gd", "Tb", "Dy", "Ho", "Er", "Tm", "Yb", "Lu", "Hf", "Ta", "W", "Re", - "Os", "Ir", "Pt", "Au", "Hg", "Tl", "Pb", "Bi", "Th", "Pa", "U", "Np", "Pu", "Am", "Cm", - "Bk", "Cf", "Es", "Fm", "Md", "No", "Lr", "Rf", "Db", "Sg", "Bh", "Hs", "Mt", "Ds", "Rg", - "Cn", "Nh", "Fl", "Mc", "Lv", "Ts", "Og" - ] + "H", + "He", + "Li", + "Be", + "B", + "C", + "N", + "O", + "F", + "Ne", + "Na", + "Mg", + "Al", + "Si", + "P", + "S", + "Cl", + "K", + "Ar", + "Ca", + "Sc", + "Ti", + "V", + "Cr", + "Mn", + "Fe", + "Ni", + "Co", + "Cu", + "Zn", + "Ga", + "Ge", + "As", + "Se", + "Br", + "Kr", + "Rb", + "Sr", + "Y", + "Zr", + "Nb", + "Mo", + "Tc", + "Ru", + "Rh", + "Pd", + "Ag", + "Cd", + "In", + "Sn", + "Sb", + "Te", + "I", + "Xe", + "Cs", + "Ba", + "La", + "Ce", + "Pr", + "Nd", + "Pm", + "Sm", + "Eu", + "Gd", + "Tb", + "Dy", + "Ho", + "Er", + "Tm", + "Yb", + "Lu", + "Hf", + "Ta", + "W", + "Re", + "Os", + "Ir", + "Pt", + "Au", + "Hg", + "Tl", + "Pb", + "Bi", + "Th", + "Pa", + "U", + "Np", + "Pu", + "Am", + "Cm", + "Bk", + "Cf", + "Es", + "Fm", + "Md", + "No", + "Lr", + "Rf", + "Db", + "Sg", + "Bh", + "Hs", + "Mt", + "Ds", + "Rg", + "Cn", + "Nh", + "Fl", + "Mc", + "Lv", + "Ts", + "Og", +] -NUMS_ = [ - "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", ".", "+", "-" - ] +NUMS_ = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", ".", "+", "-"] COMPOSITION_ANALYSIS_DICT = { "atoms": ATOM_LIST_, @@ -35,44 +136,106 @@ SLICE_ANALYSIS_DICT = { "atoms": ATOM_LIST_, "directions": [ - "o o o", "o o +", "o o -", "o + o", "o + +", "o + -", "o - o", "o - +", "o - -", - "+ o o", "+ o +", "+ o -", "+ + o", "+ + +", "+ + -", "+ - o", "+ - +", "+ - -", - "- o o", "- o +", "- o -", "- + o", "- + +", "- + -", "- - o", "- - +", "- - -" + "o o o", + "o o +", + "o o -", + "o + o", + "o + +", + "o + -", + "o - o", + "o - +", + "o - -", + "+ o o", + "+ o +", + "+ o -", + "+ + o", + "+ + +", + "+ + -", + "+ - o", + "+ - +", + "+ - -", + "- o o", + "- o +", + "- o -", + "- + o", + "- + +", + "- + -", + "- - o", + "- - +", + "- - -", ], - "numbers": NUMS_ + "numbers": NUMS_, } CRYSTAL_LLM_ANALYSIS_DICT = { "atoms": ATOM_LIST_, "numbers": NUMS_, - "miscellaneous": ["\n", " "] + "miscellaneous": ["\n", " "], } CIF_ANALYSIS_DICT = { "atoms": ATOM_LIST_, "numbers": NUMS_, - "lattice": ["_cell_length_a", "_cell_length_b", "_cell_length_c", - "_cell_angle_alpha", "_cell_angle_beta", "_cell_angle_gamma" ], + "lattice": [ + "_cell_length_a", + "_cell_length_b", + "_cell_length_c", + "_cell_angle_alpha", + "_cell_angle_beta", + "_cell_angle_gamma", + ], "identifier": ["loop_"], - "composition": [ - "_chemical_formula_structural", - "_chemical_formula_sum"], - "symmetry": ["_symmetry_space_group_name_H-M", "_symmetry_Int_Tables_number", - "_symmetry_equiv_pos_site_id","_symmetry_equiv_pos_as_xyz",], - "miscellaneous": ["_atom_site_symmetry_multiplicity" , - "_atom_type_symbol", - "_atom_type_oxidation_number", - "_atom_site_type_symbol", - "_atom_site_label", - "_atom_site_symmetry_multiplicity", - "_atom_site_fract_x", - "_atom_site_fract_y", - "_atom_site_fract_z", - "_atom_site_occupancy", - "-", "/", "'", "\"", ",", "'x, y, z'", "x", "y", "z", "-x", "-y", "-z", " ", " ", "\n", "_geom_bond_atom_site_label_1", "_geom_bond_atom_site_label_2", "_geom_bond_distance", "_ccdc_geom_bond_type", "_", "a", "n", "c", "b", "m", "d", "R", "A", "(", ")", "[", "]", "*" - ],} - - - - + "composition": ["_chemical_formula_structural", "_chemical_formula_sum"], + "symmetry": [ + "_symmetry_space_group_name_H-M", + "_symmetry_Int_Tables_number", + "_symmetry_equiv_pos_site_id", + "_symmetry_equiv_pos_as_xyz", + ], + "miscellaneous": [ + "_atom_site_symmetry_multiplicity", + "_atom_type_symbol", + "_atom_type_oxidation_number", + "_atom_site_type_symbol", + "_atom_site_label", + "_atom_site_symmetry_multiplicity", + "_atom_site_fract_x", + "_atom_site_fract_y", + "_atom_site_fract_z", + "_atom_site_occupancy", + "-", + "/", + "'", + '"', + ",", + "'x, y, z'", + "x", + "y", + "z", + "-x", + "-y", + "-z", + " ", + " ", + "\n", + "_geom_bond_atom_site_label_1", + "_geom_bond_atom_site_label_2", + "_geom_bond_distance", + "_ccdc_geom_bond_type", + "_", + "a", + "n", + "c", + "b", + "m", + "d", + "R", + "A", + "(", + ")", + "[", + "]", + "*", + ], +} diff --git a/src/xtal2txt/tokenizer.py b/src/xtal2txt/tokenizer.py index 5bead4e..4698b2c 100644 --- a/src/xtal2txt/tokenizer.py +++ b/src/xtal2txt/tokenizer.py @@ -1,15 +1,17 @@ +import json import os import re -import json -from pathlib import Path - - -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from tokenizers import Tokenizer -from xtal2txt.analysis import ANALYSIS_MASK_TOKENS, SLICE_ANALYSIS_DICT, CRYSTAL_LLM_ANALYSIS_DICT, CIF_ANALYSIS_DICT, COMPOSITION_ANALYSIS_DICT - +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from xtal2txt.analysis import ( + ANALYSIS_MASK_TOKENS, + CIF_ANALYSIS_DICT, + COMPOSITION_ANALYSIS_DICT, + CRYSTAL_LLM_ANALYSIS_DICT, + SLICE_ANALYSIS_DICT, +) THIS_DIR = os.path.dirname(os.path.abspath(__file__)) SLICE_VOCAB = os.path.join(THIS_DIR, "vocabs", "slice_vocab.txt") @@ -19,11 +21,14 @@ ROBOCRYS_VOCAB = os.path.join(THIS_DIR, "vocabs", "robocrys_vocab.json") - class Xtal2txtTokenizer(PreTrainedTokenizer): - def __init__(self, vocab_file, model_max_length=None, padding_length=None, **kwargs): - super(Xtal2txtTokenizer, self).__init__(model_max_length=model_max_length, **kwargs) - + def __init__( + self, vocab_file, model_max_length=None, padding_length=None, **kwargs + ): + super(Xtal2txtTokenizer, self).__init__( + model_max_length=model_max_length, **kwargs + ) + self.vocab = self.load_vocab(vocab_file) self.vocab_file = vocab_file self.truncation = False @@ -32,16 +37,16 @@ def __init__(self, vocab_file, model_max_length=None, padding_length=None, **kwa def load_vocab(self, vocab_file): _, file_extension = os.path.splitext(vocab_file) - if file_extension == '.txt': - with open(vocab_file, 'r', encoding='utf-8') as file: + if file_extension == ".txt": + with open(vocab_file, "r", encoding="utf-8") as file: vocab = file.read().splitlines() return {token: idx for idx, token in enumerate(vocab)} - elif file_extension == '.json': - with open(vocab_file, 'r', encoding='utf-8') as file: + elif file_extension == ".json": + with open(vocab_file, "r", encoding="utf-8") as file: return json.load(file) else: raise ValueError(f"Unsupported file type: {file_extension}") - + def get_vocab(self): return self.vocab @@ -50,12 +55,12 @@ def tokenize(self, text): string_tokens = [token for token in tokens if isinstance(token, str)] string_tokens.sort(key=len, reverse=True) escaped_tokens = [re.escape(token) for token in string_tokens] - pattern_str = '|'.join(escaped_tokens) + pattern_str = "|".join(escaped_tokens) pattern = re.compile(pattern_str) matches = pattern.findall(text) if self.truncation and len(matches) > self.model_max_length: - matches = matches[:self.model_max_length] + matches = matches[: self.model_max_length] if self.padding and len(matches) < self.padding_length: matches += [self.pad_token] * (self.padding_length - len(matches)) @@ -63,7 +68,7 @@ def tokenize(self, text): return matches def convert_tokens_to_string(self, tokens): - return ' '.join(tokens) + return " ".join(tokens) def _add_tokens(self, new_tokens, **kwargs): for token in new_tokens: @@ -97,7 +102,7 @@ def add_special_tokens(self, special_tokens): self.vocab[value] = len(self.vocab) self.save_vocabulary(os.path.dirname(self.vocab_file)) - def token_analysis(self,tokens): + def token_analysis(self, tokens): """This method should be implemented by the Downstream tokenizers.""" raise NotImplementedError @@ -105,16 +110,23 @@ def save_vocabulary(self, save_directory, filename_prefix=None): """Save the vocabulary, ensures vocabularies are not overwritten. Filename follow the convention {index}-{filename_prefix}.json. Index keeps track of the latest vocabulary saved.""" index = 0 if os.path.isdir(save_directory): - vocab_files = list(filter(lambda x: x.endswith(".json"), os.listdir(save_directory))) + vocab_files = list( + filter(lambda x: x.endswith(".json"), os.listdir(save_directory)) + ) for vocab_file in vocab_files: try: - index = max(index, int(vocab_file.split('-')[0])) + index = max(index, int(vocab_file.split("-")[0])) except ValueError: pass # Ignore files that do not start with an integer - vocab_file = os.path.join(save_directory, f"{index + 1}-{filename_prefix}.json" if filename_prefix else f"{index + 1}.json") + vocab_file = os.path.join( + save_directory, + f"{index + 1}-{filename_prefix}.json" + if filename_prefix + else f"{index + 1}.json", + ) - with open(vocab_file, 'w', encoding='utf-8') as f: + with open(vocab_file, "w", encoding="utf-8") as f: json.dump(self.vocab, f, ensure_ascii=False) return (vocab_file,) @@ -123,14 +135,21 @@ def save_vocabulary(self, save_directory, filename_prefix=None): def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): if pretrained_model_name_or_path is not None: if os.path.isdir(pretrained_model_name_or_path): - vocab_files = list(filter(lambda x: x.endswith(".json"), os.listdir(pretrained_model_name_or_path))) - vocab_files.sort(key=lambda x: int(x.split('-')[0])) - vocab_file = os.path.join(pretrained_model_name_or_path, vocab_files[-1]) + vocab_files = list( + filter( + lambda x: x.endswith(".json"), + os.listdir(pretrained_model_name_or_path), + ) + ) + vocab_files.sort(key=lambda x: int(x.split("-")[0])) + vocab_file = os.path.join( + pretrained_model_name_or_path, vocab_files[-1] + ) if vocab_file is None: raise ValueError("You should specify a path to a vocab file") - with open(vocab_file, 'r', encoding='utf-8') as f: + with open(vocab_file, "r", encoding="utf-8") as f: vocab = json.load(f) tokenizer = cls(vocab_file, *inputs, **kwargs) @@ -139,72 +158,125 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): return tokenizer - class SliceTokenizer(Xtal2txtTokenizer): - def __init__(self, vocab_file=SLICE_VOCAB, model_max_length=None, padding_length=None, **kwargs): - super(SliceTokenizer, self).__init__(vocab_file, model_max_length=model_max_length, padding_length=padding_length, **kwargs) + def __init__( + self, + vocab_file=SLICE_VOCAB, + model_max_length=None, + padding_length=None, + **kwargs, + ): + super(SliceTokenizer, self).__init__( + vocab_file, + model_max_length=model_max_length, + padding_length=padding_length, + **kwargs, + ) def token_analysis(self, list_of_tokens): """Takes tokens after tokenize and returns a list with replacing the tokens with their MASK token. The - token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token.""" + token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token.""" analysis_masks = ANALYSIS_MASK_TOKENS token_type = SLICE_ANALYSIS_DICT - return [analysis_masks[next((k for k, v in token_type.items() if token in v), None)] for token in list_of_tokens] - + return [ + analysis_masks[next((k for k, v in token_type.items() if token in v), None)] + for token in list_of_tokens + ] + + class CompositionTokenizer(Xtal2txtTokenizer): - def __init__(self, vocab_file=COMPOSITION_VOCAB, model_max_length=None, padding_length=None, **kwargs): - super(CompositionTokenizer, self).__init__(vocab_file, model_max_length=model_max_length, padding_length=padding_length, **kwargs) + def __init__( + self, + vocab_file=COMPOSITION_VOCAB, + model_max_length=None, + padding_length=None, + **kwargs, + ): + super(CompositionTokenizer, self).__init__( + vocab_file, + model_max_length=model_max_length, + padding_length=padding_length, + **kwargs, + ) def convert_tokens_to_string(self, tokens): - return ''.join(tokens) - + return "".join(tokens) + def token_analysis(self, list_of_tokens): """Takes tokens after tokenize and returns a list with replacing the tokens with their MASK token. The - token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token.""" + token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token.""" analysis_masks = ANALYSIS_MASK_TOKENS token_type = COMPOSITION_ANALYSIS_DICT - return [analysis_masks[next((k for k, v in token_type.items() if token in v), None)] for token in list_of_tokens] - + return [ + analysis_masks[next((k for k, v in token_type.items() if token in v), None)] + for token in list_of_tokens + ] + class CifTokenizer(Xtal2txtTokenizer): - def __init__(self, vocab_file=CIF_VOCAB, model_max_length=None, padding_length=None, **kwargs): - super(CifTokenizer, self).__init__(vocab_file, model_max_length=model_max_length, padding_length=padding_length, **kwargs) + def __init__( + self, vocab_file=CIF_VOCAB, model_max_length=None, padding_length=None, **kwargs + ): + super(CifTokenizer, self).__init__( + vocab_file, + model_max_length=model_max_length, + padding_length=padding_length, + **kwargs, + ) def convert_tokens_to_string(self, tokens): - return ''.join(tokens) - + return "".join(tokens) + def token_analysis(self, list_of_tokens): """Takes tokens after tokenize and returns a list with replacing the tokens with their MASK token. The - token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token.""" + token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token.""" analysis_masks = ANALYSIS_MASK_TOKENS token_type = CIF_ANALYSIS_DICT - return [analysis_masks[next((k for k, v in token_type.items() if token in v), None)] for token in list_of_tokens] + return [ + analysis_masks[next((k for k, v in token_type.items() if token in v), None)] + for token in list_of_tokens + ] + class CrysllmTokenizer(Xtal2txtTokenizer): - def __init__(self, vocab_file=CRYSTAL_LLM_VOCAB, model_max_length=None, padding_length=None, **kwargs): - super(CrysllmTokenizer, self).__init__(vocab_file, model_max_length=model_max_length, padding_length=padding_length, **kwargs) + def __init__( + self, + vocab_file=CRYSTAL_LLM_VOCAB, + model_max_length=None, + padding_length=None, + **kwargs, + ): + super(CrysllmTokenizer, self).__init__( + vocab_file, + model_max_length=model_max_length, + padding_length=padding_length, + **kwargs, + ) def convert_tokens_to_string(self, tokens): - return ''.join(tokens) - + return "".join(tokens) + def token_analysis(self, list_of_tokens): """Takes tokens after tokenize and returns a list with replacing the tokens with their MASK token. The - token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token.""" + token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token.""" analysis_masks = ANALYSIS_MASK_TOKENS token_type = CRYSTAL_LLM_ANALYSIS_DICT - return [analysis_masks[next((k for k, v in token_type.items() if token in v), None)] for token in list_of_tokens] + return [ + analysis_masks[next((k for k, v in token_type.items() if token in v), None)] + for token in list_of_tokens + ] -class RobocrysTokenizer(): +class RobocrysTokenizer: """Tokenizer for Robocrystallographer. Would be BPE tokenizer. trained on the Robocrystallographer dataset. TODO: Implement this tokenizer. """ - def __init__(self, vocab_file=ROBOCRYS_VOCAB, **kwargs): - tokenizer = Tokenizer.from_file(vocab_file) - wrapped_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer) - self._tokenizer = wrapped_tokenizer + def __init__(self, vocab_file=ROBOCRYS_VOCAB, **kwargs): + tokenizer = Tokenizer.from_file(vocab_file) + wrapped_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer) + self._tokenizer = wrapped_tokenizer def tokenize(self, text): return self._tokenizer.tokenize(text) @@ -216,6 +288,6 @@ def decode(self, token_ids, skip_special_tokens=True): # Check if token_ids is a string and convert it to a list of integers if isinstance(token_ids, str): token_ids = [int(token_ids)] - return self._tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) - - + return self._tokenizer.decode( + token_ids, skip_special_tokens=skip_special_tokens + )