diff --git a/odyssey/data/DataProcessor.ipynb b/odyssey/data/DataProcessor.ipynb index b4d81bf..8b02ba7 100644 --- a/odyssey/data/DataProcessor.ipynb +++ b/odyssey/data/DataProcessor.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2024-03-13T16:14:45.546088300Z", @@ -10,15 +10,7 @@ }, "collapsed": true }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "[rank: 0] Seed set to 23\n" - ] - } - ], + "outputs": [], "source": [ "import os\n", "import sys\n", @@ -34,7 +26,7 @@ "DATASET = f\"{DATA_ROOT}/patient_sequences/patient_sequences_2048.parquet\"\n", "MAX_LEN = 2048\n", "\n", - "os.chdir(DATA_ROOT)\n", + "os.chdir(ROOT)\n", "\n", "from odyssey.utils.utils import seed_everything\n", "from odyssey.data.tokenizer import ConceptTokenizer\n", @@ -467,18 +459,6 @@ "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.9" } }, "nbformat": 4, diff --git a/odyssey/data/dataset.py b/odyssey/data/dataset.py index 5cc6530..f24e048 100644 --- a/odyssey/data/dataset.py +++ b/odyssey/data/dataset.py @@ -251,12 +251,33 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: """ data = self.data.iloc[idx] cutoff = data[self.cutoff_col] if self.cutoff_col else None - data = truncate_and_pad(data, cutoff=cutoff, max_len=self.max_len) + + # Truncate and pad the data to the specified cutoff. + data = truncate_and_pad(data, cutoff, self.max_len) + + # Prepare model input tokenized_input = self.tokenize_data(data[f"event_tokens_{self.max_len}"]) concept_ids = tokenized_input["input_ids"].squeeze() + type_tokens = data[f"type_tokens_{self.max_len}"] + age_tokens = data[f"age_tokens_{self.max_len}"] + time_tokens = data[f"time_tokens_{self.max_len}"] + visit_tokens = data[f"visit_tokens_{self.max_len}"] + position_tokens = data[f"position_tokens_{self.max_len}"] + + type_tokens = torch.tensor(type_tokens) + age_tokens = torch.tensor(age_tokens) + time_tokens = torch.tensor(time_tokens) + visit_tokens = torch.tensor(visit_tokens) + position_tokens = torch.tensor(position_tokens) + return { "concept_ids": concept_ids, + "type_ids": type_tokens, + "ages": age_tokens, + "time_stamps": time_tokens, + "visit_orders": position_tokens, + "visit_segments": visit_tokens, "labels": concept_ids, } @@ -691,6 +712,22 @@ def __len__(self) -> int: """Return the length of dataset.""" return len(self.index_mapper) + def tokenize_data(self, sequence: Union[str, List[str]]) -> Any: + """Tokenize the sequence and return input_ids and attention mask. + + Parameters + ---------- + sequence : Union[str, List[str]] + The sequence to be tokenized. + + Returns + ------- + Any + A dictionary containing input_ids and attention_mask. + + """ + return self.tokenizer(sequence, max_length=self.max_len) + def __getitem__(self, idx: int) -> Dict[str, Any]: """Get data at corresponding index. @@ -717,25 +754,30 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: # Prepare model input tokenized_input = self.tokenize_data(data[f"event_tokens_{self.max_len}"]) concept_ids = tokenized_input["input_ids"].squeeze() - labels = torch.tensor(labels) - - return {"concept_ids": concept_ids, "labels": labels, "task": task} - - def tokenize_data(self, sequence: Union[str, List[str]]) -> Any: - """Tokenize the sequence and return input_ids and attention mask. - Parameters - ---------- - sequence : Union[str, List[str]] - The sequence to be tokenized. + type_tokens = data[f"type_tokens_{self.max_len}"] + age_tokens = data[f"age_tokens_{self.max_len}"] + time_tokens = data[f"time_tokens_{self.max_len}"] + visit_tokens = data[f"visit_tokens_{self.max_len}"] + position_tokens = data[f"position_tokens_{self.max_len}"] - Returns - ------- - Any - A dictionary containing input_ids and attention_mask. + type_tokens = torch.tensor(type_tokens) + age_tokens = torch.tensor(age_tokens) + time_tokens = torch.tensor(time_tokens) + visit_tokens = torch.tensor(visit_tokens) + position_tokens = torch.tensor(position_tokens) + labels = torch.tensor(labels) - """ - return self.tokenizer(sequence, max_length=self.max_len) + return { + "concept_ids": concept_ids, + "type_ids": type_tokens, + "ages": age_tokens, + "time_stamps": time_tokens, + "visit_orders": position_tokens, + "visit_segments": visit_tokens, + "labels": labels, + "task": task, + } def balance_labels(self, task: str, positive_ratio: float) -> None: """Balance the labels for the specified task in the dataset. diff --git a/odyssey/evals/TestAnalysis.ipynb b/odyssey/evals/TestAnalysis.ipynb index 865a761..c1bd622 100644 --- a/odyssey/evals/TestAnalysis.ipynb +++ b/odyssey/evals/TestAnalysis.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -24,25 +24,25 @@ "ROOT = \"/fs01/home/afallah/odyssey/odyssey\"\n", "os.chdir(ROOT)\n", "\n", - "from odyssey.data.dataset import FinetuneMultiDataset\n", + "from odyssey.data.dataset import FinetuneMultiDataset, FinetuneDatasetDecoder\n", "from odyssey.data.tokenizer import ConceptTokenizer\n", "from odyssey.models.model_utils import load_finetune_data" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "class config:\n", " \"\"\"Save the configuration arguments.\"\"\"\n", "\n", - " model_path = \"checkpoints/multibird_finetune/multibird_finetune/test_outputs/test_outputs_1ec842db.pt\"\n", + " model_path = \"checkpoints/mamba_finetune/test_outputs/test_outputs_f8471ffd.pt\"\n", " vocab_dir = \"odyssey/data/vocab\"\n", " data_dir = \"odyssey/data/bigbird_data\"\n", - " sequence_file = \"patient_sequences/patient_sequences_2048_multi.parquet\"\n", - " id_file = \"patient_id_dict/dataset_2048_multi.pkl\"\n", + " sequence_file = \"patient_sequences_2048_multi.parquet\"\n", + " id_file = \"dataset_2048_multi.pkl\"\n", " valid_scheme = \"few_shot\"\n", " num_finetune_patients = \"all\"\n", " # label_name = \"label_mortality_1month\"\n", @@ -55,7 +55,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -80,7 +80,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -97,7 +97,7 @@ " config.num_finetune_patients,\n", ")\n", "\n", - "test_dataset = FinetuneMultiDataset(\n", + "test_dataset = FinetuneDatasetDecoder(\n", " data=fine_test,\n", " tokenizer=tokenizer,\n", " tasks=config.tasks,\n", @@ -127,9 +127,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 16, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "dict_keys(['labels', 'logits'])" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "test_outputs = torch.load(config.model_path, map_location=config.device)\n", "test_outputs.keys()" @@ -137,9 +148,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "array([0, 0, 0, ..., 0, 0, 0])" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "labels = test_outputs[\"labels\"].cpu().numpy()\n", "logits = test_outputs[\"logits\"].cpu().numpy()\n", @@ -151,15 +173,32 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 22, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "{'Balanced Accuracy': 0.8985595582256243,\n", + " 'F1 Score': 0.6536144578313253,\n", + " 'Precision': 0.5174058178350024,\n", + " 'Recall': 0.8871627146361406,\n", + " 'AUROC': 0.9667942606114659,\n", + " 'Average Precision Score': 0.47009681385740587,\n", + " 'AUC-PR': 0.7078210982047579}" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# Tasks we have are: tasks = ['mortality_1month', 'los_1week', 'c0', 'c1', 'c2']\n", "\n", "task_idx = []\n", "for i, task in enumerate(tasks):\n", - " if task == \"los_1week\":\n", + " if task == \"mortality_1month\":\n", " task_idx.append(i)\n", "\n", "\n", @@ -224,6 +263,18 @@ "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" } }, "nbformat": 4, diff --git a/odyssey/models/cehr_mamba/mamba-dev.ipynb b/odyssey/models/cehr_mamba/mamba-dev.ipynb index 14968f7..30858c5 100644 --- a/odyssey/models/cehr_mamba/mamba-dev.ipynb +++ b/odyssey/models/cehr_mamba/mamba-dev.ipynb @@ -2,9 +2,20 @@ "cells": [ { "cell_type": "code", - "execution_count": 77, + "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "'\\nNew Training:\\n1. Num parameters\\n2. Epochs\\n3. Overfitting\\n 4. Emebeddings\\n 5. Label balance\\n 6. Dataset\\n'" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "import os\n", "from typing import Any, Dict, Optional, Tuple, Union\n", @@ -63,7 +74,26 @@ ")\n", "from odyssey.utils.utils import seed_everything\n", "\n", - "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "\n", + "\"\"\"\n", + "New Training:\n", + "1. Num parameters\n", + "2. Epochs\n", + "3. Overfitting\n", + " 4. Emebeddings\n", + " 5. Label balance\n", + " 6. Dataset\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "model = torch.load(\"checkpoints/mamba_pretrain_with_embeddings/best.ckpt\")" ] }, { @@ -101,26 +131,26 @@ "\n", "\n", "# Setup data\n", - "# pre_data = load_pretrain_data(\n", - "# args.data_dir,\n", - "# f'patient_sequences/{args.sequence_file}',\n", - "# f'patient_id_dict/{args.id_file}',\n", - "# )\n", - "# train_dataset = PretrainDatasetDecoder(\n", - "# data=pre_data,\n", - "# tokenizer=tokenizer,\n", - "# max_len=args.max_len,\n", - "# )\n", - "\n", - "\n", - "_, fine_test = load_finetune_data(\n", - " args.data_dir, args.sequence_file, args.id_file, \"few_shot\", \"all\"\n", + "pre_data = load_pretrain_data(\n", + " args.data_dir,\n", + " f\"patient_sequences/{args.sequence_file}\",\n", + " f\"patient_id_dict/{args.id_file}\",\n", ")\n", - "test_dataset = PretrainDatasetDecoder(\n", - " data=fine_test,\n", + "train_dataset = PretrainDatasetDecoder(\n", + " data=pre_data,\n", " tokenizer=tokenizer,\n", " max_len=args.max_len,\n", - ")" + ")\n", + "\n", + "\n", + "# _, fine_test = load_finetune_data(\n", + "# args.data_dir, args.sequence_file, args.id_file, \"few_shot\", \"all\"\n", + "# )\n", + "# test_dataset = PretrainDatasetDecoder(\n", + "# data=fine_test,\n", + "# tokenizer=tokenizer,\n", + "# max_len=args.max_len,\n", + "# )" ] }, { @@ -177,45 +207,14 @@ "model = MambaForCausalLM(config=config)\n", "# model.backbone.embeddings = embeddings\n", "model.to(device)\n", - "\n", "model" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "MambaForCausalLM(\n", - " (backbone): MambaModel(\n", - " (embeddings): Embedding(20600, 768)\n", - " (layers): ModuleList(\n", - " (0-31): 32 x MambaBlock(\n", - " (norm): MambaRMSNorm()\n", - " (mixer): MambaMixer(\n", - " (conv1d): Conv1d(1536, 1536, kernel_size=(4,), stride=(1,), padding=(3,), groups=1536)\n", - " (act): SiLU()\n", - " (in_proj): Linear(in_features=768, out_features=3072, bias=False)\n", - " (x_proj): Linear(in_features=1536, out_features=80, bias=False)\n", - " (dt_proj): Linear(in_features=48, out_features=1536, bias=True)\n", - " (out_proj): Linear(in_features=1536, out_features=768, bias=False)\n", - " )\n", - " )\n", - " )\n", - " (norm_f): MambaRMSNorm()\n", - " )\n", - " (lm_head): Linear(in_features=768, out_features=20600, bias=False)\n", - ")" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "# Load pretrained model\n", "checkpoint = torch.load(\"checkpoints/mamba_pretrain/best.ckpt\", map_location=device)\n", @@ -227,22 +226,9 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'concept_ids': tensor([[20592, 3, 17326, ..., 0, 0, 0]], device='cuda:0'),\n", - " 'labels': tensor([1], device='cuda:0'),\n", - " 'task': 'mortality_1month'}" - ] - }, - "execution_count": 35, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "train_loader = DataLoader(\n", " decoder_dataset, # test_dataset, #train_dataset\n", @@ -263,17 +249,9 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[MOR_1M] [VS] 58160087546 00904516561 00182853489 00574705050 00121054410 66553000401 00310027539 00006003121 00456320563 62856024541 00310027539 00172531210 00338069104 51006_2 50983_4 50971_2 50970_2 50960_3 50931_1 50912_2 50902_4 50893_3 50882_2 50868_2 51301_1 51279_4 51277_1 51265_2 51250_0 51248_1 51222_4 51221_4 00172531110 10432017002 10432017002 51006_2 50983_4 50971_3 50970_2 50960_3 50931_4 50912_3 50902_3 50893_4 50882_3 50868_3 51301_1 51279_4 51277_1 51265_1 51250_0 51248_1 51222_4 51221_4 00904224461 51301_2 51279_4 51277_1 51265_2 51250_0 51248_1 51222_4 51221_4 51006_2 50983_4 50971_3 50970_2 50960_2 50931_2 50912_3 50902_2 50893_3 50882_3 50868_4 [VE] [REG] [W_0] [VS] 8938 8838 8744 51006_2 50983_4 50971_2 50931_1 50912_2 50902_4 50882_3 50868_3 51301_1 51279_4 51277_1 51265_2 51256_1 51254_3 51250_1 51248_1 51244_3 51222_4 51221_4 51200_3 51146_2 63323026201 00713016550 00182844789 62856024541 00182853489 00904516561 00006003121 00121065721 10432017002 00182844789 66553000401 00904224461 00310027539 [VE] [REG] [M_2] [VS] 00338004304 63323026201 62856024541 00182844789 00310027539 10432017002 00904404073 00338004904 00456320563 66553000401 55390014710 51006_2 50983_3 50971_1 50970_0 50960_3 50931_2 50912_2 50902_4 50893_1 50882_2 50868_1 51301_2 51279_4 51277_1 51265_1 51250_1 51248_1 51222_4 51221_4 51006_2 50983_4 50971_0 50970_1 50960_2 50931_0 50912_2 50902_4 50893_2 50882_1 50868_3 00310027539 00172531210 00310027839 [VE] [REG] [LT] [VS] 51498_4 51491_0 33332001001 63323026201 00456320563 51079000220 60505258600 62856024541 60505258600 60505258600 51516_1 51498_4 51493_1 51491_0 51476_1 51301_0 51279_4 51277_1 51265_1 51250_0 51248_1 51222_4 51221_4 51006_2 50983_4 50971_1 50931_1 50912_2 50902_4 50893_3 50882_3 50868_1 60505258500 00310027539 [VE] [MOR_1M]\n" - ] - } - ], + "outputs": [], "source": [ "input_ids = sample[\"concept_ids\"].squeeze().tolist()\n", "input_ids = input_ids[: input_ids.index(0)]\n", @@ -282,40 +260,18 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'50931_1 50912_2 50902_4 50893_3 50882_3 50868_1 60505258500 00310027539 [VE] [MOR_1M]'" - ] - }, - "execution_count": 28, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "tokenizer.decode(input_ids[-10:])" ] }, { "cell_type": "code", - "execution_count": 29, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'50970_1 50960_2 50931_0 50912_2 50902_4 50893_3 50882_3 50868_1 [VE] [REG]'" - ] - }, - "execution_count": 29, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "output = model.generate(\n", " torch.tensor(input_ids[:-10], dtype=torch.int32).unsqueeze(0).to(device),\n", @@ -372,20 +328,9 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([1, 2048, 768])" - ] - }, - "execution_count": 37, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "# model = model.backbone\n", "outputs = model(input_ids=sample[\"concept_ids\"], return_dict=True)\n", @@ -396,46 +341,18 @@ }, { "cell_type": "code", - "execution_count": 55, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'silu'" - ] - }, - "execution_count": 55, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "config.hidden_act" ] }, { "cell_type": "code", - "execution_count": 38, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[[ 0.1596, -0.0591],\n", - " [ 0.2433, -1.0078],\n", - " [-0.3050, -0.3962],\n", - " ...,\n", - " [ 0.4917, -0.3203],\n", - " [ 0.4917, -0.3203],\n", - " [ 0.4917, -0.3203]]], device='cuda:0', grad_fn=)" - ] - }, - "execution_count": 38, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "classifier = torch.nn.Linear(config.hidden_size, 2, bias=False).to(device)\n", "logits = classifier(last_hidden_states)\n", @@ -444,40 +361,18 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor(20592, device='cuda:0')" - ] - }, - "execution_count": 44, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "sample[\"concept_ids\"].squeeze()[204]" ] }, { "cell_type": "code", - "execution_count": 40, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([204], device='cuda:0')" - ] - }, - "execution_count": 40, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "sequence_lengths = torch.eq(sample[\"concept_ids\"], 0).int().argmax(-1) - 1\n", "sequence_lengths" @@ -485,20 +380,9 @@ }, { "cell_type": "code", - "execution_count": 54, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[ 0.1697, -0.7979]], device='cuda:0', grad_fn=)" - ] - }, - "execution_count": 54, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "pooled_logits = logits[torch.arange(1, device=device), sequence_lengths]\n", "pooled_logits" @@ -506,20 +390,9 @@ }, { "cell_type": "code", - "execution_count": 53, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[ 0.1697, -0.7979]], device='cuda:0', grad_fn=)" - ] - }, - "execution_count": 53, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "pooled_last_hidden_states = last_hidden_states[\n", " torch.arange(1, device=device), sequence_lengths\n", @@ -529,20 +402,9 @@ }, { "cell_type": "code", - "execution_count": 72, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[-0.0512, -0.2630]], device='cuda:0', grad_fn=)" - ] - }, - "execution_count": 72, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "import copy\n", "\n", @@ -554,20 +416,9 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor(0.3221, device='cuda:0', grad_fn=)" - ] - }, - "execution_count": 46, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "loss_fct = torch.nn.CrossEntropyLoss()\n", "loss = loss_fct(pooled_logits.view(-1, 2), torch.tensor([0]).to(device).view(-1))\n", @@ -633,22 +484,9 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'concept_ids': tensor([20593, 3, 13054, ..., 0, 0, 0]),\n", - " 'labels': tensor(0),\n", - " 'task': 'los_1week'}" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "import random\n", "from typing import Any, Dict, List, Optional, Tuple, Union\n", @@ -664,189 +502,7 @@ "CUTOFF_INDEX = 3\n", "\n", "\n", - "class FinetuneDatasetDecoder(Dataset):\n", - " \"\"\"Dataset for finetuning a decoder-based model.\n", - "\n", - " Parameters\n", - " ----------\n", - " data : pd.DataFrame\n", - " The input data containing sequences to be tokenized and masked.\n", - " tokenizer : ConceptTokenizer\n", - " An instance of the ConceptTokenizer class used for tokenizing sequences.\n", - " tasks : List[str]\n", - " A list of tasks (labels) that need to be predicted.\n", - " balance_guide : Optional[Dict[str, float]], optional\n", - " A dictionary containing the desired positive ratios for each task,\n", - " by default None.\n", - " max_len : int, optional\n", - " The maximum length of the tokenized sequences, by default 2048.\n", - " nan_indicator : int, optional\n", - " Value used to represent missing labels in the dataset, by default -1.\n", - "\n", - " Attributes\n", - " ----------\n", - " data : pd.DataFrame\n", - " Stores the input data.\n", - " tokenizer : ConceptTokenizer\n", - " Tokenizer used for tokenizing sequences.\n", - " tasks : List[str]\n", - " A list of tasks (labels) that need to be predicted.\n", - " balance_guide : Optional[Dict[str, float]]\n", - " A dictionary containing the desired positive ratios for each task.\n", - " max_len : int\n", - " Maximum length of the tokenized sequences.\n", - " nan_indicator : int\n", - " Value used to represent missing labels in the dataset.\n", - " task_to_index : Dict[str, List[Tuple[int, str, int, Optional[int]]]]\n", - " A dictionary mapping each task to a list of tuples containing the\n", - " index, task, label, and cutoff.\n", - " index_mapper : List[Tuple[int, str, int, Optional[int]]]\n", - " A list of all datapoints to be used by __getitem__.\n", - " \"\"\"\n", - "\n", - " def __init__(\n", - " self,\n", - " data: pd.DataFrame,\n", - " tokenizer: ConceptTokenizer,\n", - " tasks: List[str],\n", - " balance_guide: Optional[Dict[str, float]] = None,\n", - " max_len: int = 2048,\n", - " nan_indicator: int = -1,\n", - " ):\n", - " \"\"\"Initiate the class.\"\"\"\n", - " super().__init__()\n", - "\n", - " self.data = data\n", - " self.tokenizer = tokenizer\n", - " self.tasks = tasks # List of tasks for which the model is being finetuned.\n", - " self.balance_guide = balance_guide\n", - " self.max_len = max_len\n", - " self.nan_indicator = (\n", - " nan_indicator # Value used to indicate missing data in labels.\n", - " )\n", - "\n", - " # Precompute indices for quick mapping in __getitem__ that\n", - " # exclude missing labels.\n", - " # This helps in filtering out entries where the label is missing\n", - " # for the specified tasks.\n", - " self.task_to_index = {task: [] for task in self.tasks}\n", - " self.data.reset_index(drop=True, inplace=True)\n", - "\n", - " for patient in self.data.itertuples():\n", - " index = patient.Index\n", - "\n", - " for task in self.tasks:\n", - " label_col = f\"label_{task}\"\n", - " # Skip this task for the current patient if the label is missing.\n", - " if getattr(patient, label_col) == self.nan_indicator:\n", - " continue\n", - "\n", - " label = getattr(patient, label_col)\n", - " # Check for the existence of a task-specific cutoff in the data,\n", - " # else use None.\n", - " if f\"cutoff_{task}\" in self.data.columns:\n", - " cutoff = getattr(patient, f\"cutoff_{task}\")\n", - " else:\n", - " cutoff = None\n", - " # Append a tuple containing the necessary information\n", - " # for training to index_mapper.\n", - " datapoint = (index, task, label, cutoff)\n", - " self.task_to_index[task].append(datapoint)\n", - "\n", - " # Balance labels for specified tasks\n", - " if self.balance_guide:\n", - " for task in self.balance_guide:\n", - " self.balance_labels(task=task, positive_ratio=self.balance_guide[task])\n", - "\n", - " # Create a list of all datapoints to be used by __getitem__\n", - " self.index_mapper = [\n", - " datapoints\n", - " for task_data in self.task_to_index.values()\n", - " for datapoints in task_data\n", - " ]\n", - " del self.task_to_index\n", - "\n", - " def __len__(self) -> int:\n", - " \"\"\"Return the length of dataset.\"\"\"\n", - " return len(self.index_mapper)\n", - "\n", - " def __getitem__(self, idx: int) -> Dict[str, Any]:\n", - " \"\"\"Get data at corresponding index.\n", - "\n", - " Parameters\n", - " ----------\n", - " idx : int\n", - " The index of the data to be retrieved.\n", - "\n", - " Returns\n", - " -------\n", - " Dict[str, Any]\n", - " A dictionary containing all different token sequences along with labels.\n", - " \"\"\"\n", - " index, task, labels, cutoff = self.index_mapper[idx]\n", - " data = self.data.iloc[index]\n", - "\n", - " # Swap the first and last token with the task token.\n", - " data[f\"event_tokens_{self.max_len}\"][0] = self.tokenizer.task_to_token(task)\n", - " data[f\"event_tokens_{self.max_len}\"][-1] = self.tokenizer.task_to_token(task)\n", - "\n", - " # Truncate and pad the data to the specified cutoff.\n", - " data = truncate_and_pad(data, cutoff, self.max_len)\n", - "\n", - " # Prepare model input\n", - " tokenized_input = self.tokenize_data(data[f\"event_tokens_{self.max_len}\"])\n", - " concept_ids = tokenized_input[\"input_ids\"].squeeze()\n", - " labels = torch.tensor(labels)\n", - "\n", - " return {\"concept_ids\": concept_ids, \"labels\": labels, \"task\": task}\n", - "\n", - " def tokenize_data(self, sequence: Union[str, List[str]]) -> Any:\n", - " \"\"\"Tokenize the sequence and return input_ids and attention mask.\n", - "\n", - " Parameters\n", - " ----------\n", - " sequence : Union[str, List[str]]\n", - " The sequence to be tokenized.\n", - "\n", - " Returns\n", - " -------\n", - " Any\n", - " A dictionary containing input_ids and attention_mask.\n", - "\n", - " \"\"\"\n", - " return self.tokenizer(sequence, max_length=self.max_len)\n", - "\n", - " def balance_labels(self, task: str, positive_ratio: float) -> None:\n", - " \"\"\"Balance the labels for the specified task in the dataset.\n", - "\n", - " This function modifies the dataset to ensure that the ratio of positive samples\n", - " to the total number of samples matches the specified positive_ratio,\n", - " while keeping all positive data points.\n", - "\n", - " Parameters\n", - " ----------\n", - " task : str\n", - " The task for which the labels need to be balanced.\n", - " positive_ratio : float\n", - " The desired positive ratio for the task.\n", - "\n", - " \"\"\"\n", - " # Separate positive and negative datapoints\n", - " datapoints = self.task_to_index[task]\n", - " positives = [data for data in datapoints if data[LABEL_INDEX] == 1]\n", - " negatives = [data for data in datapoints if data[LABEL_INDEX] == 0]\n", - "\n", - " # Calculate the total number of samples needed to achieve the\n", - " # desired positive ratio\n", - " num_positives = len(positives)\n", - " total_needed = int(num_positives / positive_ratio) - num_positives\n", - " num_negatives_to_keep = min(len(negatives), total_needed)\n", - "\n", - " # Randomly select the negatives to keep\n", - " negatives_kept = random.sample(negatives, num_negatives_to_keep)\n", - "\n", - " # Combine the kept negatives with all positives\n", - " self.task_to_index[task] = positives + negatives_kept\n", + "# Load FinetuneDatasetDecoder for debugging\n", "\n", "\n", "decoder_dataset = FinetuneDatasetDecoder(\n", @@ -861,210 +517,145 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ - "class MambaEmbeddingsForCEHR(nn.Module):\n", - " \"\"\"Construct the embeddings from word, position and token_type embeddings.\"\"\"\n", - "\n", - " # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__\n", - " def __init__(\n", - " self,\n", - " config: MambaConfig,\n", - " max_position_embeddings: int = 2048,\n", - " type_vocab_size: int = 8,\n", - " time_embeddings_size: int = 16,\n", - " visit_order_size: int = 3,\n", - " layer_norm_eps: float = 1e-12,\n", - " hidden_dropout_prob: float = 0.1,\n", - " ) -> None:\n", - " \"\"\"Initiate wrapper class for embeddings used in BigBird CEHR classes.\"\"\"\n", - " super().__init__()\n", - " self.max_position_embeddings = max_position_embeddings\n", - " self.type_vocab_size = type_vocab_size\n", - " self.layer_norm_eps = layer_norm_eps\n", - " self.hidden_dropout_prob = hidden_dropout_prob\n", - " self.hidden_size = config.hidden_size\n", - "\n", - " self.word_embeddings = nn.Embedding(\n", - " config.vocab_size,\n", - " config.hidden_size,\n", - " padding_idx=config.pad_token_id,\n", - " )\n", - " self.position_embeddings = nn.Embedding(\n", - " self.max_position_embeddings,\n", - " config.hidden_size,\n", - " )\n", - " self.token_type_embeddings = nn.Embedding(\n", - " self.type_vocab_size,\n", - " config.hidden_size,\n", - " )\n", - " self.visit_order_embeddings = nn.Embedding(\n", - " self.max_position_embeddings,\n", - " config.hidden_size,\n", - " )\n", - " self.time_embeddings = TimeEmbeddingLayer(\n", - " embedding_size=time_embeddings_size,\n", - " is_time_delta=True,\n", - " )\n", - " self.age_embeddings = TimeEmbeddingLayer(\n", - " embedding_size=time_embeddings_size,\n", - " )\n", - " self.visit_segment_embeddings = VisitEmbedding(\n", - " visit_order_size=visit_order_size,\n", - " embedding_size=config.hidden_size,\n", - " )\n", - " self.scale_back_concat_layer = nn.Linear(\n", - " config.hidden_size + 2 * time_embeddings_size,\n", - " config.hidden_size,\n", - " )\n", - "\n", - " self.time_stamps: Optional[torch.Tensor] = None\n", - " self.ages: Optional[torch.Tensor] = None\n", - " self.visit_orders: Optional[torch.Tensor] = None\n", - " self.visit_segments: Optional[torch.Tensor] = None\n", - "\n", - " # self.LayerNorm is not snake-cased to stick with TensorFlow model\n", - " # variable name and be able to load any TensorFlow checkpoint file.\n", - " self.tanh = nn.Tanh()\n", - " self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=self.layer_norm_eps)\n", - " self.dropout = nn.Dropout(self.hidden_dropout_prob)\n", - "\n", - " # position_ids (1, len position emb) is contiguous in memory.\n", - " self.position_embedding_type = getattr(\n", - " config,\n", - " \"position_embedding_type\",\n", - " \"absolute\",\n", - " )\n", - " self.register_buffer(\n", - " \"position_ids\",\n", - " torch.arange(self.max_position_embeddings).expand((1, -1)),\n", - " persistent=False,\n", - " )\n", - " self.register_buffer(\n", - " \"token_type_ids\",\n", - " torch.zeros(self.position_ids.size(), dtype=torch.long),\n", - " persistent=False,\n", - " )\n", - " # End copy\n", - "\n", - " def cache_input(\n", - " self,\n", - " token_type_ids_batch: Optional[torch.Tensor] = None,\n", - " position_ids_batch: Optional[torch.Tensor] = None,\n", - " inputs_embeds: Optional[torch.Tensor] = None,\n", - " time_stamps: Optional[torch.Tensor] = None,\n", - " ages: Optional[torch.Tensor] = None,\n", - " visit_orders: Optional[torch.Tensor] = None,\n", - " visit_segments: Optional[torch.Tensor] = None,\n", - " ) -> None:\n", - " \"\"\"Cache values for time_stamps, ages, visit_orders & visit_segments.\n", - "\n", - " These values will be used by the forward pass to change the final embedding.\n", - "\n", - " Parameters\n", - " ----------\n", - " token_type_ids_batch : torch.Tensor\n", - " The token type IDs of the input data.\n", - " position_ids_batch : torch.Tensor\n", - " The position IDs of the input data.\n", - " inputs_embeds : torch.Tensor\n", - " The embeddings of the input data.\n", - " time_stamps : torch.Tensor\n", - " Time stamps of the input data.\n", - " ages : torch.Tensor\n", - " Ages of the input data.\n", - " visit_orders : torch.Tensor\n", - " Visit orders of the input data.\n", - " visit_segments : torch.Tensor\n", - " Visit segments of the input data.\n", - " \"\"\"\n", - " self.token_type_ids_batch = token_type_ids_batch\n", - " self.position_ids_batch = position_ids_batch\n", - " self.inputs_embeds = inputs_embeds\n", - " self.time_stamps = time_stamps\n", - " self.ages = ages\n", - " self.visit_orders = visit_orders\n", - " self.visit_segments = visit_segments\n", - "\n", - " def clear_cache(self) -> None:\n", - " \"\"\"Delete the tensors cached by cache_input method.\"\"\"\n", - " del (\n", - " self.token_type_ids_batch,\n", - " self.position_ids_batch,\n", - " self.inputs_embeds,\n", - " self.time_stamps,\n", - " self.ages,\n", - " self.visit_orders,\n", - " self.visit_segments,\n", - " )\n", - "\n", - " def forward(\n", - " self,\n", - " input_ids: Optional[torch.Tensor] = None,\n", - " past_key_values_length: int = 0,\n", - " ) -> Any:\n", - " \"\"\"Return the final embeddings of concept ids using input and cached values.\"\"\"\n", - " if input_ids is not None:\n", - " input_shape = input_ids.size()\n", - " else:\n", - " input_shape = self.inputs_embeds.size()[:-1]\n", - "\n", - " seq_length = input_shape[1]\n", - "\n", - " if self.position_ids_batch is None:\n", - " self.position_ids_batch = self.position_ids[\n", - " :,\n", - " past_key_values_length : seq_length + past_key_values_length,\n", - " ]\n", - "\n", - " # Setting the token_type_ids to the registered buffer in constructor\n", - " if self.token_type_ids_batch is None:\n", - " if hasattr(self, \"token_type_ids\"):\n", - " buffered_token_type_ids = self.token_type_ids[:, :seq_length]\n", - " buffered_token_type_ids_expanded = buffered_token_type_ids.expand(\n", - " input_shape[0],\n", - " seq_length,\n", - " )\n", - " self.token_type_ids_batch = buffered_token_type_ids_expanded\n", - " else:\n", - " self.token_type_ids_batch = torch.zeros(\n", - " input_shape,\n", - " dtype=torch.long,\n", - " device=self.position_ids.device,\n", - " )\n", - "\n", - " if self.inputs_embeds is None:\n", - " self.inputs_embeds = self.word_embeddings(input_ids)\n", - "\n", - " # Using cached values from a prior cache_input call\n", - " time_stamps_embeds = self.time_embeddings(self.time_stamps)\n", - " ages_embeds = self.age_embeddings(self.ages)\n", - " visit_segments_embeds = self.visit_segment_embeddings(self.visit_segments)\n", - " visit_order_embeds = self.visit_order_embeddings(self.visit_orders)\n", - "\n", - " position_embeds = self.position_embeddings(self.position_ids_batch)\n", - " token_type_embeds = self.token_type_embeddings(self.token_type_ids_batch)\n", - "\n", - " self.inputs_embeds = torch.cat(\n", - " (self.inputs_embeds, time_stamps_embeds, ages_embeds),\n", - " dim=-1,\n", - " )\n", - " print(self.inputs_embeds.shape)\n", - " self.inputs_embeds = self.tanh(self.scale_back_concat_layer(self.inputs_embeds))\n", - " embeddings = self.inputs_embeds + token_type_embeds\n", - " embeddings += position_embeds\n", - " embeddings += visit_order_embeds\n", - " embeddings += visit_segments_embeds\n", - "\n", - " embeddings = self.dropout(embeddings)\n", - " embeddings = self.LayerNorm(embeddings)\n", + "from odyssey.models.embeddings import *\n", "\n", - " # Clear the cache for next forward call\n", - " self.clear_cache()\n", "\n", - " return embeddings" + "embeddings = MambaEmbeddingsForCEHR(\n", + " config=config,\n", + " type_vocab_size=9,\n", + " max_num_visits=512,\n", + " time_embeddings_size=32,\n", + " visit_order_size=3,\n", + " hidden_dropout_prob=0.1,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'concept_ids': tensor([[ 5, 3, 18896, ..., 1712, 4, 6]]),\n", + " 'type_ids': tensor([[1, 2, 6, ..., 5, 3, 8]]),\n", + " 'ages': tensor([[ 0, 77, 77, ..., 78, 78, 78]]),\n", + " 'time_stamps': tensor([[ 0, 8928, 8928, ..., 8981, 8981, 8981]]),\n", + " 'visit_orders': tensor([[0, 1, 1, ..., 8, 8, 8]]),\n", + " 'visit_segments': tensor([[0, 2, 2, ..., 1, 1, 1]]),\n", + " 'labels': tensor([[ 5, 3, 18896, ..., 1712, 4, 6]])}" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "batch = train_dataset[51020]\n", + "batch = {key: tensor.unsqueeze(0) for key, tensor in batch.items()}\n", + "batch" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{0, 1, 2, 3, 4, 5, 6, 7, 8}\n" + ] + } + ], + "source": [ + "print(set(batch[\"visit_orders\"][0].tolist()))" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[ 0.0452, -0.9307, 0.3723, ..., -1.1524, 0.5854, -0.3397],\n", + " [ 0.0355, -0.5227, -3.0730, ..., 0.0355, -0.9060, -0.9247],\n", + " [ 0.5338, -0.9962, -1.5450, ..., 1.6476, -1.0616, -1.5152],\n", + " ...,\n", + " [ 0.1515, 1.6252, -0.2081, ..., -1.0206, 0.8621, 1.3194],\n", + " [-0.0812, 2.3611, -0.1516, ..., -1.0140, -0.0978, 1.6653],\n", + " [ 0.4165, 0.8611, -0.5180, ..., 0.3821, 1.1638, 1.4207]]],\n", + " grad_fn=)" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "inputs = (\n", + " batch[\"concept_ids\"],\n", + " batch[\"type_ids\"],\n", + " batch[\"time_stamps\"],\n", + " batch[\"ages\"],\n", + " batch[\"visit_orders\"],\n", + " batch[\"visit_segments\"],\n", + ")\n", + "labels = batch[\"labels\"]\n", + "\n", + "concept_ids, type_ids, time_stamps, ages, visit_orders, visit_segments = inputs\n", + "inputs_embeds = embeddings(\n", + " input_ids=concept_ids,\n", + " token_type_ids_batch=type_ids,\n", + " time_stamps=time_stamps,\n", + " ages=ages,\n", + " visit_orders=visit_orders,\n", + " visit_segments=visit_segments,\n", + ")\n", + "inputs_embeds" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "MambaCausalLMOutput(loss=tensor(14.1312, device='cuda:0', grad_fn=), logits=tensor([[[-1.7062, -5.0347, 2.0794, ..., 2.3091, 1.0115, -1.9545],\n", + " [-1.0205, -2.8787, -4.8018, ..., -3.2100, -5.3467, -1.1486],\n", + " [ 1.2367, -4.0578, -3.7514, ..., -0.0644, 0.9085, 0.9692],\n", + " ...,\n", + " [ 2.3067, 3.5723, 1.9051, ..., -0.0123, -2.5649, -0.4133],\n", + " [ 4.0669, 4.1643, 3.6506, ..., 2.8866, -5.4374, -0.8073],\n", + " [ 1.8762, 5.5222, -0.6316, ..., 0.1687, -7.1170, -4.6202]]],\n", + " device='cuda:0', grad_fn=), cache_params=None, hidden_states=None)" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "outputs = model(\n", + " inputs_embeds=inputs_embeds.to(device),\n", + " labels=labels.to(device),\n", + " output_hidden_states=False,\n", + " return_dict=True,\n", + ")\n", + "outputs" ] } ], @@ -1073,6 +664,18 @@ "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" } }, "nbformat": 4, diff --git a/odyssey/models/cehr_mamba/mamba_utils.py b/odyssey/models/cehr_mamba/mamba_utils.py index 51826df..9655321 100644 --- a/odyssey/models/cehr_mamba/mamba_utils.py +++ b/odyssey/models/cehr_mamba/mamba_utils.py @@ -108,12 +108,20 @@ def forward( Returns ------- """ - sequence_outputs = self.backbone( - input_ids, - inputs_embeds=inputs_embeds, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) + if inputs_embeds is not None: + sequence_outputs = self.backbone( + input_ids=None, + inputs_embeds=inputs_embeds, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + else: + sequence_outputs = self.backbone( + input_ids=input_ids, + inputs_embeds=None, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) last_hidden_states = sequence_outputs[0] batch_size = last_hidden_states.shape[0] diff --git a/odyssey/models/cehr_mamba/model.py b/odyssey/models/cehr_mamba/model.py index 3d4507a..c9a1814 100644 --- a/odyssey/models/cehr_mamba/model.py +++ b/odyssey/models/cehr_mamba/model.py @@ -26,6 +26,7 @@ MambaForSequenceClassification, MambaSequenceClassifierOutput, ) +from odyssey.models.embeddings import MambaEmbeddingsForCEHR class MambaPretrain(pl.LightningModule): @@ -35,6 +36,10 @@ def __init__( self, vocab_size: int, embedding_size: int = 768, + time_embeddings_size: int = 32, + visit_order_size: int = 3, + type_vocab_size: int = 9, + max_num_visits: int = 512, max_seq_length: int = 2048, state_size: int = 16, num_hidden_layers: int = 32, @@ -49,6 +54,10 @@ def __init__( self.vocab_size = vocab_size self.embedding_size = embedding_size + self.time_embeddings_size = time_embeddings_size + self.visit_order_size = visit_order_size + self.type_vocab_size = type_vocab_size + self.max_num_visits = max_num_visits self.max_seq_length = max_seq_length self.state_size = state_size self.num_hidden_layers = num_hidden_layers @@ -70,22 +79,68 @@ def __init__( bos_token_id=self.cls_idx, eos_token_id=self.padding_idx, ) + self.embeddings = MambaEmbeddingsForCEHR( + config=self.config, + type_vocab_size=self.type_vocab_size, + max_num_visits=self.max_num_visits, + time_embeddings_size=self.time_embeddings_size, + visit_order_size=self.visit_order_size, + hidden_dropout_prob=self.dropout_prob, + ) + # Initialize weights and apply final processing + self.post_init() + # Mamba has its own initialization self.model = MambaForCausalLM(config=self.config) + def _init_weights(self, module: torch.nn.Module) -> None: + """Initialize the weights.""" + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def post_init(self) -> None: + """Apply weight initialization.""" + self.apply(self._init_weights) + def forward( self, - input_ids: torch.Tensor, + inputs: Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ], labels: Optional[torch.Tensor] = None, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, ) -> Union[Tuple[torch.Tensor, ...], MambaCausalLMOutput]: """Forward pass for the model.""" + concept_ids, type_ids, time_stamps, ages, visit_orders, visit_segments = inputs + inputs_embeds = self.embeddings( + input_ids=concept_ids, + token_type_ids_batch=type_ids, + time_stamps=time_stamps, + ages=ages, + visit_orders=visit_orders, + visit_segments=visit_segments, + ) + if labels is None: - labels = input_ids + labels = concept_ids return self.model( - input_ids=input_ids, + inputs_embeds=inputs_embeds, labels=labels, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -93,13 +148,20 @@ def forward( def training_step(self, batch: Dict[str, Any], batch_idx: int) -> Any: """Train model on training dataset.""" - concept_ids = batch["concept_ids"] + inputs = ( + batch["concept_ids"], + batch["type_ids"], + batch["time_stamps"], + batch["ages"], + batch["visit_orders"], + batch["visit_segments"], + ) labels = batch["labels"] # Ensure use of mixed precision with autocast(): loss = self( - input_ids=concept_ids, + inputs, labels=labels, return_dict=True, ).loss @@ -111,18 +173,24 @@ def training_step(self, batch: Dict[str, Any], batch_idx: int) -> Any: prog_bar=True, sync_dist=True, ) - return loss def validation_step(self, batch: Dict[str, Any], batch_idx: int) -> Any: """Evaluate model on validation dataset.""" - concept_ids = batch["concept_ids"] + inputs = ( + batch["concept_ids"], + batch["type_ids"], + batch["time_stamps"], + batch["ages"], + batch["visit_orders"], + batch["visit_segments"], + ) labels = batch["labels"] # Ensure use of mixed precision with autocast(): loss = self( - input_ids=concept_ids, + inputs, labels=labels, return_dict=True, ).loss @@ -197,6 +265,7 @@ def __init__( # self.post_init() self.pretrained_model = pretrained_model + self.embeddings = self.pretrained_model.embeddings self.model.backbone = self.pretrained_model.model.backbone def _init_weights(self, module: torch.nn.Module) -> None: @@ -219,14 +288,32 @@ def post_init(self) -> None: def forward( self, - input_ids: torch.Tensor, + inputs: Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ], labels: Optional[torch.Tensor] = None, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, ) -> Union[Tuple[torch.Tensor, ...], MambaSequenceClassifierOutput]: """Forward pass for the model.""" + concept_ids, type_ids, time_stamps, ages, visit_orders, visit_segments = inputs + inputs_embeds = self.embeddings( + input_ids=concept_ids, + token_type_ids_batch=type_ids, + time_stamps=time_stamps, + ages=ages, + visit_orders=visit_orders, + visit_segments=visit_segments, + ) + return self.model( - input_ids=input_ids, + input_ids=concept_ids, + inputs_embeds=inputs_embeds, labels=labels, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -234,13 +321,20 @@ def forward( def training_step(self, batch: Dict[str, Any], batch_idx: int) -> Any: """Train model on training dataset.""" - concept_ids = batch["concept_ids"] + inputs = ( + batch["concept_ids"], + batch["type_ids"], + batch["time_stamps"], + batch["ages"], + batch["visit_orders"], + batch["visit_segments"], + ) labels = batch["labels"] # Ensure use of mixed precision with autocast(): loss = self( - input_ids=concept_ids, + inputs, labels=labels, return_dict=True, ).loss @@ -257,13 +351,20 @@ def training_step(self, batch: Dict[str, Any], batch_idx: int) -> Any: def validation_step(self, batch: Dict[str, Any], batch_idx: int) -> Any: """Evaluate model on validation dataset.""" - concept_ids = batch["concept_ids"] + inputs = ( + batch["concept_ids"], + batch["type_ids"], + batch["time_stamps"], + batch["ages"], + batch["visit_orders"], + batch["visit_segments"], + ) labels = batch["labels"] # Ensure use of mixed precision with autocast(): loss = self( - input_ids=concept_ids, + inputs, labels=labels, return_dict=True, ).loss @@ -280,13 +381,20 @@ def validation_step(self, batch: Dict[str, Any], batch_idx: int) -> Any: def test_step(self, batch: Dict[str, Any], batch_idx: int) -> Any: """Test step.""" - concept_ids = batch["concept_ids"] + inputs = ( + batch["concept_ids"], + batch["type_ids"], + batch["time_stamps"], + batch["ages"], + batch["visit_orders"], + batch["visit_segments"], + ) labels = batch["labels"] # Ensure use of mixed precision with autocast(): outputs = self( - input_ids=concept_ids, + inputs, labels=labels, return_dict=True, ) diff --git a/odyssey/models/configs/cehr_bigbird.yaml b/odyssey/models/configs/cehr_bigbird.yaml index ab61788..c278a0d 100644 --- a/odyssey/models/configs/cehr_bigbird.yaml +++ b/odyssey/models/configs/cehr_bigbird.yaml @@ -33,7 +33,7 @@ finetune: num_workers: 6 gpus: 4 nodes: 1 - max_epochs: 4 #3 + max_epochs: 4 # 3 acc: 1 patience: 10 persistent_workers: True diff --git a/odyssey/models/configs/cehr_mamba.yaml b/odyssey/models/configs/cehr_mamba.yaml index ff27bfb..a97a377 100644 --- a/odyssey/models/configs/cehr_mamba.yaml +++ b/odyssey/models/configs/cehr_mamba.yaml @@ -1,19 +1,23 @@ model: embedding_size: 768 + time_embeddings_size: 32 + visit_order_size: 3 + type_vocab_size: 9 + max_seq_length: 2048 + max_num_visits: 512 state_size: 16 num_hidden_layers: 32 expand: 2 conv_kernel: 4 dropout_prob: 0.1 learning_rate: 5.e-5 - max_seq_length: 2048 model_finetune: learning_rate: 5.e-5 classifier_dropout: 0.1 train: - batch_size: 44 #32 + batch_size: 44 num_workers: 5 gpus: 4 nodes: 1 @@ -24,10 +28,10 @@ train: finetune: batch_size: 64 #26 - num_workers: 6 + num_workers: 5 gpus: 4 nodes: 1 - max_epochs: 5 #3 + max_epochs: 3 acc: 1 patience: 10 persistent_workers: True diff --git a/odyssey/models/embeddings.py b/odyssey/models/embeddings.py index df06d7f..71559f3 100644 --- a/odyssey/models/embeddings.py +++ b/odyssey/models/embeddings.py @@ -5,7 +5,7 @@ import torch from torch import nn -from transformers import BigBirdConfig +from transformers import BigBirdConfig, MambaConfig class TimeEmbeddingLayer(nn.Module): @@ -380,3 +380,116 @@ def forward( self.clear_cache() return embeddings + + +class MambaEmbeddingsForCEHR(nn.Module): + """Construct the embeddings from concept, token_type, etc., embeddings.""" + + # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__ + def __init__( + self, + config: MambaConfig, + type_vocab_size: int = 9, + max_num_visits: int = 512, + time_embeddings_size: int = 32, + visit_order_size: int = 3, + layer_norm_eps: float = 1e-12, + hidden_dropout_prob: float = 0.1, + ) -> None: + """Initiate wrapper class for embeddings used in Mamba CEHR classes.""" + super().__init__() + self.type_vocab_size = type_vocab_size + self.max_num_visits = max_num_visits + self.layer_norm_eps = layer_norm_eps + self.hidden_dropout_prob = hidden_dropout_prob + self.hidden_size = config.hidden_size + + self.word_embeddings = nn.Embedding( + config.vocab_size, + config.hidden_size, + padding_idx=config.pad_token_id, + ) + self.token_type_embeddings = nn.Embedding( + self.type_vocab_size, + config.hidden_size, + ) + self.visit_order_embeddings = nn.Embedding( + self.max_num_visits, + config.hidden_size, + ) + self.time_embeddings = TimeEmbeddingLayer( + embedding_size=time_embeddings_size, + is_time_delta=True, + ) + self.age_embeddings = TimeEmbeddingLayer( + embedding_size=time_embeddings_size, + ) + self.visit_segment_embeddings = VisitEmbedding( + visit_order_size=visit_order_size, + embedding_size=config.hidden_size, + ) + self.scale_back_concat_layer = nn.Linear( + config.hidden_size + 2 * time_embeddings_size, + config.hidden_size, + ) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model + # variable name and be able to load any TensorFlow checkpoint file. + self.tanh = nn.Tanh() + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=self.layer_norm_eps) + self.dropout = nn.Dropout(self.hidden_dropout_prob) + # End copy + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + token_type_ids_batch: Optional[torch.Tensor] = None, + time_stamps: Optional[torch.Tensor] = None, + ages: Optional[torch.Tensor] = None, + visit_orders: Optional[torch.Tensor] = None, + visit_segments: Optional[torch.Tensor] = None, + ) -> Any: + """Return the final embeddings of concept ids. + + Parameters + ---------- + input_ids: torch.Tensor + The input data (concept_ids) to be embedded. + inputs_embeds : torch.Tensor + The embeddings of the input data. + token_type_ids_batch : torch.Tensor + The token type IDs of the input data. + time_stamps : torch.Tensor + Time stamps of the input data. + ages : torch.Tensor + Ages of the input data. + visit_orders : torch.Tensor + Visit orders of the input data. + visit_segments : torch.Tensor + Visit segments of the input data. + """ + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + # Using cached values from a prior cache_input call + time_stamps_embeds = self.time_embeddings(time_stamps) + ages_embeds = self.age_embeddings(ages) + visit_segments_embeds = self.visit_segment_embeddings(visit_segments) + visit_order_embeds = self.visit_order_embeddings(visit_orders) + token_type_embeds = self.token_type_embeddings(token_type_ids_batch) + + inputs_embeds = torch.cat( + (inputs_embeds, time_stamps_embeds, ages_embeds), + dim=-1, + ) + + inputs_embeds = self.tanh(self.scale_back_concat_layer(inputs_embeds)) + embeddings = inputs_embeds + token_type_embeds + embeddings += visit_order_embeds + embeddings += visit_segments_embeds + + embeddings = self.dropout(embeddings) + embeddings = self.LayerNorm(embeddings) + + return embeddings