diff --git a/notebooks/development.ipynb b/notebooks/development.ipynb index 69fef04..8e68041 100644 --- a/notebooks/development.ipynb +++ b/notebooks/development.ipynb @@ -186,9 +186,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "2024-04-20 15:16:52.952786: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", - "2024-04-20 15:16:52.953610: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", - "2024-04-20 15:16:53.256036: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" + "2024-04-20 21:20:11.710080: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", + "2024-04-20 21:20:11.710904: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", + "2024-04-20 21:20:11.920831: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" ] }, { @@ -204,8 +204,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "Reading file /home/so87pot/miniconda3/envs/xtal2/lib/python3.9/site-packages/robocrys/condense/mineral_db.json.gz: 0it [00:00, ?it/s]#####7| 176/180 [00:00<00:00, 615.86it/s]\n", - "Decoding objects from /home/so87pot/miniconda3/envs/xtal2/lib/python3.9/site-packages/robocrys/condense/mineral_db.json.gz: 100%|##########| 180/180 [00:00<00:00, 572.87it/s]\n" + "Reading file /home/so87pot/miniconda3/envs/xtal2/lib/python3.9/site-packages/robocrys/condense/mineral_db.json.gz: 0it [00:00, ?it/s]##3 | 114/180 [00:00<00:00, 604.30it/s]\n", + "Decoding objects from /home/so87pot/miniconda3/envs/xtal2/lib/python3.9/site-packages/robocrys/condense/mineral_db.json.gz: 100%|##########| 180/180 [00:00<00:00, 593.53it/s]\n" ] } ], @@ -545,6 +545,769 @@ " output_var.append(bond_var)\n", " return \"\\n\".join(output) + \"\\n\\n\" + \"\\n\".join(output_var)\n" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "from xtal2txt.tokenizer import SliceTokenizer" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "#path= \"/work/so87pot/material_db/qmof_text/qmof_filtered_text.json\"\n", + "path=\"/work/so87pot/material_db/qmof_text/bandgap/train.json\"\n", + "ds = load_dataset(\"json\", data_files=path,split=\"train\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Dataset({\n", + " features: ['composition', 'pld', 'crystal_llm_rep', 'natoms', 'atoms_params', 'lcd', 'id', 'slice', 'cif_symmetrized', 'volume', 'atoms', 'labels', 'zmatrix', 'cif_p1', 'density'],\n", + " num_rows: 8600\n", + "})" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "_DEFAULT_SPECIAL_TOKENS = {\n", + " \"unk_token\": \"[UNK]\",\n", + " \"pad_token\": \"[PAD]\",\n", + " \"cls_token\": \"[CLS]\",\n", + " \"sep_token\": \"[SEP]\",\n", + " \"mask_token\": \"[MASK]\",\n", + " \"eos_token\": \"[EOS]\",\n", + " \"bos_token\": \"[BOS]\",\n", + "}\n", + "tokenizer = SliceTokenizer(\n", + " model_max_length=512, truncation=True, padding=\"max_length\", special_tokens=_DEFAULT_SPECIAL_TOKENS, max_length=512\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def _tokenize( examples):\n", + " # Tokenize the 'crystal_llm' column using the LAMA tokenizer\n", + " tokenized_examples = tokenizer(\n", + " examples[\"slice\"], truncation=True, padding=True\n", + " )\n", + "\n", + " # Convert input_ids list to tensor and then to torch.float16\n", + " # tokenized_examples[\"input_ids\"] = torch.tensor(tokenized_examples[\"input_ids\"]).to(torch.float16)\n", + " # tokenized_examples[\"attention_mask\"] = torch.tensor(tokenized_examples[\"attention_mask\"]).to(torch.float16)\n", + " \n", + " # # Check if \"labels\" key exists before trying to convert it to float16\n", + " # if \"labels\" in tokenized_examples:\n", + " # tokenized_examples[\"labels\"] = torch.tensor(tokenized_examples[\"labels\"]).to(torch.float16)\n", + "\n", + " return tokenized_examples" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Parameter 'function'= of the transform datasets.arrow_dataset.Dataset._map_single couldn't be hashed properly, a random hash was used instead. Make sure your transforms and parameters are serializable with pickle or dill for the dataset fingerprinting and caching to work. If you reuse this transform, the caching mechanism will consider it to be different from the previous calls and recompute everything. This warning is only showed once. Subsequent hashing failures won't be showed.\n", + "Map: 100%|██████████| 8600/8600 [00:07<00:00, 1143.87 examples/s]\n" + ] + } + ], + "source": [ + "dataset = ds.map(_tokenize, batched=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Dataset({\n", + " features: ['composition', 'pld', 'crystal_llm_rep', 'natoms', 'atoms_params', 'lcd', 'id', 'slice', 'cif_symmetrized', 'volume', 'atoms', 'labels', 'zmatrix', 'cif_p1', 'density', 'input_ids', 'token_type_ids', 'attention_mask'],\n", + " num_rows: 8600\n", + "})" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'Cu Cu H H H H H H H H H H H H H H H H H H H H C C C C C C C C C C C C C C C C C C C C C C C C N N N N N N N N O O O O O O O O O O O O 0 48 - - o 0 53 - o o 0 51 o - + 0 46 o o + 1 50 o o o 1 47 o o o 1 49 o o o 1 52 o o o 2 42 o - o 3 43 o o o 4 44 o + o 5 45 o o o 6 38 o o o 7 39 o o o 8 40 o o o 9 41 o o o 10 38 o o o 11 39 o o o 12 40 o o o 13 41 o o o 14 42 o o o 15 43 o o o 16 44 o o o 17 45 o o o 18 62 o o o 19 63 o o o 20 64 o o o 21 65 o o o 22 46 o o o 22 30 o o o 23 31 o o o 23 47 o o o 24 32 o o o 24 48 o o o 25 49 o o o 25 33 o o o 26 30 o o o 26 50 o o o 27 51 o o o 27 31 o o o 28 52 o o o 28 32 o o o 29 33 o o o 29 53 o o o 30 34 o - o 31 35 o o - 32 36 o + o 33 37 o o + 34 54 o o o 34 58 o o o 35 59 o o o 35 55 o o + 36 60 o o o 36 56 o o o 37 57 o o - 37 61 o o o 38 58 o o o 39 59 o o o 40 60 o o o 41 61 o o o '" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds['slice'][0]" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "352" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(tokenizer.tokenize(ds['slice'][0]))" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "tokens = dataset['input_ids'][0]" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[151,\n", + " 55,\n", + " 55,\n", + " 27,\n", + " 27,\n", + " 27,\n", + " 27,\n", + " 27,\n", + " 27,\n", + " 27,\n", + " 27,\n", + " 27,\n", + " 27,\n", + " 27,\n", + " 27,\n", + " 27,\n", + " 27,\n", + " 27,\n", + " 27,\n", + " 27,\n", + " 27,\n", + " 27,\n", + " 27,\n", + " 32,\n", + " 32,\n", + " 32,\n", + " 32,\n", + " 32,\n", + " 32,\n", + " 32,\n", + " 32,\n", + " 32,\n", + " 32,\n", + " 32,\n", + " 32,\n", + " 32,\n", + " 32,\n", + " 32,\n", + " 32,\n", + " 32,\n", + " 32,\n", + " 32,\n", + " 32,\n", + " 32,\n", + " 32,\n", + " 32,\n", + " 32,\n", + " 33,\n", + " 33,\n", + " 33,\n", + " 33,\n", + " 33,\n", + " 33,\n", + " 33,\n", + " 33,\n", + " 34,\n", + " 34,\n", + " 34,\n", + " 34,\n", + " 34,\n", + " 34,\n", + " 34,\n", + " 34,\n", + " 34,\n", + " 34,\n", + " 34,\n", + " 34,\n", + " 139,\n", + " 143,\n", + " 147,\n", + " 24,\n", + " 139,\n", + " 144,\n", + " 142,\n", + " 18,\n", + " 139,\n", + " 144,\n", + " 140,\n", + " 7,\n", + " 139,\n", + " 143,\n", + " 145,\n", + " 1,\n", + " 140,\n", + " 144,\n", + " 139,\n", + " 0,\n", + " 140,\n", + " 143,\n", + " 146,\n", + " 0,\n", + " 140,\n", + " 143,\n", + " 148,\n", + " 0,\n", + " 140,\n", + " 144,\n", + " 141,\n", + " 0,\n", + " 141,\n", + " 143,\n", + " 141,\n", + " 6,\n", + " 142,\n", + " 143,\n", + " 142,\n", + " 0,\n", + " 143,\n", + " 143,\n", + " 143,\n", + " 3,\n", + " 144,\n", + " 143,\n", + " 144,\n", + " 0,\n", + " 145,\n", + " 142,\n", + " 147,\n", + " 0,\n", + " 146,\n", + " 142,\n", + " 148,\n", + " 0,\n", + " 147,\n", + " 143,\n", + " 139,\n", + " 0,\n", + " 148,\n", + " 143,\n", + " 140,\n", + " 0,\n", + " 140,\n", + " 139,\n", + " 142,\n", + " 147,\n", + " 0,\n", + " 140,\n", + " 140,\n", + " 142,\n", + " 148,\n", + " 0,\n", + " 140,\n", + " 141,\n", + " 143,\n", + " 139,\n", + " 0,\n", + " 140,\n", + " 142,\n", + " 143,\n", + " 140,\n", + " 0,\n", + " 140,\n", + " 143,\n", + " 143,\n", + " 141,\n", + " 0,\n", + " 140,\n", + " 144,\n", + " 143,\n", + " 142,\n", + " 0,\n", + " 140,\n", + " 145,\n", + " 143,\n", + " 143,\n", + " 0,\n", + " 140,\n", + " 146,\n", + " 143,\n", + " 144,\n", + " 0,\n", + " 140,\n", + " 147,\n", + " 145,\n", + " 141,\n", + " 0,\n", + " 140,\n", + " 148,\n", + " 145,\n", + " 142,\n", + " 0,\n", + " 141,\n", + " 139,\n", + " 145,\n", + " 143,\n", + " 0,\n", + " 141,\n", + " 140,\n", + " 145,\n", + " 144,\n", + " 0,\n", + " 141,\n", + " 141,\n", + " 143,\n", + " 145,\n", + " 0,\n", + " 141,\n", + " 141,\n", + " 142,\n", + " 139,\n", + " 0,\n", + " 141,\n", + " 142,\n", + " 142,\n", + " 140,\n", + " 0,\n", + " 141,\n", + " 142,\n", + " 143,\n", + " 146,\n", + " 0,\n", + " 141,\n", + " 143,\n", + " 142,\n", + " 141,\n", + " 0,\n", + " 141,\n", + " 143,\n", + " 143,\n", + " 147,\n", + " 0,\n", + " 141,\n", + " 144,\n", + " 143,\n", + " 148,\n", + " 0,\n", + " 141,\n", + " 144,\n", + " 142,\n", + " 142,\n", + " 0,\n", + " 141,\n", + " 145,\n", + " 142,\n", + " 139,\n", + " 0,\n", + " 141,\n", + " 145,\n", + " 144,\n", + " 139,\n", + " 0,\n", + " 141,\n", + " 146,\n", + " 144,\n", + " 140,\n", + " 0,\n", + " 141,\n", + " 146,\n", + " 142,\n", + " 140,\n", + " 0,\n", + " 141,\n", + " 147,\n", + " 144,\n", + " 141,\n", + " 0,\n", + " 141,\n", + " 147,\n", + " 142,\n", + " 141,\n", + " 0,\n", + " 141,\n", + " 148,\n", + " 142,\n", + " 142,\n", + " 0,\n", + " 141,\n", + " 148,\n", + " 144,\n", + " 142,\n", + " 0,\n", + " 142,\n", + " 139,\n", + " 142,\n", + " 143,\n", + " 6,\n", + " 142,\n", + " 140,\n", + " 142,\n", + " 144,\n", + " 2,\n", + " 142,\n", + " 141,\n", + " 142,\n", + " 145,\n", + " 3,\n", + " 142,\n", + " 142,\n", + " 142,\n", + " 146,\n", + " 1,\n", + " 142,\n", + " 143,\n", + " 144,\n", + " 143,\n", + " 0,\n", + " 142,\n", + " 143,\n", + " 144,\n", + " 147,\n", + " 0,\n", + " 142,\n", + " 144,\n", + " 144,\n", + " 148,\n", + " 0,\n", + " 142,\n", + " 144,\n", + " 144,\n", + " 144,\n", + " 1,\n", + " 142,\n", + " 145,\n", + " 145,\n", + " 139,\n", + " 0,\n", + " 142,\n", + " 145,\n", + " 144,\n", + " 145,\n", + " 0,\n", + " 142,\n", + " 146,\n", + " 144,\n", + " 146,\n", + " 2,\n", + " 142,\n", + " 146,\n", + " 145,\n", + " 140,\n", + " 0,\n", + " 142,\n", + " 147,\n", + " 144,\n", + " 147,\n", + " 0,\n", + " 142,\n", + " 148,\n", + " 144,\n", + " 148,\n", + " 0,\n", + " 143,\n", + " 139,\n", + " 145,\n", + " 139,\n", + " 0,\n", + " 143,\n", + " 140,\n", + " 145,\n", + " 140,\n", + " 0,\n", + " 152,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150,\n", + " 150]" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokens" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "512" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(tokens)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'[CLS] Cu Cu H H H H H H H H H H H H H H H H H H H H C C C C C C C C C C C C C C C C C C C C C C C C N N N N N N N N O O O O O O O O O O O O 0 4 8 - - o 0 5 3 - o o 0 5 1 o - + 0 4 6 o o + 1 5 0 o o o 1 4 7 o o o 1 4 9 o o o 1 5 2 o o o 2 4 2 o - o 3 4 3 o o o 4 4 4 o + o 5 4 5 o o o 6 3 8 o o o 7 3 9 o o o 8 4 0 o o o 9 4 1 o o o 1 0 3 8 o o o 1 1 3 9 o o o 1 2 4 0 o o o 1 3 4 1 o o o 1 4 4 2 o o o 1 5 4 3 o o o 1 6 4 4 o o o 1 7 4 5 o o o 1 8 6 2 o o o 1 9 6 3 o o o 2 0 6 4 o o o 2 1 6 5 o o o 2 2 4 6 o o o 2 2 3 0 o o o 2 3 3 1 o o o 2 3 4 7 o o o 2 4 3 2 o o o 2 4 4 8 o o o 2 5 4 9 o o o 2 5 3 3 o o o 2 6 3 0 o o o 2 6 5 0 o o o 2 7 5 1 o o o 2 7 3 1 o o o 2 8 5 2 o o o 2 8 3 2 o o o 2 9 3 3 o o o 2 9 5 3 o o o 3 0 3 4 o - o 3 1 3 5 o o - 3 2 3 6 o + o 3 3 3 7 o o + 3 4 5 4 o o o 3 4 5 8 o o o 3 5 5 9 o o o 3 5 5 5 o o + 3 6 6 0 o o o 3 6 5 6 o o o 3 7 5 7 o o - 3 7 6 1 o o o 3 8 5 8 o o o 3 9 5 9 o o o 4 0 6 0 o o o 4 1 6 1 o o o [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokenizer.decode(tokens)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/src/xtal2txt/tokenizer.py b/src/xtal2txt/tokenizer.py index 4698b2c..cfdcb76 100644 --- a/src/xtal2txt/tokenizer.py +++ b/src/xtal2txt/tokenizer.py @@ -23,7 +23,7 @@ class Xtal2txtTokenizer(PreTrainedTokenizer): def __init__( - self, vocab_file, model_max_length=None, padding_length=None, **kwargs + self, vocab_file, special_tokens=None, model_max_length=None, padding_length=None, **kwargs ): super(Xtal2txtTokenizer, self).__init__( model_max_length=model_max_length, **kwargs @@ -34,6 +34,10 @@ def __init__( self.truncation = False self.padding = False self.padding_length = padding_length + + # Initialize special tokens + self.special_tokens = special_tokens if special_tokens is not None else {} + self.add_special_tokens(self.special_tokens) def load_vocab(self, vocab_file): _, file_extension = os.path.splitext(vocab_file) @@ -59,8 +63,15 @@ def tokenize(self, text): pattern = re.compile(pattern_str) matches = pattern.findall(text) + # Add [CLS] and [SEP] tokens if present in the vocabulary + if self.cls_token is not None: + matches = [self.cls_token] + matches + if self.truncation and len(matches) > self.model_max_length: - matches = matches[: self.model_max_length] + matches = matches[: self.model_max_length-1] # -1 since we add sep token later + + if self.sep_token is not None: + matches += [self.sep_token] if self.padding and len(matches) < self.padding_length: matches += [self.pad_token] * (self.padding_length - len(matches)) @@ -100,6 +111,16 @@ def add_special_tokens(self, special_tokens): if value not in self.vocab: setattr(self, token, value) self.vocab[value] = len(self.vocab) + + # Ensure [CLS] and [SEP] tokens are added + cls_token = special_tokens.get("cls_token", None) + sep_token = special_tokens.get("sep_token", None) + if cls_token is not None and cls_token not in self.vocab: + setattr(self, "cls_token", cls_token) + self.vocab[cls_token] = len(self.vocab) + if sep_token is not None and sep_token not in self.vocab: + setattr(self, "sep_token", sep_token) + self.vocab[sep_token] = len(self.vocab) self.save_vocabulary(os.path.dirname(self.vocab_file)) def token_analysis(self, tokens): diff --git a/src/xtal2txt/vocabs/1.json b/src/xtal2txt/vocabs/1.json index 2b112a5..f5c1741 100644 --- a/src/xtal2txt/vocabs/1.json +++ b/src/xtal2txt/vocabs/1.json @@ -1 +1 @@ -{"H": 0, "He": 1, "Li": 2, "Be": 3, "B": 4, "C": 5, "N": 6, "O": 7, "F": 8, "Ne": 9, "Na": 10, "Mg": 11, "Al": 12, "Si": 13, "P": 14, "S": 15, "Cl": 16, "K": 17, "Ar": 18, "Ca": 19, "Sc": 20, "Ti": 21, "V": 22, "Cr": 23, "Mn": 24, "Fe": 25, "Ni": 26, "Co": 27, "Cu": 28, "Zn": 29, "Ga": 30, "Ge": 31, "As": 32, "Se": 33, "Br": 34, "Kr": 35, "Rb": 36, "Sr": 37, "Y": 38, "Zr": 39, "Nb": 40, "Mo": 41, "Tc": 42, "Ru": 43, "Rh": 44, "Pd": 45, "Ag": 46, "Cd": 47, "In": 48, "Sn": 49, "Sb": 50, "Te": 51, "I": 52, "Xe": 53, "Cs": 54, "Ba": 55, "La": 56, "Ce": 57, "Pr": 58, "Nd": 59, "Pm": 60, "Sm": 61, "Eu": 62, "Gd": 63, "Tb": 64, "Dy": 65, "Ho": 66, "Er": 67, "Tm": 68, "Yb": 69, "Lu": 70, "Hf": 71, "Ta": 72, "W": 73, "Re": 74, "Os": 75, "Ir": 76, "Pt": 77, "Au": 78, "Hg": 79, "Tl": 80, "Pb": 81, "Bi": 82, "Th": 83, "Pa": 84, "U": 85, "Np": 86, "Pu": 87, "Am": 88, "Cm": 89, "Bk": 90, "Cf": 91, "Es": 92, "Fm": 93, "Md": 94, "No": 95, "Lr": 96, "Rf": 97, "Db": 98, "Sg": 99, "Bh": 100, "Hs": 101, "Mt": 102, "Ds": 103, "Rg": 104, "Cn": 105, "Nh": 106, "Fl": 107, "Mc": 108, "Lv": 109, "Ts": 110, "Og": 111, "0": 112, "1": 113, "2": 114, "3": 115, "4": 116, "5": 117, "6": 118, "7": 119, "8": 120, "9": 121, "data_": 122, "_symmetry_space_group_name_H-M": 123, "_cell_length_a": 124, "_cell_length_b": 125, "_cell_length_c": 126, "_cell_angle_alpha": 127, "_cell_angle_beta": 128, "_cell_angle_gamma": 129, "_symmetry_Int_Tables_number": 130, "_chemical_formula_structural": 131, "_chemical_formula_sum": 132, "_cell_volume": 133, "_cell_formula_units_Z": 134, "loop_": 135, "_symmetry_equiv_pos_site_id": 136, "_symmetry_equiv_pos_as_xyz": 137, "_atom_type_symbol": 138, "_atom_type_oxidation_number": 139, "_atom_site_type_symbol": 140, "_atom_site_label": 141, "_atom_site_symmetry_multiplicity": 142, "_atom_site_fract_x": 143, "_atom_site_fract_y": 144, "_atom_site_fract_z": 145, "_atom_site_occupancy": 146, " ": 147, ".": 148, "+": 149, "-": 150, "/": 151, "'": 152, "\"": 153, ",": 154, "'x, y, z'": 155, "x": 156, "y": 157, "z": 158, "-x": 159, "-y": 160, "-z": 161, " ": 162, " ": 163, "\n": 164, "_geom_bond_atom_site_label_1": 165, "_geom_bond_atom_site_label_2": 166, "_geom_bond_distance": 167, "_ccdc_geom_bond_type": 168, "_": 169, "a": 170, "n": 171, "c": 172, "b": 173, "m": 174, "d": 175, "R": 176, "A": 177, "(": 178, ")": 179, "[": 180, "]": 181, "*": 182, "[UNK]": 183, "[PAD]": 184, "[CLS]": 185, "[SEP]": 186, "[MASK]": 187, "[EOS]": 188, "[BOS]": 189} \ No newline at end of file +{"o o o": 0, "o o +": 1, "o o -": 2, "o + o": 3, "o + +": 4, "o + -": 5, "o - o": 6, "o - +": 7, "o - -": 8, "+ o o": 9, "+ o +": 10, "+ o -": 11, "+ + o": 12, "+ + +": 13, "+ + -": 14, "+ - o": 15, "+ - +": 16, "+ - -": 17, "- o o": 18, "- o +": 19, "- o -": 20, "- + o": 21, "- + +": 22, "- + -": 23, "- - o": 24, "- - +": 25, "- - -": 26, "H": 27, "He": 28, "Li": 29, "Be": 30, "B": 31, "C": 32, "N": 33, "O": 34, "F": 35, "Ne": 36, "Na": 37, "Mg": 38, "Al": 39, "Si": 40, "P": 41, "S": 42, "Cl": 43, "K": 44, "Ar": 45, "Ca": 46, "Sc": 47, "Ti": 48, "V": 49, "Cr": 50, "Mn": 51, "Fe": 52, "Ni": 53, "Co": 54, "Cu": 55, "Zn": 56, "Ga": 57, "Ge": 58, "As": 59, "Se": 60, "Br": 61, "Kr": 62, "Rb": 63, "Sr": 64, "Y": 65, "Zr": 66, "Nb": 67, "Mo": 68, "Tc": 69, "Ru": 70, "Rh": 71, "Pd": 72, "Ag": 73, "Cd": 74, "In": 75, "Sn": 76, "Sb": 77, "Te": 78, "I": 79, "Xe": 80, "Cs": 81, "Ba": 82, "La": 83, "Ce": 84, "Pr": 85, "Nd": 86, "Pm": 87, "Sm": 88, "Eu": 89, "Gd": 90, "Tb": 91, "Dy": 92, "Ho": 93, "Er": 94, "Tm": 95, "Yb": 96, "Lu": 97, "Hf": 98, "Ta": 99, "W": 100, "Re": 101, "Os": 102, "Ir": 103, "Pt": 104, "Au": 105, "Hg": 106, "Tl": 107, "Pb": 108, "Bi": 109, "Th": 110, "Pa": 111, "U": 112, "Np": 113, "Pu": 114, "Am": 115, "Cm": 116, "Bk": 117, "Cf": 118, "Es": 119, "Fm": 120, "Md": 121, "No": 122, "Lr": 123, "Rf": 124, "Db": 125, "Sg": 126, "Bh": 127, "Hs": 128, "Mt": 129, "Ds": 130, "Rg": 131, "Cn": 132, "Nh": 133, "Fl": 134, "Mc": 135, "Lv": 136, "Ts": 137, "Og": 138, "0": 139, "1": 140, "2": 141, "3": 142, "4": 143, "5": 144, "6": 145, "7": 146, "8": 147, "9": 148, "[UNK]": 149, "[PAD]": 150, "[CLS]": 151, "[SEP]": 152, "[MASK]": 153, "[EOS]": 154, "[BOS]": 155} \ No newline at end of file