From 4dab179e51cd7185c1b7ce7b00d65228346c8d76 Mon Sep 17 00:00:00 2001 From: afallah Date: Wed, 1 May 2024 14:00:52 -0400 Subject: [PATCH] Implemented Mamba related HuggingFace utility functions and the MambaForSequenceClassification model for finetuning. --- odyssey/models/cehr_mamba/mamba-dev.ipynb | 655 ++++++++++++++++----- odyssey/models/cehr_mamba/mamba_utils.py | 155 +++++ odyssey/models/cehr_mamba/model.py | 224 ++++++- odyssey/models/cehr_mamba/playground.ipynb | 7 - 4 files changed, 876 insertions(+), 165 deletions(-) create mode 100644 odyssey/models/cehr_mamba/mamba_utils.py diff --git a/odyssey/models/cehr_mamba/mamba-dev.ipynb b/odyssey/models/cehr_mamba/mamba-dev.ipynb index 9919aef..a729375 100644 --- a/odyssey/models/cehr_mamba/mamba-dev.ipynb +++ b/odyssey/models/cehr_mamba/mamba-dev.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 4, + "execution_count": 77, "metadata": {}, "outputs": [], "source": [ @@ -25,6 +25,7 @@ " MambaConfig,\n", " MambaModel,\n", " MambaForCausalLM,\n", + " MambaPreTrainedModel\n", ")\n", "\n", "import numpy as np\n", @@ -67,22 +68,24 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "class args:\n", " data_dir = 'odyssey/data/bigbird_data'\n", - " sequence_file = 'patient_sequences_2048.parquet'\n", + " sequence_file = 'patient_sequences_2048_multi.parquet'\n", " id_file = 'dataset_2048_multi.pkl'\n", " vocab_dir = 'odyssey/data/vocab'\n", " max_len = 2048\n", - " mask_prob = 0.15" + " mask_prob = 0.15\n", + " tasks = ['mortality_1month', 'los_1week', 'c0', 'c1', 'c2']\n", + " balance_guide = {'mortality_1month': 0.5, 'los_1week': 0.5, 'c0': 0.5, 'c1': 0.5, 'c2': 0.5}" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -120,7 +123,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -148,7 +151,7 @@ ")" ] }, - "execution_count": 9, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -178,7 +181,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -206,7 +209,7 @@ ")" ] }, - "execution_count": 10, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -222,30 +225,33 @@ }, { "cell_type": "code", - "execution_count": 60, + "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "{'concept_ids': tensor([[ 5, 3, 3637, ..., 0, 0, 0]], device='cuda:0'),\n", - " 'labels': tensor([[ 5, 3, 3637, ..., 0, 0, 0]], device='cuda:0')}" + "{'concept_ids': tensor([[20592, 3, 17326, ..., 0, 0, 0]], device='cuda:0'),\n", + " 'labels': tensor([1], device='cuda:0'),\n", + " 'task': 'mortality_1month'}" ] }, - "execution_count": 60, + "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_loader = DataLoader(\n", - " test_dataset, #train_dataset\n", + " decoder_dataset, #test_dataset, #train_dataset\n", " batch_size=3,\n", " shuffle=False,\n", " )\n", "\n", - "sample = test_dataset[97] #train_dataset[0]\n", + "sample = decoder_dataset[2323] #test_dataset[8765] #train_dataset[0]\n", + "task = sample.pop('task')\n", "sample = {key:tensor.unsqueeze(0).to(device) for key, tensor in sample.items()}\n", + "sample['task'] = task\n", "\n", "# sample = next(iter(train_loader))\n", "# sample = {key:tensor.to(device) for key, tensor in sample.items()}\n", @@ -255,14 +261,14 @@ }, { "cell_type": "code", - "execution_count": 61, + "execution_count": 36, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "[CLS] [VS] 0FC78ZZ 0FC98ZZ 49281041688 00338011704 00641607825 51484_2 51491_3 51498_2 52010_3 52009_4 52008_0 52007_2 52004_4 52005_3 52006_1 51250_0 51265_3 51277_0 51279_4 51301_3 50861_4 50863_3 50868_3 50878_3 50882_1 50883_4 50884_3 50885_4 50893_4 50902_2 50912_2 50931_2 50960_3 50970_2 50971_2 50983_3 51006_0 51221_4 51222_4 51248_3 50912_1 50931_2 50971_2 50868_2 50983_3 51006_0 50878_2 50882_3 51237_2 50861_4 51274_1 50863_3 50885_4 50902_1 [VE] [REG]\n" + "[MOR_1M] [VS] 58160087546 00904516561 00182853489 00574705050 00121054410 66553000401 00310027539 00006003121 00456320563 62856024541 00310027539 00172531210 00338069104 51006_2 50983_4 50971_2 50970_2 50960_3 50931_1 50912_2 50902_4 50893_3 50882_2 50868_2 51301_1 51279_4 51277_1 51265_2 51250_0 51248_1 51222_4 51221_4 00172531110 10432017002 10432017002 51006_2 50983_4 50971_3 50970_2 50960_3 50931_4 50912_3 50902_3 50893_4 50882_3 50868_3 51301_1 51279_4 51277_1 51265_1 51250_0 51248_1 51222_4 51221_4 00904224461 51301_2 51279_4 51277_1 51265_2 51250_0 51248_1 51222_4 51221_4 51006_2 50983_4 50971_3 50970_2 50960_2 50931_2 50912_3 50902_2 50893_3 50882_3 50868_4 [VE] [REG] [W_0] [VS] 8938 8838 8744 51006_2 50983_4 50971_2 50931_1 50912_2 50902_4 50882_3 50868_3 51301_1 51279_4 51277_1 51265_2 51256_1 51254_3 51250_1 51248_1 51244_3 51222_4 51221_4 51200_3 51146_2 63323026201 00713016550 00182844789 62856024541 00182853489 00904516561 00006003121 00121065721 10432017002 00182844789 66553000401 00904224461 00310027539 [VE] [REG] [M_2] [VS] 00338004304 63323026201 62856024541 00182844789 00310027539 10432017002 00904404073 00338004904 00456320563 66553000401 55390014710 51006_2 50983_3 50971_1 50970_0 50960_3 50931_2 50912_2 50902_4 50893_1 50882_2 50868_1 51301_2 51279_4 51277_1 51265_1 51250_1 51248_1 51222_4 51221_4 51006_2 50983_4 50971_0 50970_1 50960_2 50931_0 50912_2 50902_4 50893_2 50882_1 50868_3 00310027539 00172531210 00310027839 [VE] [REG] [LT] [VS] 51498_4 51491_0 33332001001 63323026201 00456320563 51079000220 60505258600 62856024541 60505258600 60505258600 51516_1 51498_4 51493_1 51491_0 51476_1 51301_0 51279_4 51277_1 51265_1 51250_0 51248_1 51222_4 51221_4 51006_2 50983_4 50971_1 50931_1 50912_2 50902_4 50893_3 50882_3 50868_1 60505258500 00310027539 [VE] [MOR_1M]\n" ] } ], @@ -274,16 +280,16 @@ }, { "cell_type": "code", - "execution_count": 62, + "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "'50878_2 50882_3 51237_2 50861_4 51274_1 50863_3 50885_4 50902_1 [VE] [REG]'" + "'50931_1 50912_2 50902_4 50893_3 50882_3 50868_1 60505258500 00310027539 [VE] [MOR_1M]'" ] }, - "execution_count": 62, + "execution_count": 28, "metadata": {}, "output_type": "execute_result" } @@ -294,16 +300,16 @@ }, { "cell_type": "code", - "execution_count": 65, + "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "'50882_1 50885_3 50902_3 51221_4 51222_4 51248_1 51250_0 51265_3 51277_0 51279_4'" + "'50970_1 50960_2 50931_0 50912_2 50902_4 50893_3 50882_3 50868_1 [VE] [REG]'" ] }, - "execution_count": 65, + "execution_count": 29, "metadata": {}, "output_type": "execute_result" } @@ -354,12 +360,215 @@ "# visit_segments = sample['visit_segments']\n", "# )\n", "\n", + "# outputs = model(\n", + "# input_ids=sample[\"concept_ids\"], labels=sample[\"concept_ids\"], return_dict=True\n", + "# )\n", + "\n", + "# loss = outputs.loss\n", + "# logits = outputs.logits" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 2048, 768])" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# model = model.backbone\n", "outputs = model(\n", - " input_ids=sample[\"concept_ids\"], labels=sample[\"concept_ids\"], return_dict=True\n", + " input_ids=sample[\"concept_ids\"], return_dict=True\n", ")\n", "\n", - "loss = outputs.loss\n", - "logits = outputs.logits" + "last_hidden_states = outputs.last_hidden_state\n", + "last_hidden_states.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'silu'" + ] + }, + "execution_count": 55, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "config.hidden_act" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[ 0.1596, -0.0591],\n", + " [ 0.2433, -1.0078],\n", + " [-0.3050, -0.3962],\n", + " ...,\n", + " [ 0.4917, -0.3203],\n", + " [ 0.4917, -0.3203],\n", + " [ 0.4917, -0.3203]]], device='cuda:0', grad_fn=)" + ] + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "classifier = torch.nn.Linear(config.hidden_size, 2, bias=False).to(device)\n", + "logits = classifier(last_hidden_states)\n", + "logits" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(20592, device='cuda:0')" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sample[\"concept_ids\"].squeeze()[204]" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([204], device='cuda:0')" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sequence_lengths = torch.eq(sample['concept_ids'], 0).int().argmax(-1) - 1\n", + "sequence_lengths " + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 0.1697, -0.7979]], device='cuda:0', grad_fn=)" + ] + }, + "execution_count": 54, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pooled_logits = logits[torch.arange(1, device=device), sequence_lengths]\n", + "pooled_logits" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 0.1697, -0.7979]], device='cuda:0', grad_fn=)" + ] + }, + "execution_count": 53, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pooled_last_hidden_states = last_hidden_states[torch.arange(1, device=device), sequence_lengths]\n", + "classifier(pooled_last_hidden_states)" + ] + }, + { + "cell_type": "code", + "execution_count": 72, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[-0.0512, -0.2630]], device='cuda:0', grad_fn=)" + ] + }, + "execution_count": 72, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import copy\n", + "config_copy = copy.deepcopy(config)\n", + "config_copy.classifier_dropout = 0.1\n", + "head = MambaClassificationHead(config_copy).to(device)\n", + "head(pooled_last_hidden_states)" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(0.3221, device='cuda:0', grad_fn=)" + ] + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "loss_fct = torch.nn.CrossEntropyLoss()\n", + "loss = loss_fct(pooled_logits.view(-1,2), torch.tensor([0]).to(device).view(-1))\n", + "loss" ] }, { @@ -368,148 +577,288 @@ "metadata": {}, "outputs": [], "source": [ - "class MambaPretrain(pl.LightningModule):\n", - " \"\"\"Mamba model for pretraining.\"\"\"\n", + "outputs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "inputs['input_ids'].shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "last_hidden_states[:, 0, :].shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from odyssey.models.cehr_mamba.model import MambaPretrain\n", + "from torch.nn import MSELoss, CrossEntropyLoss, BCEWithLogitsLoss\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"\n", + "1. Emebeddings -> Not now\n", + "2. Padding order -> Done automatically\n", + "\n", + "---\n", + "Finetuning Approach:\n", + " 1. Replace the first and last REG token with the class token\n", + "2. Use the last hiddent state of the last token for class prediction\n", + "3. Ourselves!\n", + "\n", + "4. Dataset refactoring (inheritance, what to return, etc)\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'concept_ids': tensor([20593, 3, 13054, ..., 0, 0, 0]),\n", + " 'labels': tensor(0),\n", + " 'task': 'los_1week'}" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import random\n", + "from typing import Any, Dict, List, Optional, Tuple, Union\n", + "\n", + "import pandas as pd\n", + "import torch\n", + "from torch.utils.data import Dataset\n", + "\n", + "from odyssey.data.tokenizer import ConceptTokenizer, truncate_and_pad\n", + "\n", + "TASK_INDEX = 1\n", + "LABEL_INDEX = 2\n", + "CUTOFF_INDEX = 3\n", + "\n", + "\n", + "class FinetuneDatasetDecoder(Dataset):\n", + " \"\"\"Dataset for finetuning a decoder-based model.\n", + "\n", + " Parameters\n", + " ----------\n", + " data : pd.DataFrame\n", + " The input data containing sequences to be tokenized and masked.\n", + " tokenizer : ConceptTokenizer\n", + " An instance of the ConceptTokenizer class used for tokenizing sequences.\n", + " tasks : List[str]\n", + " A list of tasks (labels) that need to be predicted.\n", + " balance_guide : Optional[Dict[str, float]], optional\n", + " A dictionary containing the desired positive ratios for each task,\n", + " by default None.\n", + " max_len : int, optional\n", + " The maximum length of the tokenized sequences, by default 2048.\n", + " nan_indicator : int, optional\n", + " Value used to represent missing labels in the dataset, by default -1.\n", + "\n", + " Attributes\n", + " ----------\n", + " data : pd.DataFrame\n", + " Stores the input data.\n", + " tokenizer : ConceptTokenizer\n", + " Tokenizer used for tokenizing sequences.\n", + " tasks : List[str]\n", + " A list of tasks (labels) that need to be predicted.\n", + " balance_guide : Optional[Dict[str, float]]\n", + " A dictionary containing the desired positive ratios for each task.\n", + " max_len : int\n", + " Maximum length of the tokenized sequences.\n", + " nan_indicator : int\n", + " Value used to represent missing labels in the dataset.\n", + " task_to_index : Dict[str, List[Tuple[int, str, int, Optional[int]]]]\n", + " A dictionary mapping each task to a list of tuples containing the\n", + " index, task, label, and cutoff.\n", + " index_mapper : List[Tuple[int, str, int, Optional[int]]]\n", + " A list of all datapoints to be used by __getitem__.\n", + " \"\"\"\n", "\n", " def __init__(\n", " self,\n", - " vocab_size: int,\n", - " embedding_size: int = 768,\n", - " state_size: int = 16,\n", - " num_hidden_layers: int = 32,\n", - " expand: int = 2,\n", - " conv_kernel: int = 4,\n", - " learning_rate: float = 5e-5,\n", - " dropout_prob: float = 0.1,\n", - " padding_idx: int = 0,\n", - " cls_idx: int = 5,\n", + " data: pd.DataFrame,\n", + " tokenizer: ConceptTokenizer,\n", + " tasks: List[str],\n", + " balance_guide: Optional[Dict[str, float]] = None,\n", + " max_len: int = 2048,\n", + " nan_indicator: int = -1,\n", " ):\n", + " \"\"\"Initiate the class.\"\"\"\n", " super().__init__()\n", "\n", - " self.vocab_size = vocab_size\n", - " self.embedding_size = embedding_size\n", - " self.state_size = state_size\n", - " self.num_hidden_layers = num_hidden_layers\n", - " self.expand = expand\n", - " self.conv_kernel = conv_kernel\n", - " self.learning_rate = learning_rate\n", - " self.dropout_prob = dropout_prob\n", - " self.padding_idx = padding_idx\n", - " self.cls_idx = cls_idx\n", - "\n", - " self.config = MambaConfig(\n", - " vocab_size=self.vocab_size,\n", - " hidden_size=self.embedding_size,\n", - " state_size=self.state_size,\n", - " num_hidden_layers=self.num_hidden_layers,\n", - " expand=self.expand,\n", - " conv_kernel=self.conv_kernel,\n", - " pad_token_id=self.padding_idx,\n", - " bos_token_id=self.cls_idx,\n", - " eos_token_id=self.padding_idx,\n", + " self.data = data\n", + " self.tokenizer = tokenizer\n", + " self.tasks = tasks # List of tasks for which the model is being finetuned.\n", + " self.balance_guide = balance_guide\n", + " self.max_len = max_len\n", + " self.nan_indicator = (\n", + " nan_indicator # Value used to indicate missing data in labels.\n", " )\n", "\n", - " self.model = MambaForCausalLM(config=config)\n", + " # Precompute indices for quick mapping in __getitem__ that\n", + " # exclude missing labels.\n", + " # This helps in filtering out entries where the label is missing\n", + " # for the specified tasks.\n", + " self.task_to_index = {task: [] for task in self.tasks}\n", + " self.data.reset_index(drop=True, inplace=True)\n", "\n", - " def forward(\n", - " self,\n", - " input_ids: torch.Tensor,\n", - " output_hidden_states: Optional[bool] = False,\n", - " return_dict: Optional[bool] = True,\n", - " ) -> Union[Tuple[torch.Tensor, ...], MambaOutput]:\n", - " \"\"\"Forward pass for the model.\"\"\"\n", - "\n", - " return self.model(\n", - " input_ids=input_ids,\n", - " labels=input_ids,\n", - " output_hidden_states=output_hidden_states,\n", - " return_dict=return_dict,\n", - " )\n", + " for patient in self.data.itertuples():\n", + " index = patient.Index\n", "\n", - " def training_step(self, batch: Dict[str, Any], batch_idx: int) -> Any:\n", - " \"\"\"Train model on training dataset.\"\"\"\n", - " concept_ids = batch[\"concept_ids\"]\n", - "\n", - " # Ensure use of mixed precision\n", - " with autocast():\n", - " loss = self(\n", - " concept_ids,\n", - " return_dict=True,\n", - " ).loss\n", - "\n", - " (current_lr,) = self.lr_schedulers().get_last_lr()\n", - " self.log_dict(\n", - " dictionary={\"train_loss\": loss, \"lr\": current_lr},\n", - " on_step=True,\n", - " prog_bar=True,\n", - " )\n", + " for task in self.tasks:\n", + " label_col = f\"label_{task}\"\n", + " # Skip this task for the current patient if the label is missing.\n", + " if getattr(patient, label_col) == self.nan_indicator:\n", + " continue\n", "\n", - " return loss\n", - "\n", - " def validation_step(self, batch: Dict[str, Any], batch_idx: int) -> Any:\n", - " \"\"\"Evaluate model on validation dataset.\"\"\"\n", - " concept_ids = batch[\"concept_ids\"]\n", - "\n", - " # Ensure use of mixed precision\n", - " with autocast():\n", - " loss = self(\n", - " concept_ids,\n", - " return_dict=True,\n", - " ).loss\n", - "\n", - " (current_lr,) = self.lr_schedulers().get_last_lr()\n", - " self.log_dict(\n", - " dictionary={\"val_loss\": loss, \"lr\": current_lr},\n", - " on_step=True,\n", - " prog_bar=True,\n", - " sync_dist=True,\n", - " )\n", - " return loss\n", + " label = getattr(patient, label_col)\n", + " # Check for the existence of a task-specific cutoff in the data,\n", + " # else use None.\n", + " if f\"cutoff_{task}\" in self.data.columns:\n", + " cutoff = getattr(patient, f\"cutoff_{task}\")\n", + " else:\n", + " cutoff = None\n", + " # Append a tuple containing the necessary information\n", + " # for training to index_mapper.\n", + " datapoint = (index, task, label, cutoff)\n", + " self.task_to_index[task].append(datapoint)\n", "\n", - " def configure_optimizers(\n", - " self,\n", - " ) -> Tuple[list[Any], list[dict[str, SequentialLR | str]]]:\n", - " \"\"\"Configure optimizers and learning rate scheduler.\"\"\"\n", - " optimizer = optim.AdamW(\n", - " self.parameters(),\n", - " lr=self.learning_rate,\n", - " )\n", + " # Balance labels for specified tasks\n", + " if self.balance_guide:\n", + " for task in self.balance_guide:\n", + " self.balance_labels(task=task, positive_ratio=self.balance_guide[task])\n", "\n", - " n_steps = self.trainer.estimated_stepping_batches\n", - " n_warmup_steps = int(0.1 * n_steps)\n", - " n_decay_steps = int(0.9 * n_steps)\n", + " # Create a list of all datapoints to be used by __getitem__\n", + " self.index_mapper = [\n", + " datapoints\n", + " for task_data in self.task_to_index.values()\n", + " for datapoints in task_data\n", + " ]\n", + " del self.task_to_index\n", "\n", - " warmup = LinearLR(\n", - " optimizer,\n", - " start_factor=0.01,\n", - " end_factor=1.0,\n", - " total_iters=n_warmup_steps,\n", - " )\n", - " decay = LinearLR(\n", - " optimizer,\n", - " start_factor=1.0,\n", - " end_factor=0.01,\n", - " total_iters=n_decay_steps,\n", - " )\n", - " scheduler = SequentialLR(\n", - " optimizer=optimizer,\n", - " schedulers=[warmup, decay],\n", - " milestones=[n_warmup_steps],\n", - " )\n", + " def __len__(self) -> int:\n", + " \"\"\"Return the length of dataset.\"\"\"\n", + " return len(self.index_mapper)\n", "\n", - " return [optimizer], [{\"scheduler\": scheduler, \"interval\": \"step\"}]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "\"\"\"\n", - "1. Emebeddings -> Not now\n", - "2. Padding order -> Done automatically\n", - "\"\"\"" + " def __getitem__(self, idx: int) -> Dict[str, Any]:\n", + " \"\"\"Get data at corresponding index.\n", + "\n", + " Parameters\n", + " ----------\n", + " idx : int\n", + " The index of the data to be retrieved.\n", + "\n", + " Returns\n", + " -------\n", + " Dict[str, Any]\n", + " A dictionary containing all different token sequences along with labels.\n", + " \"\"\"\n", + " index, task, labels, cutoff = self.index_mapper[idx]\n", + " data = self.data.iloc[index]\n", + "\n", + " # Swap the first and last token with the task token.\n", + " data[f\"event_tokens_{self.max_len}\"][0] = self.tokenizer.task_to_token(task)\n", + " data[f\"event_tokens_{self.max_len}\"][-1] = self.tokenizer.task_to_token(task)\n", + "\n", + " # Truncate and pad the data to the specified cutoff.\n", + " data = truncate_and_pad(data, cutoff, self.max_len)\n", + "\n", + " # Prepare model input\n", + " tokenized_input = self.tokenize_data(data[f\"event_tokens_{self.max_len}\"])\n", + " concept_ids = tokenized_input[\"input_ids\"].squeeze()\n", + " labels = torch.tensor(labels)\n", + "\n", + " return {\n", + " \"concept_ids\": concept_ids,\n", + " \"labels\": labels,\n", + " \"task\": task\n", + " }\n", + "\n", + " def tokenize_data(self, sequence: Union[str, List[str]]) -> Any:\n", + " \"\"\"Tokenize the sequence and return input_ids and attention mask.\n", + "\n", + " Parameters\n", + " ----------\n", + " sequence : Union[str, List[str]]\n", + " The sequence to be tokenized.\n", + "\n", + " Returns\n", + " -------\n", + " Any\n", + " A dictionary containing input_ids and attention_mask.\n", + "\n", + " \"\"\"\n", + " return self.tokenizer(sequence, max_length=self.max_len)\n", + "\n", + " def balance_labels(self, task: str, positive_ratio: float) -> None:\n", + " \"\"\"Balance the labels for the specified task in the dataset.\n", + "\n", + " This function modifies the dataset to ensure that the ratio of positive samples\n", + " to the total number of samples matches the specified positive_ratio,\n", + " while keeping all positive data points.\n", + "\n", + " Parameters\n", + " ----------\n", + " task : str\n", + " The task for which the labels need to be balanced.\n", + " positive_ratio : float\n", + " The desired positive ratio for the task.\n", + "\n", + " \"\"\"\n", + " # Separate positive and negative datapoints\n", + " datapoints = self.task_to_index[task]\n", + " positives = [data for data in datapoints if data[LABEL_INDEX] == 1]\n", + " negatives = [data for data in datapoints if data[LABEL_INDEX] == 0]\n", + "\n", + " # Calculate the total number of samples needed to achieve the\n", + " # desired positive ratio\n", + " num_positives = len(positives)\n", + " total_needed = int(num_positives / positive_ratio) - num_positives\n", + " num_negatives_to_keep = min(len(negatives), total_needed)\n", + "\n", + " # Randomly select the negatives to keep\n", + " negatives_kept = random.sample(negatives, num_negatives_to_keep)\n", + "\n", + " # Combine the kept negatives with all positives\n", + " self.task_to_index[task] = positives + negatives_kept\n", + "\n", + "\n", + "decoder_dataset = FinetuneDatasetDecoder(\n", + " data=fine_test,\n", + " tokenizer=tokenizer,\n", + " max_len=args.max_len,\n", + " tasks=args.tasks,\n", + " balance_guide=args.balance_guide,\n", + ")\n", + "decoder_dataset[12112]" ] }, { diff --git a/odyssey/models/cehr_mamba/mamba_utils.py b/odyssey/models/cehr_mamba/mamba_utils.py new file mode 100644 index 0000000..b952111 --- /dev/null +++ b/odyssey/models/cehr_mamba/mamba_utils.py @@ -0,0 +1,155 @@ +"""Utilities following HuggingFace style for Mamba models.""" + +from typing import Any, Dict, List, Optional, Set, Union, Tuple + +import torch +from torch import nn +from torch.nn import MSELoss, CrossEntropyLoss, BCEWithLogitsLoss + +from transformers import ( + MambaModel, + MambaPreTrainedModel +) +from transformers.activations import ACT2FN +from transformers.models.mamba.modeling_mamba import ( + MambaModel, + MambaPreTrainedModel, + MAMBA_START_DOCSTRING, + MAMBA_INPUTS_DOCSTRING, +) +from transformers.utils import ( + ModelOutput, + dataclass, + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) + + +_CONFIG_FOR_DOC = "MambaConfig" + + +@dataclass +class MambaSequenceClassifierOutput(ModelOutput): + """ + Base class for outputs of Mamba sentence classification models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + + +class MambaClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.classifier_dropout) + self.out_proj = nn.Linear(config.hidden_size, config.num_labels) + self.config = config + + def forward(self, features, **kwargs): + x = features # Pooling is done by the forward pass + x = self.dropout(x) + x = self.dense(x) + x = ACT2FN[self.config.hidden_act](x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +@add_start_docstrings( + """ + Mamba Model with a sequence classification/regression head on top + (a linear layer on top of the pooled output) e.g. for GLUE tasks. + """, + MAMBA_START_DOCSTRING, +) +class MambaForSequenceClassification(MambaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + self.backbone = MambaModel(config) + self.classifier = MambaClassificationHead(config) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(MAMBA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=MambaSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[MambaSequenceClassifierOutput, Tuple[torch.FloatTensor]]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + sequence_outputs = self.backbone( + input_ids, + inputs_embeds=inputs_embeds, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_states = sequence_outputs[0] + batch_size = last_hidden_states.shape[0] + + # Pool the hidden states for the last tokens before padding to use for classification + last_token_indexes = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + pooled_last_hidden_states = last_hidden_states[ + torch.arange(batch_size, device=last_hidden_states.device), last_token_indexes + ] + + logits = self.classifier(pooled_last_hidden_states) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + sequence_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return MambaSequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=sequence_outputs.hidden_states, + ) \ No newline at end of file diff --git a/odyssey/models/cehr_mamba/model.py b/odyssey/models/cehr_mamba/model.py index dd0feff..e6460de 100644 --- a/odyssey/models/cehr_mamba/model.py +++ b/odyssey/models/cehr_mamba/model.py @@ -2,13 +2,26 @@ from typing import Any, Dict, Optional, Tuple, Union +import numpy as np import pytorch_lightning as pl + +from sklearn.metrics import ( + accuracy_score, + f1_score, + precision_score, + recall_score, + roc_auc_score, +) + import torch -from torch import optim +from torch import nn from torch.cuda.amp import autocast +from torch.optim import AdamW from torch.optim.lr_scheduler import LinearLR, SequentialLR -from transformers import MambaConfig, MambaForCausalLM -from transformers.models.mamba.modeling_mamba import MambaCausalLMOutput +from transformers import MambaConfig + +from transformers.models.mamba.modeling_mamba import MambaForCausalLM, MambaCausalLMOutput +from odyssey.models.cehr_mamba.mamba_utils import MambaForSequenceClassification, MambaSequenceClassifierOutput class MambaPretrain(pl.LightningModule): @@ -64,7 +77,7 @@ def forward( return_dict: Optional[bool] = True, ) -> Union[Tuple[torch.Tensor, ...], MambaCausalLMOutput]: """Forward pass for the model.""" - if labels == None: + if labels is None: labels = input_ids return self.model( @@ -122,7 +135,208 @@ def configure_optimizers( self, ) -> Tuple[list[Any], list[dict[str, SequentialLR | str]]]: """Configure optimizers and learning rate scheduler.""" - optimizer = optim.AdamW( + optimizer = AdamW( + self.parameters(), + lr=self.learning_rate, + ) + + n_steps = self.trainer.estimated_stepping_batches + n_warmup_steps = int(0.1 * n_steps) + n_decay_steps = int(0.9 * n_steps) + + warmup = LinearLR( + optimizer, + start_factor=0.01, + end_factor=1.0, + total_iters=n_warmup_steps, + ) + decay = LinearLR( + optimizer, + start_factor=1.0, + end_factor=0.01, + total_iters=n_decay_steps, + ) + scheduler = SequentialLR( + optimizer=optimizer, + schedulers=[warmup, decay], + milestones=[n_warmup_steps], + ) + + return [optimizer], [{"scheduler": scheduler, "interval": "step"}] + + +class MambaFinetune(pl.LightningModule): + """Mamba model for fine-tuning.""" + + def __init__( + self, + pretrained_model: MambaPretrain, + problem_type: str = "single_label_classification", + num_labels: int = 2, + learning_rate: float = 5e-5, + classifier_dropout: float = 0.1, + ): + super().__init__() + + self.num_labels = num_labels + self.learning_rate = learning_rate + self.classifier_dropout = classifier_dropout + self.test_outputs = [] + + self.config = pretrained_model.config + self.config.num_labels = self.num_labels + self.config.classifier_dropout = self.classifier_dropout + self.config.problem_type = problem_type + + self.model = MambaForSequenceClassification(config=self.config) + # self.post_init() + + self.pretrained_model = pretrained_model + self.model.backbone = self.pretrained_model.model.backbone + + def _init_weights(self, module: torch.nn.Module) -> None: + """Initialize the weights.""" + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def post_init(self) -> None: + """Apply weight initialization.""" + self.apply(self._init_weights) + + def forward( + self, + input_ids: torch.Tensor, + labels: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple[torch.Tensor, ...], MambaSequenceClassifierOutput]: + """Forward pass for the model.""" + return self.model( + input_ids=input_ids, + labels=labels, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + + def training_step(self, batch: Dict[str, Any], batch_idx: int) -> Any: + """Train model on training dataset.""" + concept_ids = batch["concept_ids"] + labels = batch["labels"] + + # Ensure use of mixed precision + with autocast(): + loss = self( + input_ids=concept_ids, + labels=labels, + return_dict=True, + ).loss + + (current_lr,) = self.lr_schedulers().get_last_lr() + self.log_dict( + dictionary={"train_loss": loss, "lr": current_lr}, + on_step=True, + prog_bar=True, + ) + + return loss + + def validation_step(self, batch: Dict[str, Any], batch_idx: int) -> Any: + """Evaluate model on validation dataset.""" + concept_ids = batch["concept_ids"] + labels = batch["labels"] + + # Ensure use of mixed precision + with autocast(): + loss = self( + input_ids=concept_ids, + labels=labels, + return_dict=True, + ).loss + + (current_lr,) = self.lr_schedulers().get_last_lr() + self.log_dict( + dictionary={"val_loss": loss, "lr": current_lr}, + on_step=True, + prog_bar=True, + ) + + return loss + + def test_step(self, batch: Dict[str, Any], batch_idx: int) -> Any: + """Test step.""" + concept_ids = batch["concept_ids"] + labels = batch["labels"] + + # Ensure use of mixed precision + with autocast(): + outputs = self( + input_ids=concept_ids, + labels=labels, + return_dict=True, + ) + + loss = outputs[0] + logits = outputs[1] + preds = torch.argmax(logits, dim=1) + log = {"loss": loss, "preds": preds, "labels": labels, "logits": logits} + + # Append the outputs to the instance attribute + self.test_outputs.append(log) + + return log + + def on_test_epoch_end(self) -> Any: + """Evaluate after the test epoch.""" + labels = torch.cat([x["labels"] for x in self.test_outputs]).cpu() + preds = torch.cat([x["preds"] for x in self.test_outputs]).cpu() + loss = torch.stack([x["loss"] for x in self.test_outputs]).mean().cpu() + logits = torch.cat([x["logits"] for x in self.test_outputs]).cpu() + + # Update the saved outputs to include all concatanted batches + self.test_outputs = { + "labels": labels, + "logits": logits, + } + + if self.config.problem_type == "multi_label_classification": + preds_one_hot = np.eye(labels.shape[1])[preds] + accuracy = accuracy_score(labels, preds_one_hot) + f1 = f1_score(labels, preds_one_hot, average="micro") + auc = roc_auc_score(labels, preds_one_hot, average="micro") + precision = precision_score(labels, preds_one_hot, average="micro") + recall = recall_score(labels, preds_one_hot, average="micro") + + else: # single_label_classification + accuracy = accuracy_score(labels, preds) + f1 = f1_score(labels, preds) + auc = roc_auc_score(labels, preds) + precision = precision_score(labels, preds) + recall = recall_score(labels, preds) + + self.log("test_loss", loss) + self.log("test_acc", accuracy) + self.log("test_f1", f1) + self.log("test_auc", auc) + self.log("test_precision", precision) + self.log("test_recall", recall) + + return loss + + def configure_optimizers( + self, + ) -> Tuple[list[Any], list[dict[str, SequentialLR | str]]]: + """Configure optimizers and learning rate scheduler.""" + optimizer = AdamW( self.parameters(), lr=self.learning_rate, ) diff --git a/odyssey/models/cehr_mamba/playground.ipynb b/odyssey/models/cehr_mamba/playground.ipynb index 2967aae..7e390b6 100644 --- a/odyssey/models/cehr_mamba/playground.ipynb +++ b/odyssey/models/cehr_mamba/playground.ipynb @@ -144,13 +144,6 @@ " return prediction_scores" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "code", "execution_count": null,