Skip to content

Commit

Permalink
chore: add analysis tokenizer for comp
Browse files Browse the repository at this point in the history
  • Loading branch information
n0w0f committed Mar 21, 2024
1 parent a6bde57 commit 9e494d9
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 2 deletions.
5 changes: 5 additions & 0 deletions src/xtal2txt/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@
"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", ".", "+", "-"
]

COMPOSITION_ANALYSIS_DICT = {
"atoms": ATOM_LIST_,
"numbers": NUMS_,
}

SLICE_ANALYSIS_DICT = {
"atoms": ATOM_LIST_,
"directions": [
Expand Down
9 changes: 8 additions & 1 deletion src/xtal2txt/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast

from tokenizers import Tokenizer
from xtal2txt.analysis import ANALYSIS_MASK_TOKENS, SLICE_ANALYSIS_DICT, CRYSTAL_LLM_ANALYSIS_DICT, CIF_ANALYSIS_DICT
from xtal2txt.analysis import ANALYSIS_MASK_TOKENS, SLICE_ANALYSIS_DICT, CRYSTAL_LLM_ANALYSIS_DICT, CIF_ANALYSIS_DICT, COMPOSITION_ANALYSIS_DICT



Expand Down Expand Up @@ -158,6 +158,13 @@ def __init__(self, vocab_file=COMPOSITION_VOCAB, model_max_length=None, padding_
def convert_tokens_to_string(self, tokens):
return ''.join(tokens)

def token_analysis(self, list_of_tokens):
"""Takes tokens after tokenize and returns a list with replacing the tokens with their MASK token. The
token type is determined from the dict declared globally, and the token is replaced with the corresponding MASK token."""
analysis_masks = ANALYSIS_MASK_TOKENS
token_type = COMPOSITION_ANALYSIS_DICT
return [analysis_masks[next((k for k, v in token_type.items() if token in v), None)] for token in list_of_tokens]


class CifTokenizer(Xtal2txtTokenizer):
def __init__(self, vocab_file=CIF_VOCAB, model_max_length=None, padding_length=None, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion src/xtal2txt/vocabs/1.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"o o o": 0, "o o +": 1, "o o -": 2, "o + o": 3, "o + +": 4, "o + -": 5, "o - o": 6, "o - +": 7, "o - -": 8, "+ o o": 9, "+ o +": 10, "+ o -": 11, "+ + o": 12, "+ + +": 13, "+ + -": 14, "+ - o": 15, "+ - +": 16, "+ - -": 17, "- o o": 18, "- o +": 19, "- o -": 20, "- + o": 21, "- + +": 22, "- + -": 23, "- - o": 24, "- - +": 25, "- - -": 26, "H": 27, "He": 28, "Li": 29, "Be": 30, "B": 31, "C": 32, "N": 33, "O": 34, "F": 35, "Ne": 36, "Na": 37, "Mg": 38, "Al": 39, "Si": 40, "P": 41, "S": 42, "Cl": 43, "K": 44, "Ar": 45, "Ca": 46, "Sc": 47, "Ti": 48, "V": 49, "Cr": 50, "Mn": 51, "Fe": 52, "Ni": 53, "Co": 54, "Cu": 55, "Zn": 56, "Ga": 57, "Ge": 58, "As": 59, "Se": 60, "Br": 61, "Kr": 62, "Rb": 63, "Sr": 64, "Y": 65, "Zr": 66, "Nb": 67, "Mo": 68, "Tc": 69, "Ru": 70, "Rh": 71, "Pd": 72, "Ag": 73, "Cd": 74, "In": 75, "Sn": 76, "Sb": 77, "Te": 78, "I": 79, "Xe": 80, "Cs": 81, "Ba": 82, "La": 83, "Ce": 84, "Pr": 85, "Nd": 86, "Pm": 87, "Sm": 88, "Eu": 89, "Gd": 90, "Tb": 91, "Dy": 92, "Ho": 93, "Er": 94, "Tm": 95, "Yb": 96, "Lu": 97, "Hf": 98, "Ta": 99, "W": 100, "Re": 101, "Os": 102, "Ir": 103, "Pt": 104, "Au": 105, "Hg": 106, "Tl": 107, "Pb": 108, "Bi": 109, "Th": 110, "Pa": 111, "U": 112, "Np": 113, "Pu": 114, "Am": 115, "Cm": 116, "Bk": 117, "Cf": 118, "Es": 119, "Fm": 120, "Md": 121, "No": 122, "Lr": 123, "Rf": 124, "Db": 125, "Sg": 126, "Bh": 127, "Hs": 128, "Mt": 129, "Ds": 130, "Rg": 131, "Cn": 132, "Nh": 133, "Fl": 134, "Mc": 135, "Lv": 136, "Ts": 137, "Og": 138, "0": 139, "1": 140, "2": 141, "3": 142, "4": 143, "5": 144, "6": 145, "7": 146, "8": 147, "9": 148, "[UNK]": 149, "[PAD]": 150, "[CLS]": 151, "[SEP]": 152, "[MASK]": 153, "[EOS]": 154, "[BOS]": 155}
{"H": 0, "He": 1, "Li": 2, "Be": 3, "B": 4, "C": 5, "N": 6, "O": 7, "F": 8, "Ne": 9, "Na": 10, "Mg": 11, "Al": 12, "Si": 13, "P": 14, "S": 15, "Cl": 16, "K": 17, "Ar": 18, "Ca": 19, "Sc": 20, "Ti": 21, "V": 22, "Cr": 23, "Mn": 24, "Fe": 25, "Ni": 26, "Co": 27, "Cu": 28, "Zn": 29, "Ga": 30, "Ge": 31, "As": 32, "Se": 33, "Br": 34, "Kr": 35, "Rb": 36, "Sr": 37, "Y": 38, "Zr": 39, "Nb": 40, "Mo": 41, "Tc": 42, "Ru": 43, "Rh": 44, "Pd": 45, "Ag": 46, "Cd": 47, "In": 48, "Sn": 49, "Sb": 50, "Te": 51, "I": 52, "Xe": 53, "Cs": 54, "Ba": 55, "La": 56, "Ce": 57, "Pr": 58, "Nd": 59, "Pm": 60, "Sm": 61, "Eu": 62, "Gd": 63, "Tb": 64, "Dy": 65, "Ho": 66, "Er": 67, "Tm": 68, "Yb": 69, "Lu": 70, "Hf": 71, "Ta": 72, "W": 73, "Re": 74, "Os": 75, "Ir": 76, "Pt": 77, "Au": 78, "Hg": 79, "Tl": 80, "Pb": 81, "Bi": 82, "Th": 83, "Pa": 84, "U": 85, "Np": 86, "Pu": 87, "Am": 88, "Cm": 89, "Bk": 90, "Cf": 91, "Es": 92, "Fm": 93, "Md": 94, "No": 95, "Lr": 96, "Rf": 97, "Db": 98, "Sg": 99, "Bh": 100, "Hs": 101, "Mt": 102, "Ds": 103, "Rg": 104, "Cn": 105, "Nh": 106, "Fl": 107, "Mc": 108, "Lv": 109, "Ts": 110, "Og": 111, "0": 112, "1": 113, "2": 114, "3": 115, "4": 116, "5": 117, "6": 118, "7": 119, "8": 120, "9": 121, "data_": 122, "_symmetry_space_group_name_H-M": 123, "_cell_length_a": 124, "_cell_length_b": 125, "_cell_length_c": 126, "_cell_angle_alpha": 127, "_cell_angle_beta": 128, "_cell_angle_gamma": 129, "_symmetry_Int_Tables_number": 130, "_chemical_formula_structural": 131, "_chemical_formula_sum": 132, "_cell_volume": 133, "_cell_formula_units_Z": 134, "loop_": 135, "_symmetry_equiv_pos_site_id": 136, "_symmetry_equiv_pos_as_xyz": 137, "_atom_type_symbol": 138, "_atom_type_oxidation_number": 139, "_atom_site_type_symbol": 140, "_atom_site_label": 141, "_atom_site_symmetry_multiplicity": 142, "_atom_site_fract_x": 143, "_atom_site_fract_y": 144, "_atom_site_fract_z": 145, "_atom_site_occupancy": 146, " ": 147, ".": 148, "+": 149, "-": 150, "/": 151, "'": 152, "\"": 153, ",": 154, "'x, y, z'": 155, "x": 156, "y": 157, "z": 158, "-x": 159, "-y": 160, "-z": 161, " ": 162, " ": 163, "\n": 164, "_geom_bond_atom_site_label_1": 165, "_geom_bond_atom_site_label_2": 166, "_geom_bond_distance": 167, "_ccdc_geom_bond_type": 168, "_": 169, "a": 170, "n": 171, "c": 172, "b": 173, "m": 174, "d": 175, "R": 176, "A": 177, "(": 178, ")": 179, "[": 180, "]": 181, "*": 182, "[UNK]": 183, "[PAD]": 184, "[CLS]": 185, "[SEP]": 186, "[MASK]": 187, "[EOS]": 188, "[BOS]": 189}

0 comments on commit 9e494d9

Please sign in to comment.