Skip to content

Commit

Permalink
Fixed merge conflicts.
Browse files Browse the repository at this point in the history
  • Loading branch information
Adibvafa committed May 1, 2024
2 parents e6d946c + 618854b commit cef4afe
Show file tree
Hide file tree
Showing 9 changed files with 108 additions and 86 deletions.
2 changes: 1 addition & 1 deletion finetune.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Finetune the pre-trained model."""

import argparse
import os
import sys
import argparse
from typing import Any, Dict

import numpy as np
Expand Down
72 changes: 38 additions & 34 deletions odyssey/data/DataProcessor.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@
" process_readmission_dataset,\n",
" process_multi_dataset,\n",
" stratified_train_test_split,\n",
" sample_balanced_subset, \n",
" sample_balanced_subset,\n",
" get_pretrain_test_split,\n",
" get_finetune_split\n",
" get_finetune_split,\n",
")\n",
"\n",
"SEED = 23\n",
Expand Down Expand Up @@ -87,7 +87,9 @@
"outputs": [],
"source": [
"# Process the dataset for length of stay prediction above a threshold\n",
"dataset_2048_los = process_length_of_stay_dataset(dataset_2048.copy(), threshold=7, max_len=MAX_LEN)"
"dataset_2048_los = process_length_of_stay_dataset(\n",
" dataset_2048.copy(), threshold=7, max_len=MAX_LEN\n",
")"
]
},
{
Expand Down Expand Up @@ -129,7 +131,9 @@
"outputs": [],
"source": [
"# Process the dataset for hospital readmission in one month task\n",
"dataset_2048_readmission = process_readmission_dataset(dataset_2048.copy(), max_len=MAX_LEN)"
"dataset_2048_readmission = process_readmission_dataset(\n",
" dataset_2048.copy(), max_len=MAX_LEN\n",
")"
]
},
{
Expand All @@ -147,7 +151,7 @@
" \"readmission\": dataset_2048_readmission,\n",
" \"los\": dataset_2048_los,\n",
" },\n",
" max_len=MAX_LEN\n",
" max_len=MAX_LEN,\n",
")"
]
},
Expand Down Expand Up @@ -184,35 +188,35 @@
"outputs": [],
"source": [
"task_config = {\n",
" \"mortality\": {\n",
" \"dataset\": dataset_2048_mortality,\n",
" \"label_col\": \"label_mortality_1month\",\n",
" \"finetune_size\": [250, 500, 1000, 5000, 20000],\n",
" \"save_path\": \"patient_id_dict/dataset_2048_mortality.pkl\",\n",
" \"split_mode\": \"single_label_balanced\",\n",
" },\n",
" \"readmission\": {\n",
" \"dataset\": dataset_2048_readmission,\n",
" \"label_col\": \"label_readmission_1month\",\n",
" \"finetune_size\": [250, 1000, 5000, 20000, 60000],\n",
" \"save_path\": \"patient_id_dict/dataset_2048_readmission.pkl\",\n",
" \"split_mode\": \"single_label_stratified\",\n",
" },\n",
" \"length_of_stay\": {\n",
" \"dataset\": dataset_2048_los,\n",
" \"label_col\": \"label_los_1week\",\n",
" \"finetune_size\": [250, 1000, 5000, 20000, 50000],\n",
" \"save_path\": \"patient_id_dict/dataset_2048_los.pkl\",\n",
" \"split_mode\": \"single_label_balanced\",\n",
" },\n",
" \"condition\": {\n",
" \"dataset\": dataset_2048_condition,\n",
" \"label_col\": \"all_conditions\",\n",
" \"finetune_size\": [50000],\n",
" \"save_path\": \"patient_id_dict/dataset_2048_condition.pkl\",\n",
" \"split_mode\": \"multi_label_stratified\",\n",
" },\n",
" }"
" \"mortality\": {\n",
" \"dataset\": dataset_2048_mortality,\n",
" \"label_col\": \"label_mortality_1month\",\n",
" \"finetune_size\": [250, 500, 1000, 5000, 20000],\n",
" \"save_path\": \"patient_id_dict/dataset_2048_mortality.pkl\",\n",
" \"split_mode\": \"single_label_balanced\",\n",
" },\n",
" \"readmission\": {\n",
" \"dataset\": dataset_2048_readmission,\n",
" \"label_col\": \"label_readmission_1month\",\n",
" \"finetune_size\": [250, 1000, 5000, 20000, 60000],\n",
" \"save_path\": \"patient_id_dict/dataset_2048_readmission.pkl\",\n",
" \"split_mode\": \"single_label_stratified\",\n",
" },\n",
" \"length_of_stay\": {\n",
" \"dataset\": dataset_2048_los,\n",
" \"label_col\": \"label_los_1week\",\n",
" \"finetune_size\": [250, 1000, 5000, 20000, 50000],\n",
" \"save_path\": \"patient_id_dict/dataset_2048_los.pkl\",\n",
" \"split_mode\": \"single_label_balanced\",\n",
" },\n",
" \"condition\": {\n",
" \"dataset\": dataset_2048_condition,\n",
" \"label_col\": \"all_conditions\",\n",
" \"finetune_size\": [50000],\n",
" \"save_path\": \"patient_id_dict/dataset_2048_condition.pkl\",\n",
" \"split_mode\": \"multi_label_stratified\",\n",
" },\n",
"}"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions odyssey/data/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class ConceptTokenizer:
Tokenizer object.
tokenizer_object: Tokenizer
Tokenizer object.
"""

def __init__(
Expand Down Expand Up @@ -416,7 +416,7 @@ def get_vocab_size(self) -> int:
"""
return len(self.tokenizer)

def get_class_token_id(self) -> int:
"""Return the token id of CLS token.
Expand Down
18 changes: 7 additions & 11 deletions odyssey/interp/AttentionVisualization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,10 @@
"class args:\n",
" \"\"\"Save the configuration arguments.\"\"\"\n",
"\n",
" model_path = (\n",
" \"checkpoints/best.ckpt\"\n",
" )\n",
" model_path = \"checkpoints/best.ckpt\"\n",
" vocab_dir = \"odyssey/data/vocab\"\n",
" data_dir = \"odyssey/data/bigbird_data\"\n",
" sequence_file = (\n",
" \"patient_sequences_2048_mortality.parquet\"\n",
" )\n",
" sequence_file = \"patient_sequences_2048_mortality.parquet\"\n",
" id_file = \"dataset_2048_mortality.pkl\"\n",
" valid_scheme = \"few_shot\"\n",
" num_finetune_patients = \"20000\"\n",
Expand Down Expand Up @@ -1097,26 +1093,26 @@
0.025323085486888885,
0.023675603792071342,
0.016575941815972328,
0.024279847741127014,
0.024279847741127018,
0.016048545017838478,
0.024201413616538048,
0.032493140548467636,
0.025662193074822426,
0.025662193074822422,
0.01891118660569191,
0.022078484296798706,
0.022078484296798703,
0.05853615701198578,
0.04014396667480469,
0.04093107953667641,
0.03502603992819786,
0.040257714688777924,
0.04025771468877792,
0.03630286082625389,
0.03580455854535103,
0.02569277584552765,
0.0514972023665905,
0.0181843601167202,
0.03194662928581238,
0.034090328961610794,
0.023539429530501366,
0.02353942953050137,
0.04850481450557709,
0.019092628732323647,
0.02532283030450344
Expand Down
37 changes: 23 additions & 14 deletions odyssey/models/cehr_mamba/mamba-dev.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,14 @@
"import numpy as np\n",
"import pandas as pd\n",
"from tokenizers import Tokenizer, models, pre_tokenizers\n",
"from transformers import BatchEncoding, PreTrainedTokenizerFast, AutoTokenizer, MambaConfig, MambaModel, MambaForCausalLM\n",
"from transformers import (\n",
" BatchEncoding,\n",
" PreTrainedTokenizerFast,\n",
" AutoTokenizer,\n",
" MambaConfig,\n",
" MambaModel,\n",
" MambaForCausalLM,\n",
")\n",
"\n",
"import numpy as np\n",
"import pytorch_lightning as pl\n",
Expand All @@ -39,7 +46,7 @@
"\n",
"# from mamba_ssm.models.mixer_seq_simple import MambaConfig, MambaLMHeadModel\n",
"\n",
"ROOT = '/h/afallah/odyssey/odyssey'\n",
"ROOT = \"/h/afallah/odyssey/odyssey\"\n",
"os.chdir(ROOT)\n",
"\n",
"from odyssey.models.embeddings import *\n",
Expand Down Expand Up @@ -162,9 +169,7 @@
"# config=config\n",
"# )\n",
"\n",
"model = MambaForCausalLM(\n",
" config=config\n",
")\n",
"model = MambaForCausalLM(config=config)\n",
"# model.backbone.embeddings = embeddings\n",
"model.to(device)\n",
"\n",
Expand Down Expand Up @@ -350,9 +355,7 @@
"# )\n",
"\n",
"outputs = model(\n",
" input_ids=sample[\"concept_ids\"],\n",
" labels=sample[\"concept_ids\"],\n",
" return_dict=True\n",
" input_ids=sample[\"concept_ids\"], labels=sample[\"concept_ids\"], return_dict=True\n",
")\n",
"\n",
"loss = outputs.loss\n",
Expand Down Expand Up @@ -408,7 +411,6 @@
"\n",
" self.model = MambaForCausalLM(config=config)\n",
"\n",
"\n",
" def forward(\n",
" self,\n",
" input_ids: torch.Tensor,\n",
Expand Down Expand Up @@ -528,7 +530,7 @@
" time_embeddings_size: int = 16,\n",
" visit_order_size: int = 3,\n",
" layer_norm_eps: float = 1e-12,\n",
" hidden_dropout_prob: float = 0.1\n",
" hidden_dropout_prob: float = 0.1,\n",
" ) -> None:\n",
" \"\"\"Initiate wrapper class for embeddings used in BigBird CEHR classes.\"\"\"\n",
" super().__init__()\n",
Expand Down Expand Up @@ -608,7 +610,7 @@
" time_stamps: Optional[torch.Tensor] = None,\n",
" ages: Optional[torch.Tensor] = None,\n",
" visit_orders: Optional[torch.Tensor] = None,\n",
" visit_segments: Optional[torch.Tensor] = None\n",
" visit_segments: Optional[torch.Tensor] = None,\n",
" ) -> None:\n",
" \"\"\"Cache values for time_stamps, ages, visit_orders & visit_segments.\n",
"\n",
Expand Down Expand Up @@ -638,11 +640,18 @@
" self.ages = ages\n",
" self.visit_orders = visit_orders\n",
" self.visit_segments = visit_segments\n",
" \n",
"\n",
" def clear_cache(self) -> None:\n",
" \"\"\"Delete the tensors cached by cache_input method.\"\"\"\n",
" del self.token_type_ids_batch, self.position_ids_batch, self.inputs_embeds, \\\n",
" self.time_stamps, self.ages, self.visit_orders, self.visit_segments\n",
" del (\n",
" self.token_type_ids_batch,\n",
" self.position_ids_batch,\n",
" self.inputs_embeds,\n",
" self.time_stamps,\n",
" self.ages,\n",
" self.visit_orders,\n",
" self.visit_segments,\n",
" )\n",
"\n",
" def forward(\n",
" self,\n",
Expand Down
10 changes: 3 additions & 7 deletions odyssey/models/cehr_mamba/model.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
"""Mamba model."""

from typing import Any, Dict, List, Tuple, Optional, Union
from typing import Any, Dict, Optional, Tuple, Union

import numpy as np
import pytorch_lightning as pl

import torch
from torch import nn, optim
from torch import optim
from torch.cuda.amp import autocast
from torch.optim.lr_scheduler import LinearLR, SequentialLR

from transformers import MambaConfig, MambaForCausalLM
from transformers.models.mamba.modeling_mamba import MambaCausalLMOutput
from transformers.models.mamba.modeling_mamba import MambaCausalLMOutput


class MambaPretrain(pl.LightningModule):
Expand Down Expand Up @@ -59,7 +56,6 @@ def __init__(

self.model = MambaForCausalLM(config=self.config)


def forward(
self,
input_ids: torch.Tensor,
Expand Down
Loading

0 comments on commit cef4afe

Please sign in to comment.