From 618854be68284a89fce4ff49fe404a5282bca263 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 30 Apr 2024 03:57:25 +0000 Subject: [PATCH] [pre-commit.ci] Add auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- finetune.py | 2 +- odyssey/data/DataProcessor.ipynb | 72 +++++++++++---------- odyssey/data/tokenizer.py | 4 +- odyssey/interp/AttentionVisualization.ipynb | 18 ++---- odyssey/models/cehr_mamba/model.py | 10 +-- odyssey/models/cehr_mamba/playground.ipynb | 38 +++++++---- odyssey/models/cehr_mamba/pretrain.ipynb | 57 +++++++++------- odyssey/models/model_utils.py | 1 - pretrain.py | 12 ++-- 9 files changed, 118 insertions(+), 96 deletions(-) diff --git a/finetune.py b/finetune.py index 512760f..c8a5ca5 100644 --- a/finetune.py +++ b/finetune.py @@ -1,8 +1,8 @@ """Finetune the pre-trained model.""" +import argparse import os import sys -import argparse from typing import Any, Dict import numpy as np diff --git a/odyssey/data/DataProcessor.ipynb b/odyssey/data/DataProcessor.ipynb index 21c211f..b4d81bf 100644 --- a/odyssey/data/DataProcessor.ipynb +++ b/odyssey/data/DataProcessor.ipynb @@ -52,9 +52,9 @@ " process_readmission_dataset,\n", " process_multi_dataset,\n", " stratified_train_test_split,\n", - " sample_balanced_subset, \n", + " sample_balanced_subset,\n", " get_pretrain_test_split,\n", - " get_finetune_split\n", + " get_finetune_split,\n", ")\n", "\n", "SEED = 23\n", @@ -87,7 +87,9 @@ "outputs": [], "source": [ "# Process the dataset for length of stay prediction above a threshold\n", - "dataset_2048_los = process_length_of_stay_dataset(dataset_2048.copy(), threshold=7, max_len=MAX_LEN)" + "dataset_2048_los = process_length_of_stay_dataset(\n", + " dataset_2048.copy(), threshold=7, max_len=MAX_LEN\n", + ")" ] }, { @@ -129,7 +131,9 @@ "outputs": [], "source": [ "# Process the dataset for hospital readmission in one month task\n", - "dataset_2048_readmission = process_readmission_dataset(dataset_2048.copy(), max_len=MAX_LEN)" + "dataset_2048_readmission = process_readmission_dataset(\n", + " dataset_2048.copy(), max_len=MAX_LEN\n", + ")" ] }, { @@ -147,7 +151,7 @@ " \"readmission\": dataset_2048_readmission,\n", " \"los\": dataset_2048_los,\n", " },\n", - " max_len=MAX_LEN\n", + " max_len=MAX_LEN,\n", ")" ] }, @@ -184,35 +188,35 @@ "outputs": [], "source": [ "task_config = {\n", - " \"mortality\": {\n", - " \"dataset\": dataset_2048_mortality,\n", - " \"label_col\": \"label_mortality_1month\",\n", - " \"finetune_size\": [250, 500, 1000, 5000, 20000],\n", - " \"save_path\": \"patient_id_dict/dataset_2048_mortality.pkl\",\n", - " \"split_mode\": \"single_label_balanced\",\n", - " },\n", - " \"readmission\": {\n", - " \"dataset\": dataset_2048_readmission,\n", - " \"label_col\": \"label_readmission_1month\",\n", - " \"finetune_size\": [250, 1000, 5000, 20000, 60000],\n", - " \"save_path\": \"patient_id_dict/dataset_2048_readmission.pkl\",\n", - " \"split_mode\": \"single_label_stratified\",\n", - " },\n", - " \"length_of_stay\": {\n", - " \"dataset\": dataset_2048_los,\n", - " \"label_col\": \"label_los_1week\",\n", - " \"finetune_size\": [250, 1000, 5000, 20000, 50000],\n", - " \"save_path\": \"patient_id_dict/dataset_2048_los.pkl\",\n", - " \"split_mode\": \"single_label_balanced\",\n", - " },\n", - " \"condition\": {\n", - " \"dataset\": dataset_2048_condition,\n", - " \"label_col\": \"all_conditions\",\n", - " \"finetune_size\": [50000],\n", - " \"save_path\": \"patient_id_dict/dataset_2048_condition.pkl\",\n", - " \"split_mode\": \"multi_label_stratified\",\n", - " },\n", - " }" + " \"mortality\": {\n", + " \"dataset\": dataset_2048_mortality,\n", + " \"label_col\": \"label_mortality_1month\",\n", + " \"finetune_size\": [250, 500, 1000, 5000, 20000],\n", + " \"save_path\": \"patient_id_dict/dataset_2048_mortality.pkl\",\n", + " \"split_mode\": \"single_label_balanced\",\n", + " },\n", + " \"readmission\": {\n", + " \"dataset\": dataset_2048_readmission,\n", + " \"label_col\": \"label_readmission_1month\",\n", + " \"finetune_size\": [250, 1000, 5000, 20000, 60000],\n", + " \"save_path\": \"patient_id_dict/dataset_2048_readmission.pkl\",\n", + " \"split_mode\": \"single_label_stratified\",\n", + " },\n", + " \"length_of_stay\": {\n", + " \"dataset\": dataset_2048_los,\n", + " \"label_col\": \"label_los_1week\",\n", + " \"finetune_size\": [250, 1000, 5000, 20000, 50000],\n", + " \"save_path\": \"patient_id_dict/dataset_2048_los.pkl\",\n", + " \"split_mode\": \"single_label_balanced\",\n", + " },\n", + " \"condition\": {\n", + " \"dataset\": dataset_2048_condition,\n", + " \"label_col\": \"all_conditions\",\n", + " \"finetune_size\": [50000],\n", + " \"save_path\": \"patient_id_dict/dataset_2048_condition.pkl\",\n", + " \"split_mode\": \"multi_label_stratified\",\n", + " },\n", + "}" ] }, { diff --git a/odyssey/data/tokenizer.py b/odyssey/data/tokenizer.py index 43835da..81900b9 100644 --- a/odyssey/data/tokenizer.py +++ b/odyssey/data/tokenizer.py @@ -127,7 +127,7 @@ class ConceptTokenizer: Tokenizer object. tokenizer_object: Tokenizer Tokenizer object. - + """ def __init__( @@ -416,7 +416,7 @@ def get_vocab_size(self) -> int: """ return len(self.tokenizer) - + def get_class_token_id(self) -> int: """Return the token id of CLS token. diff --git a/odyssey/interp/AttentionVisualization.ipynb b/odyssey/interp/AttentionVisualization.ipynb index a7703ec..82f6bb1 100644 --- a/odyssey/interp/AttentionVisualization.ipynb +++ b/odyssey/interp/AttentionVisualization.ipynb @@ -85,14 +85,10 @@ "class args:\n", " \"\"\"Save the configuration arguments.\"\"\"\n", "\n", - " model_path = (\n", - " \"checkpoints/best.ckpt\"\n", - " )\n", + " model_path = \"checkpoints/best.ckpt\"\n", " vocab_dir = \"odyssey/data/vocab\"\n", " data_dir = \"odyssey/data/bigbird_data\"\n", - " sequence_file = (\n", - " \"patient_sequences_2048_mortality.parquet\"\n", - " )\n", + " sequence_file = \"patient_sequences_2048_mortality.parquet\"\n", " id_file = \"dataset_2048_mortality.pkl\"\n", " valid_scheme = \"few_shot\"\n", " num_finetune_patients = \"20000\"\n", @@ -1097,18 +1093,18 @@ 0.025323085486888885, 0.023675603792071342, 0.016575941815972328, - 0.024279847741127014, + 0.024279847741127018, 0.016048545017838478, 0.024201413616538048, 0.032493140548467636, - 0.025662193074822426, + 0.025662193074822422, 0.01891118660569191, - 0.022078484296798706, + 0.022078484296798703, 0.05853615701198578, 0.04014396667480469, 0.04093107953667641, 0.03502603992819786, - 0.040257714688777924, + 0.04025771468877792, 0.03630286082625389, 0.03580455854535103, 0.02569277584552765, @@ -1116,7 +1112,7 @@ 0.0181843601167202, 0.03194662928581238, 0.034090328961610794, - 0.023539429530501366, + 0.02353942953050137, 0.04850481450557709, 0.019092628732323647, 0.02532283030450344 diff --git a/odyssey/models/cehr_mamba/model.py b/odyssey/models/cehr_mamba/model.py index f6e459b..dd0feff 100644 --- a/odyssey/models/cehr_mamba/model.py +++ b/odyssey/models/cehr_mamba/model.py @@ -1,17 +1,14 @@ """Mamba model.""" -from typing import Any, Dict, List, Tuple, Optional, Union +from typing import Any, Dict, Optional, Tuple, Union -import numpy as np import pytorch_lightning as pl - import torch -from torch import nn, optim +from torch import optim from torch.cuda.amp import autocast from torch.optim.lr_scheduler import LinearLR, SequentialLR - from transformers import MambaConfig, MambaForCausalLM -from transformers.models.mamba.modeling_mamba import MambaCausalLMOutput +from transformers.models.mamba.modeling_mamba import MambaCausalLMOutput class MambaPretrain(pl.LightningModule): @@ -59,7 +56,6 @@ def __init__( self.model = MambaForCausalLM(config=self.config) - def forward( self, input_ids: torch.Tensor, diff --git a/odyssey/models/cehr_mamba/playground.ipynb b/odyssey/models/cehr_mamba/playground.ipynb index d494dae..2967aae 100644 --- a/odyssey/models/cehr_mamba/playground.ipynb +++ b/odyssey/models/cehr_mamba/playground.ipynb @@ -29,7 +29,7 @@ " MambaPreTrainedModel,\n", " MambaOutput,\n", " MAMBA_START_DOCSTRING,\n", - " MAMBA_INPUTS_DOCSTRING\n", + " MAMBA_INPUTS_DOCSTRING,\n", ")\n", "from transformers.activations import ACT2FN\n", "from transformers.utils import (\n", @@ -42,7 +42,7 @@ "\n", "from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel\n", "\n", - "ROOT = '/h/afallah/odyssey/odyssey'\n", + "ROOT = \"/h/afallah/odyssey/odyssey\"\n", "_CHECKPOINT_FOR_DOC = \"state-spaces/mamba-130m-hf\"\n", "_CONFIG_FOR_DOC = \"MambaConfig\"\n", "os.chdir(ROOT)\n", @@ -58,7 +58,7 @@ "source": [ "tokenizer = AutoTokenizer.from_pretrained(\"state-spaces/mamba-130m-hf\")\n", "model = MambaForCausalLM.from_pretrained(\"state-spaces/mamba-130m-hf\")\n", - "input_ids = tokenizer(\"Hey how are you doing?\", return_tensors= \"pt\")[\"input_ids\"]\n", + "input_ids = tokenizer(\"Hey how are you doing?\", return_tensors=\"pt\")[\"input_ids\"]\n", "\n", "out = model.generate(input_ids, max_new_tokens=10)\n", "print(tokenizer.batch_decode(out))" @@ -170,7 +170,9 @@ " def __init__(self, config):\n", " super().__init__(config)\n", " self.backbone = MambaModel(config)\n", - " self.lm_head = MambaOnlyMLMHead(config) # nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n", + " self.lm_head = MambaOnlyMLMHead(\n", + " config\n", + " ) # nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n", " # Initialize weights and apply final processing\n", " self.post_init()\n", "\n", @@ -179,7 +181,7 @@ "\n", " def set_output_embeddings(self, new_embeddings):\n", " self.lm_head = new_embeddings\n", - " \n", + "\n", " def get_input_embeddings(self):\n", " return self.backbone.get_input_embeddings()\n", "\n", @@ -262,7 +264,9 @@ " 1.99\n", " ```\n", " \"\"\"\n", - " return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n", + " return_dict = (\n", + " return_dict if return_dict is not None else self.config.use_return_dict\n", + " )\n", "\n", " outputs = self.bert(\n", " input_ids,\n", @@ -284,11 +288,15 @@ " masked_lm_loss = None\n", " if labels is not None:\n", " loss_fct = CrossEntropyLoss() # -100 index = padding token\n", - " masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))\n", + " masked_lm_loss = loss_fct(\n", + " prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)\n", + " )\n", "\n", " if not return_dict:\n", " output = (prediction_scores,) + outputs[2:]\n", - " return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output\n", + " return (\n", + " ((masked_lm_loss,) + output) if masked_lm_loss is not None else output\n", + " )\n", "\n", " return MaskedLMOutput(\n", " loss=masked_lm_loss,\n", @@ -297,16 +305,24 @@ " attentions=outputs.attentions,\n", " )\n", "\n", - " def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):\n", + " def prepare_inputs_for_generation(\n", + " self, input_ids, attention_mask=None, **model_kwargs\n", + " ):\n", " input_shape = input_ids.shape\n", " effective_batch_size = input_shape[0]\n", "\n", " # add a dummy token\n", " if self.config.pad_token_id is None:\n", " raise ValueError(\"The PAD token should be defined for generation\")\n", - " attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)\n", + " attention_mask = torch.cat(\n", + " [attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))],\n", + " dim=-1,\n", + " )\n", " dummy_token = torch.full(\n", - " (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device\n", + " (effective_batch_size, 1),\n", + " self.config.pad_token_id,\n", + " dtype=torch.long,\n", + " device=input_ids.device,\n", " )\n", " input_ids = torch.cat([input_ids, dummy_token], dim=1)\n", "\n", diff --git a/odyssey/models/cehr_mamba/pretrain.ipynb b/odyssey/models/cehr_mamba/pretrain.ipynb index f05c1d7..fbbdff7 100644 --- a/odyssey/models/cehr_mamba/pretrain.ipynb +++ b/odyssey/models/cehr_mamba/pretrain.ipynb @@ -18,7 +18,14 @@ "import numpy as np\n", "import pandas as pd\n", "from tokenizers import Tokenizer, models, pre_tokenizers\n", - "from transformers import BatchEncoding, PreTrainedTokenizerFast, AutoTokenizer, MambaConfig, MambaModel, MambaForCausalLM\n", + "from transformers import (\n", + " BatchEncoding,\n", + " PreTrainedTokenizerFast,\n", + " AutoTokenizer,\n", + " MambaConfig,\n", + " MambaModel,\n", + " MambaForCausalLM,\n", + ")\n", "\n", "import numpy as np\n", "import pytorch_lightning as pl\n", @@ -39,7 +46,7 @@ "\n", "# from mamba_ssm.models.mixer_seq_simple import MambaConfig, MambaLMHeadModel\n", "\n", - "ROOT = '/h/afallah/odyssey/odyssey'\n", + "ROOT = \"/h/afallah/odyssey/odyssey\"\n", "os.chdir(ROOT)\n", "\n", "from odyssey.models.embeddings import *\n", @@ -64,10 +71,10 @@ "outputs": [], "source": [ "class args:\n", - " data_dir = 'odyssey/data/bigbird_data'\n", - " sequence_file = 'patient_sequences/patient_sequences_2048.parquet'\n", - " id_file = 'patient_id_dict/dataset_2048_multi.pkl'\n", - " vocab_dir = 'odyssey/data/vocab'\n", + " data_dir = \"odyssey/data/bigbird_data\"\n", + " sequence_file = \"patient_sequences/patient_sequences_2048.parquet\"\n", + " id_file = \"patient_id_dict/dataset_2048_multi.pkl\"\n", + " vocab_dir = \"odyssey/data/vocab\"\n", " max_len = 2048\n", " mask_prob = 0.15" ] @@ -93,7 +100,7 @@ "# tokenizer=tokenizer,\n", "# max_len=args.max_len,\n", "# mask_prob=args.mask_prob,\n", - " # )" + "# )" ] }, { @@ -147,9 +154,7 @@ "# config=config\n", "# )\n", "\n", - "model = MambaForCausalLM(\n", - " config=config\n", - ")\n", + "model = MambaForCausalLM(config=config)\n", "# model.backbone.embeddings = embeddings\n", "model.to(device)\n", "\n", @@ -163,16 +168,16 @@ "outputs": [], "source": [ "train_loader = DataLoader(\n", - " train_dataset,\n", - " batch_size=3,\n", - " shuffle=False,\n", - " )\n", + " train_dataset,\n", + " batch_size=3,\n", + " shuffle=False,\n", + ")\n", "\n", "# sample = train_dataset[0]\n", "# sample = {key:tensor.unsqueeze(0).to(device) for key, tensor in sample.items()}\n", "\n", "sample = next(iter(train_loader))\n", - "sample = {key:tensor.to(device) for key, tensor in sample.items()}" + "sample = {key: tensor.to(device) for key, tensor in sample.items()}" ] }, { @@ -221,9 +226,7 @@ "# )\n", "\n", "outputs = model(\n", - " input_ids=sample[\"concept_ids\"],\n", - " labels=sample[\"concept_ids\"],\n", - " return_dict=True\n", + " input_ids=sample[\"concept_ids\"], labels=sample[\"concept_ids\"], return_dict=True\n", ")\n", "\n", "loss = outputs.loss\n", @@ -279,7 +282,6 @@ "\n", " self.model = MambaForCausalLM(config=config)\n", "\n", - "\n", " def forward(\n", " self,\n", " input_ids: torch.Tensor,\n", @@ -399,7 +401,7 @@ " time_embeddings_size: int = 16,\n", " visit_order_size: int = 3,\n", " layer_norm_eps: float = 1e-12,\n", - " hidden_dropout_prob: float = 0.1\n", + " hidden_dropout_prob: float = 0.1,\n", " ) -> None:\n", " \"\"\"Initiate wrapper class for embeddings used in BigBird CEHR classes.\"\"\"\n", " super().__init__()\n", @@ -479,7 +481,7 @@ " time_stamps: Optional[torch.Tensor] = None,\n", " ages: Optional[torch.Tensor] = None,\n", " visit_orders: Optional[torch.Tensor] = None,\n", - " visit_segments: Optional[torch.Tensor] = None\n", + " visit_segments: Optional[torch.Tensor] = None,\n", " ) -> None:\n", " \"\"\"Cache values for time_stamps, ages, visit_orders & visit_segments.\n", "\n", @@ -509,11 +511,18 @@ " self.ages = ages\n", " self.visit_orders = visit_orders\n", " self.visit_segments = visit_segments\n", - " \n", + "\n", " def clear_cache(self) -> None:\n", " \"\"\"Delete the tensors cached by cache_input method.\"\"\"\n", - " del self.token_type_ids_batch, self.position_ids_batch, self.inputs_embeds, \\\n", - " self.time_stamps, self.ages, self.visit_orders, self.visit_segments\n", + " del (\n", + " self.token_type_ids_batch,\n", + " self.position_ids_batch,\n", + " self.inputs_embeds,\n", + " self.time_stamps,\n", + " self.ages,\n", + " self.visit_orders,\n", + " self.visit_segments,\n", + " )\n", "\n", " def forward(\n", " self,\n", diff --git a/odyssey/models/model_utils.py b/odyssey/models/model_utils.py index 914772f..73d8cca 100644 --- a/odyssey/models/model_utils.py +++ b/odyssey/models/model_utils.py @@ -13,7 +13,6 @@ from odyssey.data.tokenizer import ConceptTokenizer from odyssey.models.cehr_bert.model import BertFinetune, BertPretrain from odyssey.models.cehr_big_bird.model import BigBirdFinetune, BigBirdPretrain -from odyssey.models.cehr_mamba.model import MambaPretrain def load_config(config_dir: str, model_type: str) -> Any: diff --git a/pretrain.py b/pretrain.py index 392f17e..a98aeff 100644 --- a/pretrain.py +++ b/pretrain.py @@ -5,8 +5,8 @@ import sys from typing import Any, Dict -import torch import pytorch_lightning as pl +import torch from lightning.pytorch.loggers import WandbLogger from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint from pytorch_lightning.strategies.ddp import DDPStrategy @@ -51,7 +51,7 @@ def main(args: argparse.Namespace, model_config: Dict[str, Any]) -> None: tokenizer.fit_on_vocab() # Load datasets - if args.model_type == 'cehr_mamba': # Decoder model + if args.model_type == "cehr_mamba": # Decoder model train_dataset = PretrainDatasetDecoder( data=pre_train, tokenizer=tokenizer, @@ -61,8 +61,8 @@ def main(args: argparse.Namespace, model_config: Dict[str, Any]) -> None: data=pre_val, tokenizer=tokenizer, max_len=args.max_len, - ) - + ) + else: train_dataset = PretrainDataset( data=pre_train, @@ -267,7 +267,9 @@ def main(args: argparse.Namespace, model_config: Dict[str, Any]) -> None: args = parser.parse_args() if args.model_type not in ["cehr_bert", "cehr_bigbird", "cehr_mamba"]: - print("Invalid model type. Choose 'cehr_bert' or 'cehr_bigbird' or 'cehr_mamba'.") + print( + "Invalid model type. Choose 'cehr_bert' or 'cehr_bigbird' or 'cehr_mamba'." + ) sys.exit(1) args.checkpoint_dir = os.path.join(args.checkpoint_dir, args.exp_name)