Skip to content

Commit

Permalink
Merge pull request #35 from VectorInstitute/feature/mamba
Browse files Browse the repository at this point in the history
Implemented the functionality to use custom embeddings for Mamba models.
  • Loading branch information
amrit110 authored May 6, 2024
2 parents de1c61f + 9a429ca commit 426658f
Show file tree
Hide file tree
Showing 9 changed files with 617 additions and 708 deletions.
26 changes: 3 additions & 23 deletions odyssey/data/DataProcessor.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,15 @@
"cells": [
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2024-03-13T16:14:45.546088300Z",
"start_time": "2024-03-13T16:14:43.587090300Z"
},
"collapsed": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[rank: 0] Seed set to 23\n"
]
}
],
"outputs": [],
"source": [
"import os\n",
"import sys\n",
Expand All @@ -34,7 +26,7 @@
"DATASET = f\"{DATA_ROOT}/patient_sequences/patient_sequences_2048.parquet\"\n",
"MAX_LEN = 2048\n",
"\n",
"os.chdir(DATA_ROOT)\n",
"os.chdir(ROOT)\n",
"\n",
"from odyssey.utils.utils import seed_everything\n",
"from odyssey.data.tokenizer import ConceptTokenizer\n",
Expand Down Expand Up @@ -467,18 +459,6 @@
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
Expand Down
76 changes: 59 additions & 17 deletions odyssey/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,33 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""
data = self.data.iloc[idx]
cutoff = data[self.cutoff_col] if self.cutoff_col else None
data = truncate_and_pad(data, cutoff=cutoff, max_len=self.max_len)

# Truncate and pad the data to the specified cutoff.
data = truncate_and_pad(data, cutoff, self.max_len)

# Prepare model input
tokenized_input = self.tokenize_data(data[f"event_tokens_{self.max_len}"])
concept_ids = tokenized_input["input_ids"].squeeze()

type_tokens = data[f"type_tokens_{self.max_len}"]
age_tokens = data[f"age_tokens_{self.max_len}"]
time_tokens = data[f"time_tokens_{self.max_len}"]
visit_tokens = data[f"visit_tokens_{self.max_len}"]
position_tokens = data[f"position_tokens_{self.max_len}"]

type_tokens = torch.tensor(type_tokens)
age_tokens = torch.tensor(age_tokens)
time_tokens = torch.tensor(time_tokens)
visit_tokens = torch.tensor(visit_tokens)
position_tokens = torch.tensor(position_tokens)

return {
"concept_ids": concept_ids,
"type_ids": type_tokens,
"ages": age_tokens,
"time_stamps": time_tokens,
"visit_orders": position_tokens,
"visit_segments": visit_tokens,
"labels": concept_ids,
}

Expand Down Expand Up @@ -691,6 +712,22 @@ def __len__(self) -> int:
"""Return the length of dataset."""
return len(self.index_mapper)

def tokenize_data(self, sequence: Union[str, List[str]]) -> Any:
"""Tokenize the sequence and return input_ids and attention mask.
Parameters
----------
sequence : Union[str, List[str]]
The sequence to be tokenized.
Returns
-------
Any
A dictionary containing input_ids and attention_mask.
"""
return self.tokenizer(sequence, max_length=self.max_len)

def __getitem__(self, idx: int) -> Dict[str, Any]:
"""Get data at corresponding index.
Expand All @@ -717,25 +754,30 @@ def __getitem__(self, idx: int) -> Dict[str, Any]:
# Prepare model input
tokenized_input = self.tokenize_data(data[f"event_tokens_{self.max_len}"])
concept_ids = tokenized_input["input_ids"].squeeze()
labels = torch.tensor(labels)

return {"concept_ids": concept_ids, "labels": labels, "task": task}

def tokenize_data(self, sequence: Union[str, List[str]]) -> Any:
"""Tokenize the sequence and return input_ids and attention mask.

Parameters
----------
sequence : Union[str, List[str]]
The sequence to be tokenized.
type_tokens = data[f"type_tokens_{self.max_len}"]
age_tokens = data[f"age_tokens_{self.max_len}"]
time_tokens = data[f"time_tokens_{self.max_len}"]
visit_tokens = data[f"visit_tokens_{self.max_len}"]
position_tokens = data[f"position_tokens_{self.max_len}"]

Returns
-------
Any
A dictionary containing input_ids and attention_mask.
type_tokens = torch.tensor(type_tokens)
age_tokens = torch.tensor(age_tokens)
time_tokens = torch.tensor(time_tokens)
visit_tokens = torch.tensor(visit_tokens)
position_tokens = torch.tensor(position_tokens)
labels = torch.tensor(labels)

"""
return self.tokenizer(sequence, max_length=self.max_len)
return {
"concept_ids": concept_ids,
"type_ids": type_tokens,
"ages": age_tokens,
"time_stamps": time_tokens,
"visit_orders": position_tokens,
"visit_segments": visit_tokens,
"labels": labels,
"task": task,
}

def balance_labels(self, task: str, positive_ratio: float) -> None:
"""Balance the labels for the specified task in the dataset.
Expand Down
83 changes: 67 additions & 16 deletions odyssey/evals/TestAnalysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -24,25 +24,25 @@
"ROOT = \"/fs01/home/afallah/odyssey/odyssey\"\n",
"os.chdir(ROOT)\n",
"\n",
"from odyssey.data.dataset import FinetuneMultiDataset\n",
"from odyssey.data.dataset import FinetuneMultiDataset, FinetuneDatasetDecoder\n",
"from odyssey.data.tokenizer import ConceptTokenizer\n",
"from odyssey.models.model_utils import load_finetune_data"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"class config:\n",
" \"\"\"Save the configuration arguments.\"\"\"\n",
"\n",
" model_path = \"checkpoints/multibird_finetune/multibird_finetune/test_outputs/test_outputs_1ec842db.pt\"\n",
" model_path = \"checkpoints/mamba_finetune/test_outputs/test_outputs_f8471ffd.pt\"\n",
" vocab_dir = \"odyssey/data/vocab\"\n",
" data_dir = \"odyssey/data/bigbird_data\"\n",
" sequence_file = \"patient_sequences/patient_sequences_2048_multi.parquet\"\n",
" id_file = \"patient_id_dict/dataset_2048_multi.pkl\"\n",
" sequence_file = \"patient_sequences_2048_multi.parquet\"\n",
" id_file = \"dataset_2048_multi.pkl\"\n",
" valid_scheme = \"few_shot\"\n",
" num_finetune_patients = \"all\"\n",
" # label_name = \"label_mortality_1month\"\n",
Expand All @@ -55,7 +55,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -80,7 +80,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -97,7 +97,7 @@
" config.num_finetune_patients,\n",
")\n",
"\n",
"test_dataset = FinetuneMultiDataset(\n",
"test_dataset = FinetuneDatasetDecoder(\n",
" data=fine_test,\n",
" tokenizer=tokenizer,\n",
" tasks=config.tasks,\n",
Expand Down Expand Up @@ -127,19 +127,41 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 16,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"dict_keys(['labels', 'logits'])"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_outputs = torch.load(config.model_path, map_location=config.device)\n",
"test_outputs.keys()"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 17,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"array([0, 0, 0, ..., 0, 0, 0])"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"labels = test_outputs[\"labels\"].cpu().numpy()\n",
"logits = test_outputs[\"logits\"].cpu().numpy()\n",
Expand All @@ -151,15 +173,32 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 22,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"{'Balanced Accuracy': 0.8985595582256243,\n",
" 'F1 Score': 0.6536144578313253,\n",
" 'Precision': 0.5174058178350024,\n",
" 'Recall': 0.8871627146361406,\n",
" 'AUROC': 0.9667942606114659,\n",
" 'Average Precision Score': 0.47009681385740587,\n",
" 'AUC-PR': 0.7078210982047579}"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Tasks we have are: tasks = ['mortality_1month', 'los_1week', 'c0', 'c1', 'c2']\n",
"\n",
"task_idx = []\n",
"for i, task in enumerate(tasks):\n",
" if task == \"los_1week\":\n",
" if task == \"mortality_1month\":\n",
" task_idx.append(i)\n",
"\n",
"\n",
Expand Down Expand Up @@ -224,6 +263,18 @@
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit 426658f

Please sign in to comment.