From a6bde57f21c71d892d0516dcf554180d087590ba Mon Sep 17 00:00:00 2001 From: n0w0f Date: Mon, 18 Mar 2024 06:59:37 +0100 Subject: [PATCH 1/3] feat: helper for attention analysis --- src/xtal2txt/analysis.py | 73 ++++++++++++++++++++++++++++++++++++++ src/xtal2txt/tokenizer.py | 30 ++++++++++++++-- src/xtal2txt/vocabs/1.json | 1 + 3 files changed, 101 insertions(+), 3 deletions(-) create mode 100644 src/xtal2txt/analysis.py create mode 100644 src/xtal2txt/vocabs/1.json diff --git a/src/xtal2txt/analysis.py b/src/xtal2txt/analysis.py new file mode 100644 index 0000000..acb6426 --- /dev/null +++ b/src/xtal2txt/analysis.py @@ -0,0 +1,73 @@ + +ANALYSIS_MASK_TOKENS = { + "atoms" : "[ATOMS]", + "directions" : "[DIR]", + "numbers" : "[NUMS]", + "bonds" : "[BONDS]", + "miscellaneous" : "[MISC]", + "identifier" : "[ID]", + "symmetry" : "[SYM]", + "lattice" : "[LATTICE]", + "composition" : "[COMP]", + None : "[NONE]" +} + +ATOM_LIST_ = [ + "H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne", "Na", "Mg", "Al", "Si", "P", + "S", "Cl", "K", "Ar", "Ca", "Sc", "Ti", "V", "Cr", "Mn", "Fe", "Ni", "Co", "Cu", "Zn", + "Ga", "Ge", "As", "Se", "Br", "Kr", "Rb", "Sr", "Y", "Zr", "Nb", "Mo", "Tc", "Ru", "Rh", + "Pd", "Ag", "Cd", "In", "Sn", "Sb", "Te", "I", "Xe", "Cs", "Ba", "La", "Ce", "Pr", "Nd", + "Pm", "Sm", "Eu", "Gd", "Tb", "Dy", "Ho", "Er", "Tm", "Yb", "Lu", "Hf", "Ta", "W", "Re", + "Os", "Ir", "Pt", "Au", "Hg", "Tl", "Pb", "Bi", "Th", "Pa", "U", "Np", "Pu", "Am", "Cm", + "Bk", "Cf", "Es", "Fm", "Md", "No", "Lr", "Rf", "Db", "Sg", "Bh", "Hs", "Mt", "Ds", "Rg", + "Cn", "Nh", "Fl", "Mc", "Lv", "Ts", "Og" + ] + +NUMS_ = [ + "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", ".", "+", "-" + ] + +SLICE_ANALYSIS_DICT = { + "atoms": ATOM_LIST_, + "directions": [ + "o o o", "o o +", "o o -", "o + o", "o + +", "o + -", "o - o", "o - +", "o - -", + "+ o o", "+ o +", "+ o -", "+ + o", "+ + +", "+ + -", "+ - o", "+ - +", "+ - -", + "- o o", "- o +", "- o -", "- + o", "- + +", "- + -", "- - o", "- - +", "- - -" + ], + "numbers": NUMS_ +} + + +CRYSTAL_LLM_ANALYSIS_DICT = { + "atoms": ATOM_LIST_, + "numbers": NUMS_, + "miscellaneous": ["\n", " "] +} + +CIF_ANALYSIS_DICT = { + "atoms": ATOM_LIST_, + "numbers": NUMS_, + "lattice": ["_cell_length_a", "_cell_length_b", "_cell_length_c", + "_cell_angle_alpha", "_cell_angle_beta", "_cell_angle_gamma" ], + "identifier": ["loop_"], + "composition": [ + "_chemical_formula_structural", + "_chemical_formula_sum"], + "symmetry": ["_symmetry_space_group_name_H-M", "_symmetry_Int_Tables_number", + "_symmetry_equiv_pos_site_id","_symmetry_equiv_pos_as_xyz",], + "miscellaneous": ["_atom_site_symmetry_multiplicity" , + "_atom_type_symbol", + "_atom_type_oxidation_number", + "_atom_site_type_symbol", + "_atom_site_label", + "_atom_site_symmetry_multiplicity", + "_atom_site_fract_x", + "_atom_site_fract_y", + "_atom_site_fract_z", + "_atom_site_occupancy", + "-", "/", "'", "\"", ",", "'x, y, z'", "x", "y", "z", "-x", "-y", "-z", " ", " ", "\n", "_geom_bond_atom_site_label_1", "_geom_bond_atom_site_label_2", "_geom_bond_distance", "_ccdc_geom_bond_type", "_", "a", "n", "c", "b", "m", "d", "R", "A", "(", ")", "[", "]", "*" + ],} + + + + diff --git a/src/xtal2txt/tokenizer.py b/src/xtal2txt/tokenizer.py index 850346c..e77768f 100644 --- a/src/xtal2txt/tokenizer.py +++ b/src/xtal2txt/tokenizer.py @@ -7,6 +7,7 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from tokenizers import Tokenizer +from xtal2txt.analysis import ANALYSIS_MASK_TOKENS, SLICE_ANALYSIS_DICT, CRYSTAL_LLM_ANALYSIS_DICT, CIF_ANALYSIS_DICT @@ -96,6 +97,10 @@ def add_special_tokens(self, special_tokens): self.vocab[value] = len(self.vocab) self.save_vocabulary(os.path.dirname(self.vocab_file)) + def token_analysis(self,tokens): + """This method should be implemented by the Downstream tokenizers.""" + raise NotImplementedError + def save_vocabulary(self, save_directory, filename_prefix=None): """Save the vocabulary, ensures vocabularies are not overwritten. Filename follow the convention {index}-{filename_prefix}.json. Index keeps track of the latest vocabulary saved.""" index = 0 @@ -139,15 +144,20 @@ class SliceTokenizer(Xtal2txtTokenizer): def __init__(self, vocab_file=SLICE_VOCAB, model_max_length=None, padding_length=None, **kwargs): super(SliceTokenizer, self).__init__(vocab_file, model_max_length=model_max_length, padding_length=padding_length, **kwargs) - - + def token_analysis(self, list_of_tokens): + """Takes tokens after tokenize and returns a list with replacing the tokens with their MASK token. The + token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token.""" + analysis_masks = ANALYSIS_MASK_TOKENS + token_type = SLICE_ANALYSIS_DICT + return [analysis_masks[next((k for k, v in token_type.items() if token in v), None)] for token in list_of_tokens] + class CompositionTokenizer(Xtal2txtTokenizer): def __init__(self, vocab_file=COMPOSITION_VOCAB, model_max_length=None, padding_length=None, **kwargs): super(CompositionTokenizer, self).__init__(vocab_file, model_max_length=model_max_length, padding_length=padding_length, **kwargs) def convert_tokens_to_string(self, tokens): return ''.join(tokens) - + class CifTokenizer(Xtal2txtTokenizer): def __init__(self, vocab_file=CIF_VOCAB, model_max_length=None, padding_length=None, **kwargs): @@ -155,6 +165,13 @@ def __init__(self, vocab_file=CIF_VOCAB, model_max_length=None, padding_length=N def convert_tokens_to_string(self, tokens): return ''.join(tokens) + + def token_analysis(self, list_of_tokens): + """Takes tokens after tokenize and returns a list with replacing the tokens with their MASK token. The + token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token.""" + analysis_masks = ANALYSIS_MASK_TOKENS + token_type = CIF_ANALYSIS_DICT + return [analysis_masks[next((k for k, v in token_type.items() if token in v), None)] for token in list_of_tokens] class CrysllmTokenizer(Xtal2txtTokenizer): def __init__(self, vocab_file=CRYSTAL_LLM_VOCAB, model_max_length=None, padding_length=None, **kwargs): @@ -162,6 +179,13 @@ def __init__(self, vocab_file=CRYSTAL_LLM_VOCAB, model_max_length=None, padding_ def convert_tokens_to_string(self, tokens): return ''.join(tokens) + + def token_analysis(self, list_of_tokens): + """Takes tokens after tokenize and returns a list with replacing the tokens with their MASK token. The + token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token.""" + analysis_masks = ANALYSIS_MASK_TOKENS + token_type = CRYSTAL_LLM_ANALYSIS_DICT + return [analysis_masks[next((k for k, v in token_type.items() if token in v), None)] for token in list_of_tokens] class RobocrysTokenizer(): diff --git a/src/xtal2txt/vocabs/1.json b/src/xtal2txt/vocabs/1.json new file mode 100644 index 0000000..f5c1741 --- /dev/null +++ b/src/xtal2txt/vocabs/1.json @@ -0,0 +1 @@ +{"o o o": 0, "o o +": 1, "o o -": 2, "o + o": 3, "o + +": 4, "o + -": 5, "o - o": 6, "o - +": 7, "o - -": 8, "+ o o": 9, "+ o +": 10, "+ o -": 11, "+ + o": 12, "+ + +": 13, "+ + -": 14, "+ - o": 15, "+ - +": 16, "+ - -": 17, "- o o": 18, "- o +": 19, "- o -": 20, "- + o": 21, "- + +": 22, "- + -": 23, "- - o": 24, "- - +": 25, "- - -": 26, "H": 27, "He": 28, "Li": 29, "Be": 30, "B": 31, "C": 32, "N": 33, "O": 34, "F": 35, "Ne": 36, "Na": 37, "Mg": 38, "Al": 39, "Si": 40, "P": 41, "S": 42, "Cl": 43, "K": 44, "Ar": 45, "Ca": 46, "Sc": 47, "Ti": 48, "V": 49, "Cr": 50, "Mn": 51, "Fe": 52, "Ni": 53, "Co": 54, "Cu": 55, "Zn": 56, "Ga": 57, "Ge": 58, "As": 59, "Se": 60, "Br": 61, "Kr": 62, "Rb": 63, "Sr": 64, "Y": 65, "Zr": 66, "Nb": 67, "Mo": 68, "Tc": 69, "Ru": 70, "Rh": 71, "Pd": 72, "Ag": 73, "Cd": 74, "In": 75, "Sn": 76, "Sb": 77, "Te": 78, "I": 79, "Xe": 80, "Cs": 81, "Ba": 82, "La": 83, "Ce": 84, "Pr": 85, "Nd": 86, "Pm": 87, "Sm": 88, "Eu": 89, "Gd": 90, "Tb": 91, "Dy": 92, "Ho": 93, "Er": 94, "Tm": 95, "Yb": 96, "Lu": 97, "Hf": 98, "Ta": 99, "W": 100, "Re": 101, "Os": 102, "Ir": 103, "Pt": 104, "Au": 105, "Hg": 106, "Tl": 107, "Pb": 108, "Bi": 109, "Th": 110, "Pa": 111, "U": 112, "Np": 113, "Pu": 114, "Am": 115, "Cm": 116, "Bk": 117, "Cf": 118, "Es": 119, "Fm": 120, "Md": 121, "No": 122, "Lr": 123, "Rf": 124, "Db": 125, "Sg": 126, "Bh": 127, "Hs": 128, "Mt": 129, "Ds": 130, "Rg": 131, "Cn": 132, "Nh": 133, "Fl": 134, "Mc": 135, "Lv": 136, "Ts": 137, "Og": 138, "0": 139, "1": 140, "2": 141, "3": 142, "4": 143, "5": 144, "6": 145, "7": 146, "8": 147, "9": 148, "[UNK]": 149, "[PAD]": 150, "[CLS]": 151, "[SEP]": 152, "[MASK]": 153, "[EOS]": 154, "[BOS]": 155} \ No newline at end of file From 9e494d982513f1d1f8b3840a725d86196ebaa06a Mon Sep 17 00:00:00 2001 From: n0w0f Date: Thu, 21 Mar 2024 10:13:15 +0100 Subject: [PATCH 2/3] chore: add analysis tokenizer for comp --- src/xtal2txt/analysis.py | 5 +++++ src/xtal2txt/tokenizer.py | 9 ++++++++- src/xtal2txt/vocabs/1.json | 2 +- 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/xtal2txt/analysis.py b/src/xtal2txt/analysis.py index acb6426..0e63afe 100644 --- a/src/xtal2txt/analysis.py +++ b/src/xtal2txt/analysis.py @@ -27,6 +27,11 @@ "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", ".", "+", "-" ] +COMPOSITION_ANALYSIS_DICT = { + "atoms": ATOM_LIST_, + "numbers": NUMS_, +} + SLICE_ANALYSIS_DICT = { "atoms": ATOM_LIST_, "directions": [ diff --git a/src/xtal2txt/tokenizer.py b/src/xtal2txt/tokenizer.py index e77768f..5bead4e 100644 --- a/src/xtal2txt/tokenizer.py +++ b/src/xtal2txt/tokenizer.py @@ -7,7 +7,7 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from tokenizers import Tokenizer -from xtal2txt.analysis import ANALYSIS_MASK_TOKENS, SLICE_ANALYSIS_DICT, CRYSTAL_LLM_ANALYSIS_DICT, CIF_ANALYSIS_DICT +from xtal2txt.analysis import ANALYSIS_MASK_TOKENS, SLICE_ANALYSIS_DICT, CRYSTAL_LLM_ANALYSIS_DICT, CIF_ANALYSIS_DICT, COMPOSITION_ANALYSIS_DICT @@ -158,6 +158,13 @@ def __init__(self, vocab_file=COMPOSITION_VOCAB, model_max_length=None, padding_ def convert_tokens_to_string(self, tokens): return ''.join(tokens) + def token_analysis(self, list_of_tokens): + """Takes tokens after tokenize and returns a list with replacing the tokens with their MASK token. The + token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token.""" + analysis_masks = ANALYSIS_MASK_TOKENS + token_type = COMPOSITION_ANALYSIS_DICT + return [analysis_masks[next((k for k, v in token_type.items() if token in v), None)] for token in list_of_tokens] + class CifTokenizer(Xtal2txtTokenizer): def __init__(self, vocab_file=CIF_VOCAB, model_max_length=None, padding_length=None, **kwargs): diff --git a/src/xtal2txt/vocabs/1.json b/src/xtal2txt/vocabs/1.json index f5c1741..2b112a5 100644 --- a/src/xtal2txt/vocabs/1.json +++ b/src/xtal2txt/vocabs/1.json @@ -1 +1 @@ -{"o o o": 0, "o o +": 1, "o o -": 2, "o + o": 3, "o + +": 4, "o + -": 5, "o - o": 6, "o - +": 7, "o - -": 8, "+ o o": 9, "+ o +": 10, "+ o -": 11, "+ + o": 12, "+ + +": 13, "+ + -": 14, "+ - o": 15, "+ - +": 16, "+ - -": 17, "- o o": 18, "- o +": 19, "- o -": 20, "- + o": 21, "- + +": 22, "- + -": 23, "- - o": 24, "- - +": 25, "- - -": 26, "H": 27, "He": 28, "Li": 29, "Be": 30, "B": 31, "C": 32, "N": 33, "O": 34, "F": 35, "Ne": 36, "Na": 37, "Mg": 38, "Al": 39, "Si": 40, "P": 41, "S": 42, "Cl": 43, "K": 44, "Ar": 45, "Ca": 46, "Sc": 47, "Ti": 48, "V": 49, "Cr": 50, "Mn": 51, "Fe": 52, "Ni": 53, "Co": 54, "Cu": 55, "Zn": 56, "Ga": 57, "Ge": 58, "As": 59, "Se": 60, "Br": 61, "Kr": 62, "Rb": 63, "Sr": 64, "Y": 65, "Zr": 66, "Nb": 67, "Mo": 68, "Tc": 69, "Ru": 70, "Rh": 71, "Pd": 72, "Ag": 73, "Cd": 74, "In": 75, "Sn": 76, "Sb": 77, "Te": 78, "I": 79, "Xe": 80, "Cs": 81, "Ba": 82, "La": 83, "Ce": 84, "Pr": 85, "Nd": 86, "Pm": 87, "Sm": 88, "Eu": 89, "Gd": 90, "Tb": 91, "Dy": 92, "Ho": 93, "Er": 94, "Tm": 95, "Yb": 96, "Lu": 97, "Hf": 98, "Ta": 99, "W": 100, "Re": 101, "Os": 102, "Ir": 103, "Pt": 104, "Au": 105, "Hg": 106, "Tl": 107, "Pb": 108, "Bi": 109, "Th": 110, "Pa": 111, "U": 112, "Np": 113, "Pu": 114, "Am": 115, "Cm": 116, "Bk": 117, "Cf": 118, "Es": 119, "Fm": 120, "Md": 121, "No": 122, "Lr": 123, "Rf": 124, "Db": 125, "Sg": 126, "Bh": 127, "Hs": 128, "Mt": 129, "Ds": 130, "Rg": 131, "Cn": 132, "Nh": 133, "Fl": 134, "Mc": 135, "Lv": 136, "Ts": 137, "Og": 138, "0": 139, "1": 140, "2": 141, "3": 142, "4": 143, "5": 144, "6": 145, "7": 146, "8": 147, "9": 148, "[UNK]": 149, "[PAD]": 150, "[CLS]": 151, "[SEP]": 152, "[MASK]": 153, "[EOS]": 154, "[BOS]": 155} \ No newline at end of file +{"H": 0, "He": 1, "Li": 2, "Be": 3, "B": 4, "C": 5, "N": 6, "O": 7, "F": 8, "Ne": 9, "Na": 10, "Mg": 11, "Al": 12, "Si": 13, "P": 14, "S": 15, "Cl": 16, "K": 17, "Ar": 18, "Ca": 19, "Sc": 20, "Ti": 21, "V": 22, "Cr": 23, "Mn": 24, "Fe": 25, "Ni": 26, "Co": 27, "Cu": 28, "Zn": 29, "Ga": 30, "Ge": 31, "As": 32, "Se": 33, "Br": 34, "Kr": 35, "Rb": 36, "Sr": 37, "Y": 38, "Zr": 39, "Nb": 40, "Mo": 41, "Tc": 42, "Ru": 43, "Rh": 44, "Pd": 45, "Ag": 46, "Cd": 47, "In": 48, "Sn": 49, "Sb": 50, "Te": 51, "I": 52, "Xe": 53, "Cs": 54, "Ba": 55, "La": 56, "Ce": 57, "Pr": 58, "Nd": 59, "Pm": 60, "Sm": 61, "Eu": 62, "Gd": 63, "Tb": 64, "Dy": 65, "Ho": 66, "Er": 67, "Tm": 68, "Yb": 69, "Lu": 70, "Hf": 71, "Ta": 72, "W": 73, "Re": 74, "Os": 75, "Ir": 76, "Pt": 77, "Au": 78, "Hg": 79, "Tl": 80, "Pb": 81, "Bi": 82, "Th": 83, "Pa": 84, "U": 85, "Np": 86, "Pu": 87, "Am": 88, "Cm": 89, "Bk": 90, "Cf": 91, "Es": 92, "Fm": 93, "Md": 94, "No": 95, "Lr": 96, "Rf": 97, "Db": 98, "Sg": 99, "Bh": 100, "Hs": 101, "Mt": 102, "Ds": 103, "Rg": 104, "Cn": 105, "Nh": 106, "Fl": 107, "Mc": 108, "Lv": 109, "Ts": 110, "Og": 111, "0": 112, "1": 113, "2": 114, "3": 115, "4": 116, "5": 117, "6": 118, "7": 119, "8": 120, "9": 121, "data_": 122, "_symmetry_space_group_name_H-M": 123, "_cell_length_a": 124, "_cell_length_b": 125, "_cell_length_c": 126, "_cell_angle_alpha": 127, "_cell_angle_beta": 128, "_cell_angle_gamma": 129, "_symmetry_Int_Tables_number": 130, "_chemical_formula_structural": 131, "_chemical_formula_sum": 132, "_cell_volume": 133, "_cell_formula_units_Z": 134, "loop_": 135, "_symmetry_equiv_pos_site_id": 136, "_symmetry_equiv_pos_as_xyz": 137, "_atom_type_symbol": 138, "_atom_type_oxidation_number": 139, "_atom_site_type_symbol": 140, "_atom_site_label": 141, "_atom_site_symmetry_multiplicity": 142, "_atom_site_fract_x": 143, "_atom_site_fract_y": 144, "_atom_site_fract_z": 145, "_atom_site_occupancy": 146, " ": 147, ".": 148, "+": 149, "-": 150, "/": 151, "'": 152, "\"": 153, ",": 154, "'x, y, z'": 155, "x": 156, "y": 157, "z": 158, "-x": 159, "-y": 160, "-z": 161, " ": 162, " ": 163, "\n": 164, "_geom_bond_atom_site_label_1": 165, "_geom_bond_atom_site_label_2": 166, "_geom_bond_distance": 167, "_ccdc_geom_bond_type": 168, "_": 169, "a": 170, "n": 171, "c": 172, "b": 173, "m": 174, "d": 175, "R": 176, "A": 177, "(": 178, ")": 179, "[": 180, "]": 181, "*": 182, "[UNK]": 183, "[PAD]": 184, "[CLS]": 185, "[SEP]": 186, "[MASK]": 187, "[EOS]": 188, "[BOS]": 189} \ No newline at end of file From ab7eb7337adbe1abc1db152fa9b1cc5141e40841 Mon Sep 17 00:00:00 2001 From: n0w0f Date: Fri, 22 Mar 2024 13:53:01 +0100 Subject: [PATCH 3/3] fix: lint --- src/xtal2txt/analysis.py | 265 ++++++++++++++++++++++++++++++-------- src/xtal2txt/tokenizer.py | 194 +++++++++++++++++++--------- 2 files changed, 347 insertions(+), 112 deletions(-) diff --git a/src/xtal2txt/analysis.py b/src/xtal2txt/analysis.py index 0e63afe..0dc5915 100644 --- a/src/xtal2txt/analysis.py +++ b/src/xtal2txt/analysis.py @@ -1,31 +1,132 @@ - ANALYSIS_MASK_TOKENS = { - "atoms" : "[ATOMS]", - "directions" : "[DIR]", - "numbers" : "[NUMS]", - "bonds" : "[BONDS]", - "miscellaneous" : "[MISC]", - "identifier" : "[ID]", - "symmetry" : "[SYM]", - "lattice" : "[LATTICE]", - "composition" : "[COMP]", - None : "[NONE]" + "atoms": "[ATOMS]", + "directions": "[DIR]", + "numbers": "[NUMS]", + "bonds": "[BONDS]", + "miscellaneous": "[MISC]", + "identifier": "[ID]", + "symmetry": "[SYM]", + "lattice": "[LATTICE]", + "composition": "[COMP]", + None: "[NONE]", } ATOM_LIST_ = [ - "H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne", "Na", "Mg", "Al", "Si", "P", - "S", "Cl", "K", "Ar", "Ca", "Sc", "Ti", "V", "Cr", "Mn", "Fe", "Ni", "Co", "Cu", "Zn", - "Ga", "Ge", "As", "Se", "Br", "Kr", "Rb", "Sr", "Y", "Zr", "Nb", "Mo", "Tc", "Ru", "Rh", - "Pd", "Ag", "Cd", "In", "Sn", "Sb", "Te", "I", "Xe", "Cs", "Ba", "La", "Ce", "Pr", "Nd", - "Pm", "Sm", "Eu", "Gd", "Tb", "Dy", "Ho", "Er", "Tm", "Yb", "Lu", "Hf", "Ta", "W", "Re", - "Os", "Ir", "Pt", "Au", "Hg", "Tl", "Pb", "Bi", "Th", "Pa", "U", "Np", "Pu", "Am", "Cm", - "Bk", "Cf", "Es", "Fm", "Md", "No", "Lr", "Rf", "Db", "Sg", "Bh", "Hs", "Mt", "Ds", "Rg", - "Cn", "Nh", "Fl", "Mc", "Lv", "Ts", "Og" - ] + "H", + "He", + "Li", + "Be", + "B", + "C", + "N", + "O", + "F", + "Ne", + "Na", + "Mg", + "Al", + "Si", + "P", + "S", + "Cl", + "K", + "Ar", + "Ca", + "Sc", + "Ti", + "V", + "Cr", + "Mn", + "Fe", + "Ni", + "Co", + "Cu", + "Zn", + "Ga", + "Ge", + "As", + "Se", + "Br", + "Kr", + "Rb", + "Sr", + "Y", + "Zr", + "Nb", + "Mo", + "Tc", + "Ru", + "Rh", + "Pd", + "Ag", + "Cd", + "In", + "Sn", + "Sb", + "Te", + "I", + "Xe", + "Cs", + "Ba", + "La", + "Ce", + "Pr", + "Nd", + "Pm", + "Sm", + "Eu", + "Gd", + "Tb", + "Dy", + "Ho", + "Er", + "Tm", + "Yb", + "Lu", + "Hf", + "Ta", + "W", + "Re", + "Os", + "Ir", + "Pt", + "Au", + "Hg", + "Tl", + "Pb", + "Bi", + "Th", + "Pa", + "U", + "Np", + "Pu", + "Am", + "Cm", + "Bk", + "Cf", + "Es", + "Fm", + "Md", + "No", + "Lr", + "Rf", + "Db", + "Sg", + "Bh", + "Hs", + "Mt", + "Ds", + "Rg", + "Cn", + "Nh", + "Fl", + "Mc", + "Lv", + "Ts", + "Og", +] -NUMS_ = [ - "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", ".", "+", "-" - ] +NUMS_ = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", ".", "+", "-"] COMPOSITION_ANALYSIS_DICT = { "atoms": ATOM_LIST_, @@ -35,44 +136,106 @@ SLICE_ANALYSIS_DICT = { "atoms": ATOM_LIST_, "directions": [ - "o o o", "o o +", "o o -", "o + o", "o + +", "o + -", "o - o", "o - +", "o - -", - "+ o o", "+ o +", "+ o -", "+ + o", "+ + +", "+ + -", "+ - o", "+ - +", "+ - -", - "- o o", "- o +", "- o -", "- + o", "- + +", "- + -", "- - o", "- - +", "- - -" + "o o o", + "o o +", + "o o -", + "o + o", + "o + +", + "o + -", + "o - o", + "o - +", + "o - -", + "+ o o", + "+ o +", + "+ o -", + "+ + o", + "+ + +", + "+ + -", + "+ - o", + "+ - +", + "+ - -", + "- o o", + "- o +", + "- o -", + "- + o", + "- + +", + "- + -", + "- - o", + "- - +", + "- - -", ], - "numbers": NUMS_ + "numbers": NUMS_, } CRYSTAL_LLM_ANALYSIS_DICT = { "atoms": ATOM_LIST_, "numbers": NUMS_, - "miscellaneous": ["\n", " "] + "miscellaneous": ["\n", " "], } CIF_ANALYSIS_DICT = { "atoms": ATOM_LIST_, "numbers": NUMS_, - "lattice": ["_cell_length_a", "_cell_length_b", "_cell_length_c", - "_cell_angle_alpha", "_cell_angle_beta", "_cell_angle_gamma" ], + "lattice": [ + "_cell_length_a", + "_cell_length_b", + "_cell_length_c", + "_cell_angle_alpha", + "_cell_angle_beta", + "_cell_angle_gamma", + ], "identifier": ["loop_"], - "composition": [ - "_chemical_formula_structural", - "_chemical_formula_sum"], - "symmetry": ["_symmetry_space_group_name_H-M", "_symmetry_Int_Tables_number", - "_symmetry_equiv_pos_site_id","_symmetry_equiv_pos_as_xyz",], - "miscellaneous": ["_atom_site_symmetry_multiplicity" , - "_atom_type_symbol", - "_atom_type_oxidation_number", - "_atom_site_type_symbol", - "_atom_site_label", - "_atom_site_symmetry_multiplicity", - "_atom_site_fract_x", - "_atom_site_fract_y", - "_atom_site_fract_z", - "_atom_site_occupancy", - "-", "/", "'", "\"", ",", "'x, y, z'", "x", "y", "z", "-x", "-y", "-z", " ", " ", "\n", "_geom_bond_atom_site_label_1", "_geom_bond_atom_site_label_2", "_geom_bond_distance", "_ccdc_geom_bond_type", "_", "a", "n", "c", "b", "m", "d", "R", "A", "(", ")", "[", "]", "*" - ],} - - - - + "composition": ["_chemical_formula_structural", "_chemical_formula_sum"], + "symmetry": [ + "_symmetry_space_group_name_H-M", + "_symmetry_Int_Tables_number", + "_symmetry_equiv_pos_site_id", + "_symmetry_equiv_pos_as_xyz", + ], + "miscellaneous": [ + "_atom_site_symmetry_multiplicity", + "_atom_type_symbol", + "_atom_type_oxidation_number", + "_atom_site_type_symbol", + "_atom_site_label", + "_atom_site_symmetry_multiplicity", + "_atom_site_fract_x", + "_atom_site_fract_y", + "_atom_site_fract_z", + "_atom_site_occupancy", + "-", + "/", + "'", + '"', + ",", + "'x, y, z'", + "x", + "y", + "z", + "-x", + "-y", + "-z", + " ", + " ", + "\n", + "_geom_bond_atom_site_label_1", + "_geom_bond_atom_site_label_2", + "_geom_bond_distance", + "_ccdc_geom_bond_type", + "_", + "a", + "n", + "c", + "b", + "m", + "d", + "R", + "A", + "(", + ")", + "[", + "]", + "*", + ], +} diff --git a/src/xtal2txt/tokenizer.py b/src/xtal2txt/tokenizer.py index 5bead4e..4698b2c 100644 --- a/src/xtal2txt/tokenizer.py +++ b/src/xtal2txt/tokenizer.py @@ -1,15 +1,17 @@ +import json import os import re -import json -from pathlib import Path - - -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from tokenizers import Tokenizer -from xtal2txt.analysis import ANALYSIS_MASK_TOKENS, SLICE_ANALYSIS_DICT, CRYSTAL_LLM_ANALYSIS_DICT, CIF_ANALYSIS_DICT, COMPOSITION_ANALYSIS_DICT - +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from xtal2txt.analysis import ( + ANALYSIS_MASK_TOKENS, + CIF_ANALYSIS_DICT, + COMPOSITION_ANALYSIS_DICT, + CRYSTAL_LLM_ANALYSIS_DICT, + SLICE_ANALYSIS_DICT, +) THIS_DIR = os.path.dirname(os.path.abspath(__file__)) SLICE_VOCAB = os.path.join(THIS_DIR, "vocabs", "slice_vocab.txt") @@ -19,11 +21,14 @@ ROBOCRYS_VOCAB = os.path.join(THIS_DIR, "vocabs", "robocrys_vocab.json") - class Xtal2txtTokenizer(PreTrainedTokenizer): - def __init__(self, vocab_file, model_max_length=None, padding_length=None, **kwargs): - super(Xtal2txtTokenizer, self).__init__(model_max_length=model_max_length, **kwargs) - + def __init__( + self, vocab_file, model_max_length=None, padding_length=None, **kwargs + ): + super(Xtal2txtTokenizer, self).__init__( + model_max_length=model_max_length, **kwargs + ) + self.vocab = self.load_vocab(vocab_file) self.vocab_file = vocab_file self.truncation = False @@ -32,16 +37,16 @@ def __init__(self, vocab_file, model_max_length=None, padding_length=None, **kwa def load_vocab(self, vocab_file): _, file_extension = os.path.splitext(vocab_file) - if file_extension == '.txt': - with open(vocab_file, 'r', encoding='utf-8') as file: + if file_extension == ".txt": + with open(vocab_file, "r", encoding="utf-8") as file: vocab = file.read().splitlines() return {token: idx for idx, token in enumerate(vocab)} - elif file_extension == '.json': - with open(vocab_file, 'r', encoding='utf-8') as file: + elif file_extension == ".json": + with open(vocab_file, "r", encoding="utf-8") as file: return json.load(file) else: raise ValueError(f"Unsupported file type: {file_extension}") - + def get_vocab(self): return self.vocab @@ -50,12 +55,12 @@ def tokenize(self, text): string_tokens = [token for token in tokens if isinstance(token, str)] string_tokens.sort(key=len, reverse=True) escaped_tokens = [re.escape(token) for token in string_tokens] - pattern_str = '|'.join(escaped_tokens) + pattern_str = "|".join(escaped_tokens) pattern = re.compile(pattern_str) matches = pattern.findall(text) if self.truncation and len(matches) > self.model_max_length: - matches = matches[:self.model_max_length] + matches = matches[: self.model_max_length] if self.padding and len(matches) < self.padding_length: matches += [self.pad_token] * (self.padding_length - len(matches)) @@ -63,7 +68,7 @@ def tokenize(self, text): return matches def convert_tokens_to_string(self, tokens): - return ' '.join(tokens) + return " ".join(tokens) def _add_tokens(self, new_tokens, **kwargs): for token in new_tokens: @@ -97,7 +102,7 @@ def add_special_tokens(self, special_tokens): self.vocab[value] = len(self.vocab) self.save_vocabulary(os.path.dirname(self.vocab_file)) - def token_analysis(self,tokens): + def token_analysis(self, tokens): """This method should be implemented by the Downstream tokenizers.""" raise NotImplementedError @@ -105,16 +110,23 @@ def save_vocabulary(self, save_directory, filename_prefix=None): """Save the vocabulary, ensures vocabularies are not overwritten. Filename follow the convention {index}-{filename_prefix}.json. Index keeps track of the latest vocabulary saved.""" index = 0 if os.path.isdir(save_directory): - vocab_files = list(filter(lambda x: x.endswith(".json"), os.listdir(save_directory))) + vocab_files = list( + filter(lambda x: x.endswith(".json"), os.listdir(save_directory)) + ) for vocab_file in vocab_files: try: - index = max(index, int(vocab_file.split('-')[0])) + index = max(index, int(vocab_file.split("-")[0])) except ValueError: pass # Ignore files that do not start with an integer - vocab_file = os.path.join(save_directory, f"{index + 1}-{filename_prefix}.json" if filename_prefix else f"{index + 1}.json") + vocab_file = os.path.join( + save_directory, + f"{index + 1}-{filename_prefix}.json" + if filename_prefix + else f"{index + 1}.json", + ) - with open(vocab_file, 'w', encoding='utf-8') as f: + with open(vocab_file, "w", encoding="utf-8") as f: json.dump(self.vocab, f, ensure_ascii=False) return (vocab_file,) @@ -123,14 +135,21 @@ def save_vocabulary(self, save_directory, filename_prefix=None): def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): if pretrained_model_name_or_path is not None: if os.path.isdir(pretrained_model_name_or_path): - vocab_files = list(filter(lambda x: x.endswith(".json"), os.listdir(pretrained_model_name_or_path))) - vocab_files.sort(key=lambda x: int(x.split('-')[0])) - vocab_file = os.path.join(pretrained_model_name_or_path, vocab_files[-1]) + vocab_files = list( + filter( + lambda x: x.endswith(".json"), + os.listdir(pretrained_model_name_or_path), + ) + ) + vocab_files.sort(key=lambda x: int(x.split("-")[0])) + vocab_file = os.path.join( + pretrained_model_name_or_path, vocab_files[-1] + ) if vocab_file is None: raise ValueError("You should specify a path to a vocab file") - with open(vocab_file, 'r', encoding='utf-8') as f: + with open(vocab_file, "r", encoding="utf-8") as f: vocab = json.load(f) tokenizer = cls(vocab_file, *inputs, **kwargs) @@ -139,72 +158,125 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): return tokenizer - class SliceTokenizer(Xtal2txtTokenizer): - def __init__(self, vocab_file=SLICE_VOCAB, model_max_length=None, padding_length=None, **kwargs): - super(SliceTokenizer, self).__init__(vocab_file, model_max_length=model_max_length, padding_length=padding_length, **kwargs) + def __init__( + self, + vocab_file=SLICE_VOCAB, + model_max_length=None, + padding_length=None, + **kwargs, + ): + super(SliceTokenizer, self).__init__( + vocab_file, + model_max_length=model_max_length, + padding_length=padding_length, + **kwargs, + ) def token_analysis(self, list_of_tokens): """Takes tokens after tokenize and returns a list with replacing the tokens with their MASK token. The - token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token.""" + token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token.""" analysis_masks = ANALYSIS_MASK_TOKENS token_type = SLICE_ANALYSIS_DICT - return [analysis_masks[next((k for k, v in token_type.items() if token in v), None)] for token in list_of_tokens] - + return [ + analysis_masks[next((k for k, v in token_type.items() if token in v), None)] + for token in list_of_tokens + ] + + class CompositionTokenizer(Xtal2txtTokenizer): - def __init__(self, vocab_file=COMPOSITION_VOCAB, model_max_length=None, padding_length=None, **kwargs): - super(CompositionTokenizer, self).__init__(vocab_file, model_max_length=model_max_length, padding_length=padding_length, **kwargs) + def __init__( + self, + vocab_file=COMPOSITION_VOCAB, + model_max_length=None, + padding_length=None, + **kwargs, + ): + super(CompositionTokenizer, self).__init__( + vocab_file, + model_max_length=model_max_length, + padding_length=padding_length, + **kwargs, + ) def convert_tokens_to_string(self, tokens): - return ''.join(tokens) - + return "".join(tokens) + def token_analysis(self, list_of_tokens): """Takes tokens after tokenize and returns a list with replacing the tokens with their MASK token. The - token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token.""" + token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token.""" analysis_masks = ANALYSIS_MASK_TOKENS token_type = COMPOSITION_ANALYSIS_DICT - return [analysis_masks[next((k for k, v in token_type.items() if token in v), None)] for token in list_of_tokens] - + return [ + analysis_masks[next((k for k, v in token_type.items() if token in v), None)] + for token in list_of_tokens + ] + class CifTokenizer(Xtal2txtTokenizer): - def __init__(self, vocab_file=CIF_VOCAB, model_max_length=None, padding_length=None, **kwargs): - super(CifTokenizer, self).__init__(vocab_file, model_max_length=model_max_length, padding_length=padding_length, **kwargs) + def __init__( + self, vocab_file=CIF_VOCAB, model_max_length=None, padding_length=None, **kwargs + ): + super(CifTokenizer, self).__init__( + vocab_file, + model_max_length=model_max_length, + padding_length=padding_length, + **kwargs, + ) def convert_tokens_to_string(self, tokens): - return ''.join(tokens) - + return "".join(tokens) + def token_analysis(self, list_of_tokens): """Takes tokens after tokenize and returns a list with replacing the tokens with their MASK token. The - token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token.""" + token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token.""" analysis_masks = ANALYSIS_MASK_TOKENS token_type = CIF_ANALYSIS_DICT - return [analysis_masks[next((k for k, v in token_type.items() if token in v), None)] for token in list_of_tokens] + return [ + analysis_masks[next((k for k, v in token_type.items() if token in v), None)] + for token in list_of_tokens + ] + class CrysllmTokenizer(Xtal2txtTokenizer): - def __init__(self, vocab_file=CRYSTAL_LLM_VOCAB, model_max_length=None, padding_length=None, **kwargs): - super(CrysllmTokenizer, self).__init__(vocab_file, model_max_length=model_max_length, padding_length=padding_length, **kwargs) + def __init__( + self, + vocab_file=CRYSTAL_LLM_VOCAB, + model_max_length=None, + padding_length=None, + **kwargs, + ): + super(CrysllmTokenizer, self).__init__( + vocab_file, + model_max_length=model_max_length, + padding_length=padding_length, + **kwargs, + ) def convert_tokens_to_string(self, tokens): - return ''.join(tokens) - + return "".join(tokens) + def token_analysis(self, list_of_tokens): """Takes tokens after tokenize and returns a list with replacing the tokens with their MASK token. The - token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token.""" + token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token.""" analysis_masks = ANALYSIS_MASK_TOKENS token_type = CRYSTAL_LLM_ANALYSIS_DICT - return [analysis_masks[next((k for k, v in token_type.items() if token in v), None)] for token in list_of_tokens] + return [ + analysis_masks[next((k for k, v in token_type.items() if token in v), None)] + for token in list_of_tokens + ] -class RobocrysTokenizer(): +class RobocrysTokenizer: """Tokenizer for Robocrystallographer. Would be BPE tokenizer. trained on the Robocrystallographer dataset. TODO: Implement this tokenizer. """ - def __init__(self, vocab_file=ROBOCRYS_VOCAB, **kwargs): - tokenizer = Tokenizer.from_file(vocab_file) - wrapped_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer) - self._tokenizer = wrapped_tokenizer + def __init__(self, vocab_file=ROBOCRYS_VOCAB, **kwargs): + tokenizer = Tokenizer.from_file(vocab_file) + wrapped_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer) + self._tokenizer = wrapped_tokenizer def tokenize(self, text): return self._tokenizer.tokenize(text) @@ -216,6 +288,6 @@ def decode(self, token_ids, skip_special_tokens=True): # Check if token_ids is a string and convert it to a list of integers if isinstance(token_ids, str): token_ids = [int(token_ids)] - return self._tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) - - + return self._tokenizer.decode( + token_ids, skip_special_tokens=skip_special_tokens + )