Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented the functionality to use custom embeddings for Mamba models. #35

Merged
merged 10 commits into from
May 6, 2024
Merged
26 changes: 3 additions & 23 deletions odyssey/data/DataProcessor.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,15 @@
"cells": [
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2024-03-13T16:14:45.546088300Z",
"start_time": "2024-03-13T16:14:43.587090300Z"
},
"collapsed": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[rank: 0] Seed set to 23\n"
]
}
],
"outputs": [],
"source": [
"import os\n",
"import sys\n",
Expand All @@ -34,7 +26,7 @@
"DATASET = f\"{DATA_ROOT}/patient_sequences/patient_sequences_2048.parquet\"\n",
"MAX_LEN = 2048\n",
"\n",
"os.chdir(DATA_ROOT)\n",
"os.chdir(ROOT)\n",
"\n",
"from odyssey.utils.utils import seed_everything\n",
"from odyssey.data.tokenizer import ConceptTokenizer\n",
Expand Down Expand Up @@ -467,18 +459,6 @@
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
Expand Down
76 changes: 59 additions & 17 deletions odyssey/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,33 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""
data = self.data.iloc[idx]
cutoff = data[self.cutoff_col] if self.cutoff_col else None
data = truncate_and_pad(data, cutoff=cutoff, max_len=self.max_len)

# Truncate and pad the data to the specified cutoff.
data = truncate_and_pad(data, cutoff, self.max_len)

# Prepare model input
tokenized_input = self.tokenize_data(data[f"event_tokens_{self.max_len}"])
concept_ids = tokenized_input["input_ids"].squeeze()

type_tokens = data[f"type_tokens_{self.max_len}"]
age_tokens = data[f"age_tokens_{self.max_len}"]
time_tokens = data[f"time_tokens_{self.max_len}"]
visit_tokens = data[f"visit_tokens_{self.max_len}"]
position_tokens = data[f"position_tokens_{self.max_len}"]

type_tokens = torch.tensor(type_tokens)
age_tokens = torch.tensor(age_tokens)
time_tokens = torch.tensor(time_tokens)
visit_tokens = torch.tensor(visit_tokens)
position_tokens = torch.tensor(position_tokens)

return {
"concept_ids": concept_ids,
"type_ids": type_tokens,
"ages": age_tokens,
"time_stamps": time_tokens,
"visit_orders": position_tokens,
"visit_segments": visit_tokens,
"labels": concept_ids,
}

Expand Down Expand Up @@ -691,6 +712,22 @@ def __len__(self) -> int:
"""Return the length of dataset."""
return len(self.index_mapper)

def tokenize_data(self, sequence: Union[str, List[str]]) -> Any:
"""Tokenize the sequence and return input_ids and attention mask.

Parameters
----------
sequence : Union[str, List[str]]
The sequence to be tokenized.

Returns
-------
Any
A dictionary containing input_ids and attention_mask.

"""
return self.tokenizer(sequence, max_length=self.max_len)

def __getitem__(self, idx: int) -> Dict[str, Any]:
"""Get data at corresponding index.

Expand All @@ -717,25 +754,30 @@ def __getitem__(self, idx: int) -> Dict[str, Any]:
# Prepare model input
tokenized_input = self.tokenize_data(data[f"event_tokens_{self.max_len}"])
concept_ids = tokenized_input["input_ids"].squeeze()
labels = torch.tensor(labels)

return {"concept_ids": concept_ids, "labels": labels, "task": task}

def tokenize_data(self, sequence: Union[str, List[str]]) -> Any:
"""Tokenize the sequence and return input_ids and attention mask.

Parameters
----------
sequence : Union[str, List[str]]
The sequence to be tokenized.
type_tokens = data[f"type_tokens_{self.max_len}"]
age_tokens = data[f"age_tokens_{self.max_len}"]
time_tokens = data[f"time_tokens_{self.max_len}"]
visit_tokens = data[f"visit_tokens_{self.max_len}"]
position_tokens = data[f"position_tokens_{self.max_len}"]

Returns
-------
Any
A dictionary containing input_ids and attention_mask.
type_tokens = torch.tensor(type_tokens)
age_tokens = torch.tensor(age_tokens)
time_tokens = torch.tensor(time_tokens)
visit_tokens = torch.tensor(visit_tokens)
position_tokens = torch.tensor(position_tokens)
labels = torch.tensor(labels)

"""
return self.tokenizer(sequence, max_length=self.max_len)
return {
"concept_ids": concept_ids,
"type_ids": type_tokens,
"ages": age_tokens,
"time_stamps": time_tokens,
"visit_orders": position_tokens,
"visit_segments": visit_tokens,
"labels": labels,
"task": task,
}

def balance_labels(self, task: str, positive_ratio: float) -> None:
"""Balance the labels for the specified task in the dataset.
Expand Down
83 changes: 67 additions & 16 deletions odyssey/evals/TestAnalysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -24,25 +24,25 @@
"ROOT = \"/fs01/home/afallah/odyssey/odyssey\"\n",
"os.chdir(ROOT)\n",
"\n",
"from odyssey.data.dataset import FinetuneMultiDataset\n",
"from odyssey.data.dataset import FinetuneMultiDataset, FinetuneDatasetDecoder\n",
"from odyssey.data.tokenizer import ConceptTokenizer\n",
"from odyssey.models.model_utils import load_finetune_data"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"class config:\n",
" \"\"\"Save the configuration arguments.\"\"\"\n",
"\n",
" model_path = \"checkpoints/multibird_finetune/multibird_finetune/test_outputs/test_outputs_1ec842db.pt\"\n",
" model_path = \"checkpoints/mamba_finetune/test_outputs/test_outputs_f8471ffd.pt\"\n",
" vocab_dir = \"odyssey/data/vocab\"\n",
" data_dir = \"odyssey/data/bigbird_data\"\n",
" sequence_file = \"patient_sequences/patient_sequences_2048_multi.parquet\"\n",
" id_file = \"patient_id_dict/dataset_2048_multi.pkl\"\n",
" sequence_file = \"patient_sequences_2048_multi.parquet\"\n",
" id_file = \"dataset_2048_multi.pkl\"\n",
" valid_scheme = \"few_shot\"\n",
" num_finetune_patients = \"all\"\n",
" # label_name = \"label_mortality_1month\"\n",
Expand All @@ -55,7 +55,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -80,7 +80,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -97,7 +97,7 @@
" config.num_finetune_patients,\n",
")\n",
"\n",
"test_dataset = FinetuneMultiDataset(\n",
"test_dataset = FinetuneDatasetDecoder(\n",
" data=fine_test,\n",
" tokenizer=tokenizer,\n",
" tasks=config.tasks,\n",
Expand Down Expand Up @@ -127,19 +127,41 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 16,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"dict_keys(['labels', 'logits'])"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_outputs = torch.load(config.model_path, map_location=config.device)\n",
"test_outputs.keys()"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 17,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"array([0, 0, 0, ..., 0, 0, 0])"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"labels = test_outputs[\"labels\"].cpu().numpy()\n",
"logits = test_outputs[\"logits\"].cpu().numpy()\n",
Expand All @@ -151,15 +173,32 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 22,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"{'Balanced Accuracy': 0.8985595582256243,\n",
" 'F1 Score': 0.6536144578313253,\n",
" 'Precision': 0.5174058178350024,\n",
" 'Recall': 0.8871627146361406,\n",
" 'AUROC': 0.9667942606114659,\n",
" 'Average Precision Score': 0.47009681385740587,\n",
" 'AUC-PR': 0.7078210982047579}"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Tasks we have are: tasks = ['mortality_1month', 'los_1week', 'c0', 'c1', 'c2']\n",
"\n",
"task_idx = []\n",
"for i, task in enumerate(tasks):\n",
" if task == \"los_1week\":\n",
" if task == \"mortality_1month\":\n",
" task_idx.append(i)\n",
"\n",
"\n",
Expand Down Expand Up @@ -224,6 +263,18 @@
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
Expand Down
Loading
Loading