diff --git a/odyssey/data/DataProcessor.ipynb b/odyssey/data/DataProcessor.ipynb index 1192cc9..0e24f83 100644 --- a/odyssey/data/DataProcessor.ipynb +++ b/odyssey/data/DataProcessor.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2024-03-13T16:14:45.546088300Z", @@ -13,23 +13,39 @@ "outputs": [], "source": [ "import os\n", + "import sys\n", "import pickle\n", "import random\n", - "import sys\n", "from typing import Any, Dict, List, Optional\n", "\n", "import numpy as np\n", "import pandas as pd\n", - "from sklearn.model_selection import train_test_split\n", - "from skmultilearn.model_selection import iterative_train_test_split\n", - "\n", - "from odyssey.utils.utils import save_object_to_disk, seed_everything\n", "\n", - "DATA_ROOT = \"/h/afallah/odyssey/odyssey/data/bigbird_data\"\n", + "ROOT = \"/h/afallah/odyssey/odyssey\"\n", + "DATA_ROOT = f\"{ROOT}/odyssey/data/bigbird_data\"\n", "DATASET = f\"{DATA_ROOT}/patient_sequences/patient_sequences_2048.parquet\"\n", "MAX_LEN = 2048\n", "\n", - "os.chdir(DATA_ROOT)\n", + "os.chdir(ROOT)\n", + "\n", + "from odyssey.utils.utils import seed_everything\n", + "from odyssey.data.processor import (\n", + " filter_by_num_visit,\n", + " filter_by_length_of_stay,\n", + " get_last_occurence_index,\n", + " check_readmission_label,\n", + " get_length_of_stay,\n", + " get_visit_cutoff_at_threshold,\n", + " process_length_of_stay_dataset,\n", + " process_condition_dataset,\n", + " process_mortality_dataset,\n", + " process_readmission_dataset,\n", + " process_multi_dataset,\n", + " stratified_train_test_split,\n", + " sample_balanced_subset, \n", + " get_pretrain_test_split,\n", + " get_finetune_split\n", + ")\n", "\n", "SEED = 23\n", "seed_everything(seed=SEED)" @@ -37,7 +53,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2024-03-13T16:15:12.321718600Z", @@ -45,235 +61,7 @@ }, "collapsed": false }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Current columns: Index(['patient_id', 'num_visits', 'deceased', 'death_after_start',\n", - " 'death_after_end', 'length', 'token_length', 'event_tokens_2048',\n", - " 'type_tokens_2048', 'age_tokens_2048', 'time_tokens_2048',\n", - " 'visit_tokens_2048', 'position_tokens_2048', 'elapsed_tokens_2048',\n", - " 'common_conditions', 'rare_conditions'],\n", - " dtype='object')\n" - ] - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
patient_idnum_visitsdeceaseddeath_after_startdeath_after_endlengthtoken_lengthevent_tokens_2048type_tokens_2048age_tokens_2048time_tokens_2048visit_tokens_2048position_tokens_2048elapsed_tokens_2048common_conditionsrare_conditions
035581927-9c95-5ae9-af76-7d74870a349c10NaNNaN5054[[CLS], [VS], 00006473900, 00904516561, 510790...[1, 2, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, ...[0, 85, 85, 85, 85, 85, 85, 85, 85, 85, 85, 85...[0, 5902, 5902, 5902, 5902, 5902, 5902, 5902, ...[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...[-2.0, -1.0, 1.97, 2.02, 2.02, 2.02, 2.02, 2.0...[1, 0, 0, 0, 0, 0, 0, 0, 0, 0][0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
1f5bba8dd-25c0-5336-8d3d-37424c18502620NaNNaN148156[[CLS], [VS], 52135_2, 52075_2, 52074_2, 52073...[1, 2, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, ...[0, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83...[0, 6594, 6594, 6594, 6594, 6594, 6594, 6594, ...[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...[-2.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0...[0, 0, 0, 0, 0, 0, 0, 1, 0, 0][0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
2f4938f91-cadb-5133-8541-a52fb0916cea20NaNNaN7886[[CLS], [VS], 0RB30ZZ, 0RG10A0, 00071101441, 0...[1, 2, 7, 7, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...[0, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44...[0, 8150, 8150, 8150, 8150, 8150, 8150, 8150, ...[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...[-2.0, -1.0, 0.0, 0.0, 1.08, 1.08, 13.89, 13.8...[0, 0, 0, 0, 0, 0, 0, 0, 0, 0][0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
36fe2371b-a6f0-5436-aade-7795005b0c6620NaNNaN8694[[CLS], [VS], 63739057310, 49281041688, 005970...[1, 2, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...[0, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72...[0, 6093, 6093, 6093, 6093, 6093, 6093, 6093, ...[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...[-2.0, -1.0, 0.75, 0.75, 0.75, 0.75, 0.75, 0.7...[1, 0, 0, 0, 0, 0, 0, 1, 0, 0][0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
46f7590ae-f3b9-50e5-9e41-d4bb1000887a10NaNNaN7276[[CLS], [VS], 50813_0, 52135_0, 52075_3, 52074...[1, 2, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, ...[0, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47...[0, 6379, 6379, 6379, 6379, 6379, 6379, 6379, ...[0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...[-2.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0...[1, 0, 0, 0, 0, 0, 0, 0, 0, 1][0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
\n", - "
" - ], - "text/plain": [ - " patient_id num_visits deceased \\\n", - "0 35581927-9c95-5ae9-af76-7d74870a349c 1 0 \n", - "1 f5bba8dd-25c0-5336-8d3d-37424c185026 2 0 \n", - "2 f4938f91-cadb-5133-8541-a52fb0916cea 2 0 \n", - "3 6fe2371b-a6f0-5436-aade-7795005b0c66 2 0 \n", - "4 6f7590ae-f3b9-50e5-9e41-d4bb1000887a 1 0 \n", - "\n", - " death_after_start death_after_end length token_length \\\n", - "0 NaN NaN 50 54 \n", - "1 NaN NaN 148 156 \n", - "2 NaN NaN 78 86 \n", - "3 NaN NaN 86 94 \n", - "4 NaN NaN 72 76 \n", - "\n", - " event_tokens_2048 \\\n", - "0 [[CLS], [VS], 00006473900, 00904516561, 510790... \n", - "1 [[CLS], [VS], 52135_2, 52075_2, 52074_2, 52073... \n", - "2 [[CLS], [VS], 0RB30ZZ, 0RG10A0, 00071101441, 0... \n", - "3 [[CLS], [VS], 63739057310, 49281041688, 005970... \n", - "4 [[CLS], [VS], 50813_0, 52135_0, 52075_3, 52074... \n", - "\n", - " type_tokens_2048 \\\n", - "0 [1, 2, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, ... \n", - "1 [1, 2, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, ... \n", - "2 [1, 2, 7, 7, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, ... \n", - "3 [1, 2, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, ... \n", - "4 [1, 2, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, ... \n", - "\n", - " age_tokens_2048 \\\n", - "0 [0, 85, 85, 85, 85, 85, 85, 85, 85, 85, 85, 85... \n", - "1 [0, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83... \n", - "2 [0, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44... \n", - "3 [0, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72... \n", - "4 [0, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47... \n", - "\n", - " time_tokens_2048 \\\n", - "0 [0, 5902, 5902, 5902, 5902, 5902, 5902, 5902, ... \n", - "1 [0, 6594, 6594, 6594, 6594, 6594, 6594, 6594, ... \n", - "2 [0, 8150, 8150, 8150, 8150, 8150, 8150, 8150, ... \n", - "3 [0, 6093, 6093, 6093, 6093, 6093, 6093, 6093, ... \n", - "4 [0, 6379, 6379, 6379, 6379, 6379, 6379, 6379, ... \n", - "\n", - " visit_tokens_2048 \\\n", - "0 [0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ... \n", - "1 [0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ... \n", - "2 [0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ... \n", - "3 [0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ... \n", - "4 [0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ... \n", - "\n", - " position_tokens_2048 \\\n", - "0 [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... \n", - "1 [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... \n", - "2 [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... \n", - "3 [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... \n", - "4 [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ... \n", - "\n", - " elapsed_tokens_2048 \\\n", - "0 [-2.0, -1.0, 1.97, 2.02, 2.02, 2.02, 2.02, 2.0... \n", - "1 [-2.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0... \n", - "2 [-2.0, -1.0, 0.0, 0.0, 1.08, 1.08, 13.89, 13.8... \n", - "3 [-2.0, -1.0, 0.75, 0.75, 0.75, 0.75, 0.75, 0.7... \n", - "4 [-2.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0... \n", - "\n", - " common_conditions rare_conditions \n", - "0 [1, 0, 0, 0, 0, 0, 0, 0, 0, 0] [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] \n", - "1 [0, 0, 0, 0, 0, 0, 0, 1, 0, 0] [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] \n", - "2 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] \n", - "3 [1, 0, 0, 0, 0, 0, 0, 1, 0, 0] [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] \n", - "4 [1, 0, 0, 0, 0, 0, 0, 0, 0, 1] [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] " - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "# Load complete dataset\n", "dataset_2048 = pd.read_parquet(DATASET)\n", @@ -282,164 +70,14 @@ "dataset_2048.head()" ] }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "def filter_by_num_visit(dataset: pd.DataFrame, minimum_num_visits: int) -> pd.DataFrame:\n", - " \"\"\"Filter the patients based on num_visits threshold.\n", - "\n", - " Args:\n", - " dataset (pd.DataFrame): The input dataset.\n", - " minimum_num_visits (int): The threshold num_visits\n", - "\n", - " Returns\n", - " -------\n", - " pd.DataFrame: The filtered dataset.\n", - " \"\"\"\n", - " filtered_dataset = dataset.loc[dataset[\"num_visits\"] >= minimum_num_visits]\n", - " filtered_dataset.reset_index(drop=True, inplace=True)\n", - " return filtered_dataset\n", - "\n", - "\n", - "def filter_by_length_of_stay(dataset: pd.DataFrame, threshold: int = 1) -> pd.DataFrame:\n", - " \"\"\"Filter the patients based on length of stay threshold.\n", - "\n", - " Args:\n", - " dataset (pd.DataFrame): The input dataset.\n", - " minimum_num_visits (int): The threshold length of stay\n", - "\n", - " Returns\n", - " -------\n", - " pd.DataFrame: The filtered dataset.\n", - " \"\"\"\n", - " filtered_dataset = dataset.loc[dataset[\"length_of_stay\"] >= threshold]\n", - "\n", - " # Only keep the patients that their first event happens within threshold\n", - " # TODO: Check how many patients get removed here?\n", - " filtered_dataset = filtered_dataset[\n", - " filtered_dataset.apply(\n", - " lambda row: row[\"elapsed_tokens_2048\"][row[\"last_VS_index\"] + 1]\n", - " < threshold * 24,\n", - " axis=1,\n", - " )\n", - " ]\n", - "\n", - " filtered_dataset.reset_index(drop=True, inplace=True)\n", - " return filtered_dataset\n", - "\n", - "\n", - "def get_last_occurence_index(seq: List[str], target: str) -> int:\n", - " \"\"\"Return the index of the last occurrence of target in seq.\n", - "\n", - " Args:\n", - " seq (List[str]): The input sequence.\n", - " target (str): The target string to find.\n", - "\n", - " Returns\n", - " -------\n", - " int: The index of the last occurrence of target in seq.\n", - " \"\"\"\n", - " return len(seq) - (seq[::-1].index(target) + 1)\n", - "\n", - "\n", - "def check_readmission_label(row: pd.Series) -> int:\n", - " \"\"\"Check if the label indicates readmission within one month.\n", - "\n", - " Args:\n", - " row (pd.Series): The input row.\n", - "\n", - " Returns\n", - " -------\n", - " bool: True if readmission label is present, False otherwise.\n", - " \"\"\"\n", - " last_vs_index = row[\"last_VS_index\"]\n", - " return int(\n", - " row[\"event_tokens_2048\"][last_vs_index - 1]\n", - " in (\"[W_0]\", \"[W_1]\", \"[W_2]\", \"[W_3]\", \"[M_1]\"),\n", - " )\n", - "\n", - "\n", - "def get_length_of_stay(row: pd.Series) -> pd.Series:\n", - " \"\"\"Determine the length of a given visit.\n", - "\n", - " Args:\n", - " row (pd.Series): The input row.\n", - "\n", - " Returns\n", - " -------\n", - " pd.Series: The preprocessed row.\n", - " \"\"\"\n", - " admission_time = row[\"last_VS_index\"] + 1\n", - " discharge_time = row[\"last_VE_index\"] - 1\n", - " return (discharge_time - admission_time) / 24\n", - "\n", - "\n", - "def get_visit_cutoff_at_threshold(row: pd.Series, threshold: int = 24) -> int:\n", - " \"\"\"Get the index of the first event token of last visit that occurs after threshold hours.\n", - "\n", - " Args:\n", - " row (pd.Series): The input row.\n", - " threshold (int): The number of hours to consider.\n", - "\n", - " Returns\n", - " -------\n", - " cutoff_index (int): The corrosponding cutoff index.\n", - " \"\"\"\n", - " last_vs_index = row[\"last_VS_index\"]\n", - " last_ve_index = row[\"last_VE_index\"]\n", - "\n", - " for i in range(last_vs_index + 1, last_ve_index):\n", - " if row[\"elapsed_tokens_2048\"][i] > threshold:\n", - " return i\n", - "\n", - " return len(row[\"event_tokens_2048\"])" - ] - }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "def process_length_of_stay_dataset(\n", - " dataset: pd.DataFrame,\n", - " threshold: int = 7,\n", - ") -> pd.DataFrame:\n", - " \"\"\"Process the length of stay dataset to extract required features.\n", - "\n", - " Args:\n", - " dataset (pd.DataFrame): The input dataset.\n", - " threshold (int): The threshold length of stay.\n", - "\n", - " Returns\n", - " -------\n", - " pd.DataFrame: The processed dataset.\n", - " \"\"\"\n", - " dataset[\"last_VS_index\"] = dataset[\"event_tokens_2048\"].transform(\n", - " lambda seq: get_last_occurence_index(list(seq), \"[VS]\"),\n", - " )\n", - " dataset[\"last_VE_index\"] = dataset[\"event_tokens_2048\"].transform(\n", - " lambda seq: get_last_occurence_index(list(seq), \"[VE]\"),\n", - " )\n", - " dataset[\"length_of_stay\"] = dataset.apply(get_length_of_stay, axis=1)\n", - "\n", - " dataset = filter_by_length_of_stay(dataset, threshold=1)\n", - " dataset[\"label_los_1week\"] = (dataset[\"length_of_stay\"] >= threshold).astype(int)\n", - "\n", - " dataset[\"cutoff_los\"] = dataset.apply(\n", - " lambda row: get_visit_cutoff_at_threshold(row, threshold=24),\n", - " axis=1,\n", - " )\n", - " dataset[\"token_length\"] = dataset[\"event_tokens_2048\"].apply(len)\n", - "\n", - " return dataset\n", - "\n", - "\n", "# Process the dataset for length of stay prediction above a threshold\n", - "dataset_2048_los = process_length_of_stay_dataset(dataset_2048.copy(), threshold=7)" + "dataset_2048_los = process_length_of_stay_dataset(dataset_2048.copy(), threshold=7, max_len=MAX_LEN)" ] }, { @@ -448,29 +86,8 @@ "metadata": {}, "outputs": [], "source": [ - "def process_condition_dataset(dataset: pd.DataFrame) -> pd.DataFrame:\n", - " \"\"\"Process the condition dataset to extract required features.\n", - "\n", - " Args:\n", - " dataset (pd.DataFrame): The input condition dataset.\n", - "\n", - " Returns\n", - " -------\n", - " pd.DataFrame: The processed condition dataset.\n", - " \"\"\"\n", - " dataset[\"all_conditions\"] = dataset.apply(\n", - " lambda row: np.concatenate(\n", - " [row[\"common_conditions\"], row[\"rare_conditions\"]],\n", - " dtype=np.int64,\n", - " ),\n", - " axis=1,\n", - " )\n", - "\n", - " return dataset\n", - "\n", - "\n", "# Process the dataset for conditions including rare and common\n", - "dataset_2048_condition = process_condition_dataset(dataset_2048.copy())" + "dataset_2048_condition = process_condition_dataset(dataset_2048.copy(), max_len=MAX_LEN)" ] }, { @@ -485,28 +102,8 @@ }, "outputs": [], "source": [ - "def process_mortality_dataset(dataset: pd.DataFrame) -> pd.DataFrame:\n", - " \"\"\"Process the mortality dataset to extract required features.\n", - "\n", - " Args:\n", - " dataset (pd.DataFrame): The input mortality dataset.\n", - "\n", - " Returns\n", - " -------\n", - " pd.DataFrame: The processed mortality dataset.\n", - " \"\"\"\n", - " dataset[\"label_mortality_2weeks\"] = (\n", - " (dataset[\"death_after_start\"] >= 0) & (dataset[\"death_after_end\"] <= 15)\n", - " ).astype(int)\n", - " dataset[\"label_mortality_1month\"] = (\n", - " (dataset[\"death_after_start\"] >= 0) & (dataset[\"death_after_end\"] <= 32)\n", - " ).astype(int)\n", - "\n", - " return dataset\n", - "\n", - "\n", "# Process the dataset for mortality in two weeks or one month task\n", - "dataset_2048_mortality = process_mortality_dataset(dataset_2048.copy())" + "dataset_2048_mortality = process_mortality_dataset(dataset_2048.copy(), max_len=MAX_LEN)" ] }, { @@ -521,34 +118,8 @@ }, "outputs": [], "source": [ - "def process_readmission_dataset(dataset: pd.DataFrame) -> pd.DataFrame:\n", - " \"\"\"Process the readmission dataset to extract required features.\n", - "\n", - " Args:\n", - " dataset (pd.DataFrame): The input dataset.\n", - "\n", - " Returns\n", - " -------\n", - " pd.DataFrame: The processed dataset.\n", - " \"\"\"\n", - " dataset[\"last_VS_index\"] = dataset[\"event_tokens_2048\"].transform(\n", - " lambda seq: get_last_occurence_index(list(seq), \"[VS]\"),\n", - " )\n", - " dataset[\"cutoff_readmission\"] = dataset[\"last_VS_index\"] - 1\n", - " dataset[\"label_readmission_1month\"] = dataset.apply(check_readmission_label, axis=1)\n", - "\n", - " dataset[\"num_visits\"] -= 1\n", - " dataset[\"token_length\"] = dataset[\"event_tokens_2048\"].apply(len)\n", - "\n", - " return dataset\n", - "\n", - "\n", "# Process the dataset for hospital readmission in one month task\n", - "dataset_2048_readmission = filter_by_num_visit(\n", - " dataset_2048.copy(),\n", - " minimum_num_visits=2,\n", - ")\n", - "dataset_2048_readmission = process_readmission_dataset(dataset_2048_readmission)" + "dataset_2048_readmission = process_readmission_dataset(dataset_2048.copy(), max_len=MAX_LEN)" ] }, { @@ -557,97 +128,7 @@ "metadata": {}, "outputs": [], "source": [ - "def process_multi_dataset(datasets: Dict[str, pd.DataFrame]):\n", - " \"\"\"\n", - " Process the multi-task dataset by merging the original dataset with the other datasets.\n", - "\n", - " Args:\n", - " datasets (Dict): Dictionary mapping each task to its respective dataframe\n", - "\n", - " Returns\n", - " -------\n", - " pd.DataFrame: The processed multi-task dataset\n", - " \"\"\"\n", - " # Merging datasets on 'patient_id'\n", - " multi_dataset = datasets[\"original\"].merge(\n", - " datasets[\"condition\"][[\"patient_id\", \"all_conditions\"]],\n", - " on=\"patient_id\",\n", - " how=\"left\",\n", - " )\n", - " multi_dataset = multi_dataset.merge(\n", - " datasets[\"mortality\"][[\"patient_id\", \"label_mortality_1month\"]],\n", - " on=\"patient_id\",\n", - " how=\"left\",\n", - " )\n", - " multi_dataset = multi_dataset.merge(\n", - " datasets[\"readmission\"][\n", - " [\"patient_id\", \"cutoff_readmission\", \"label_readmission_1month\"]\n", - " ],\n", - " on=\"patient_id\",\n", - " how=\"left\",\n", - " )\n", - " multi_dataset = multi_dataset.merge(\n", - " datasets[\"los\"][[\"patient_id\", \"cutoff_los\", \"label_los_1week\"]],\n", - " on=\"patient_id\",\n", - " how=\"left\",\n", - " )\n", - "\n", - " # Selecting the required columns\n", - " multi_dataset = multi_dataset[\n", - " [\n", - " \"patient_id\",\n", - " \"num_visits\",\n", - " \"event_tokens_2048\",\n", - " \"type_tokens_2048\",\n", - " \"age_tokens_2048\",\n", - " \"time_tokens_2048\",\n", - " \"visit_tokens_2048\",\n", - " \"position_tokens_2048\",\n", - " \"elapsed_tokens_2048\",\n", - " \"cutoff_los\",\n", - " \"cutoff_readmission\",\n", - " \"all_conditions\",\n", - " \"label_mortality_1month\",\n", - " \"label_readmission_1month\",\n", - " \"label_los_1week\",\n", - " ]\n", - " ]\n", - "\n", - " # Transform conditions from a vector of numbers to binary classes\n", - " conditions_expanded = multi_dataset[\"all_conditions\"].apply(pd.Series)\n", - " conditions_expanded.columns = [f\"condition{i}\" for i in range(20)]\n", - " multi_dataset = multi_dataset.drop(\"all_conditions\", axis=1)\n", - " multi_dataset = pd.concat([multi_dataset, conditions_expanded], axis=1)\n", - "\n", - " # Standardize important column names\n", - " multi_dataset.rename(\n", - " columns={\n", - " \"cutoff_los\": \"cutoff_los_1week\",\n", - " \"cutoff_readmission\": \"cutoff_readmission_1month\",\n", - " },\n", - " inplace=True,\n", - " )\n", - " condition_columns = {f\"condition{i}\": f\"label_c{i}\" for i in range(20)}\n", - " multi_dataset.rename(columns=condition_columns, inplace=True)\n", - "\n", - " numerical_columns = [\n", - " \"cutoff_los_1week\",\n", - " \"cutoff_readmission_1month\",\n", - " \"label_mortality_1month\",\n", - " \"label_readmission_1month\",\n", - " \"label_los_1week\",\n", - " ] + [f\"label_c{i}\" for i in range(20)]\n", - "\n", - " # Fill NaN values and convert numerical columns to integers\n", - " for column in numerical_columns:\n", - " multi_dataset[column] = multi_dataset[column].fillna(-1).astype(int)\n", - "\n", - " # Reset dataset index\n", - " multi_dataset.reset_index(drop=True, inplace=True)\n", - "\n", - " return multi_dataset\n", - "\n", - "\n", + "# Process the multi dataset\n", "multi_dataset = process_multi_dataset(\n", " datasets={\n", " \"original\": dataset_2048,\n", @@ -656,103 +137,10 @@ " \"readmission\": dataset_2048_readmission,\n", " \"los\": dataset_2048_los,\n", " },\n", + " max_len=MAX_LEN\n", ")" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def stratified_train_test_split(\n", - " dataset: pd.DataFrame,\n", - " target: str,\n", - " test_size: float,\n", - " return_test: Optional[bool] = False,\n", - "):\n", - " \"\"\"\n", - " Split the given dataset into training and testing sets using iterative stratification on given multi-label target.\n", - " \"\"\"\n", - " # Convert all_conditions into a format suitable for multi-label stratification\n", - " Y = np.array(dataset[target].values.tolist())\n", - " X = dataset[\"patient_id\"].to_numpy().reshape(-1, 1)\n", - " is_single_label = type(dataset.iloc[0][target]) == np.int64\n", - "\n", - " # Perform stratified split\n", - " if is_single_label:\n", - " X_train, X_test, y_train, y_test = train_test_split(\n", - " X,\n", - " Y,\n", - " stratify=Y,\n", - " test_size=test_size,\n", - " random_state=SEED,\n", - " )\n", - "\n", - " else:\n", - " X_train, y_train, X_test, y_test = iterative_train_test_split(\n", - " X,\n", - " Y,\n", - " test_size=test_size,\n", - " )\n", - "\n", - " X_train = X_train.flatten().tolist()\n", - " X_test = X_test.flatten().tolist()\n", - "\n", - " if return_test:\n", - " return X_test\n", - " else:\n", - " return X_train, X_test\n", - "\n", - "\n", - "def sample_balanced_subset(dataset: pd.DataFrame, target: str, sample_size: int):\n", - " \"\"\"\n", - " Sample a subset of dataset with balanced target labels.\n", - " \"\"\"\n", - " # Sampling positive and negative patients\n", - " pos_patients = dataset[dataset[target] == True].sample(\n", - " n=sample_size // 2,\n", - " random_state=SEED,\n", - " )\n", - " neg_patients = dataset[dataset[target] == False].sample(\n", - " n=sample_size // 2,\n", - " random_state=SEED,\n", - " )\n", - "\n", - " # Combining and shuffling patient IDs\n", - " sample_patients = (\n", - " pos_patients[\"patient_id\"].tolist() + neg_patients[\"patient_id\"].tolist()\n", - " )\n", - " random.shuffle(sample_patients)\n", - "\n", - " return sample_patients\n", - "\n", - "\n", - "def get_pretrain_test_split(\n", - " dataset: pd.DataFrame,\n", - " stratify_target: Optional[str] = None,\n", - " test_size: float = 0.15,\n", - "):\n", - " \"\"\"Split dataset into pretrain and test set. Stratify on a given target column if needed.\"\"\"\n", - " if stratify_target:\n", - " pretrain_ids, test_ids = stratified_train_test_split(\n", - " dataset,\n", - " target=stratify_target,\n", - " test_size=test_size,\n", - " )\n", - "\n", - " else:\n", - " test_patients = dataset.sample(n=test_size, random_state=SEED)\n", - " test_ids = test_patients[\"patient_id\"].tolist()\n", - " pretrain_ids = dataset[~dataset[\"patient_id\"].isin(test_patients)][\n", - " \"patient_id\"\n", - " ].tolist()\n", - "\n", - " random.shuffle(pretrain_ids)\n", - "\n", - " return pretrain_ids, test_ids" - ] - }, { "cell_type": "code", "execution_count": null, @@ -785,8 +173,7 @@ "metadata": {}, "outputs": [], "source": [ - "class config:\n", - " task_splits = {\n", + "task_config = {\n", " \"mortality\": {\n", " \"dataset\": dataset_2048_mortality,\n", " \"label_col\": \"label_mortality_1month\",\n", @@ -815,9 +202,7 @@ " \"save_path\": \"patient_id_dict/dataset_2048_condition.pkl\",\n", " \"split_mode\": \"multi_label_stratified\",\n", " },\n", - " }\n", - "\n", - " all_tasks = list(task_splits.keys())" + " }" ] }, { @@ -832,57 +217,11 @@ }, "outputs": [], "source": [ - "def get_finetune_split(\n", - " config: config,\n", - " patient_ids_dict: Dict[str, Any],\n", - ") -> Dict[str, Dict[str, List[str]]]:\n", - " \"\"\"\n", - " Splits the dataset into training and cross-finetuneation sets using k-fold cross-finetuneation\n", - " while ensuring balanced label distribution in each fold. Saves the resulting dictionary to disk.\n", - " \"\"\"\n", - " # Extract task-specific configuration\n", - " task_config = config.task_splits[task]\n", - " dataset = task_config[\"dataset\"]\n", - " label_col = task_config[\"label_col\"]\n", - " finetune_sizes = task_config[\"finetune_size\"]\n", - " save_path = task_config[\"save_path\"]\n", - " split_mode = task_config[\"split_mode\"]\n", - "\n", - " # Get pretrain dataset\n", - " pretrain_ids = patient_ids_dict[\"pretrain\"]\n", - " dataset = dataset[dataset[\"patient_id\"].isin(pretrain_ids)]\n", - "\n", - " # Few-shot finetune patient ids\n", - " for finetune_num in finetune_sizes:\n", - " if split_mode == \"single_label_balanced\":\n", - " finetune_ids = sample_balanced_subset(\n", - " dataset,\n", - " target=label_col,\n", - " sample_size=finetune_num,\n", - " )\n", - "\n", - " elif (\n", - " split_mode == \"single_label_stratified\"\n", - " or split_mode == \"multi_label_stratified\"\n", - " ):\n", - " finetune_ids = stratified_train_test_split(\n", - " dataset,\n", - " target=label_col,\n", - " test_size=finetune_num / len(dataset),\n", - " return_test=True,\n", - " )\n", - "\n", - " patient_ids_dict[\"finetune\"][\"few_shot\"][f\"{finetune_num}\"] = finetune_ids\n", - "\n", - " # Save the dictionary to disk\n", - " save_object_to_disk(patient_ids_dict, save_path)\n", - "\n", - " return patient_ids_dict\n", - "\n", - "\n", - "for task in config.all_tasks:\n", + "# Get finetune split\n", + "for task in task_config.keys():\n", " patient_ids_dict = get_finetune_split(\n", - " config=config,\n", + " task_config=task_config,\n", + " task=task,\n", " patient_ids_dict=patient_ids_dict,\n", " )" ] diff --git a/odyssey/data/processor.py b/odyssey/data/processor.py new file mode 100644 index 0000000..4d7c4bc --- /dev/null +++ b/odyssey/data/processor.py @@ -0,0 +1,511 @@ +"""Process patient sequences to be usable by model based on task and spilt into trian-test-finetune.""" + +import random +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import pandas as pd +from sklearn.model_selection import train_test_split +from skmultilearn.model_selection import iterative_train_test_split + +from odyssey.utils.utils import seed_everything, save_object_to_disk + +SEED = 23 +seed_everything(seed=SEED) + + +def filter_by_num_visit(dataset: pd.DataFrame, minimum_num_visits: int) -> pd.DataFrame: + """Filter the patients based on num_visits threshold. + + Args: + dataset (pd.DataFrame): The input dataset. + minimum_num_visits (int): The threshold num_visits + + Returns + ------- + pd.DataFrame: The filtered dataset. + """ + filtered_dataset = dataset.loc[dataset["num_visits"] >= minimum_num_visits] + filtered_dataset.reset_index(drop=True, inplace=True) + return filtered_dataset + + +def filter_by_length_of_stay(dataset: pd.DataFrame, threshold: int = 1, max_len: int = 2048) -> pd.DataFrame: + """Filter the patients based on length of stay threshold. + + Args: + dataset (pd.DataFrame): The input dataset. + minimum_num_visits (int): The threshold length of stay + max_len (int): The maximum length of the sequence. + + Returns + ------- + pd.DataFrame: The filtered dataset. + """ + filtered_dataset = dataset.loc[dataset["length_of_stay"] >= threshold] + + # Only keep the patients that their first event happens within threshold hours + filtered_dataset = filtered_dataset[ + filtered_dataset.apply( + lambda row: row[f"elapsed_tokens_{max_len}"][row["last_VS_index"] + 1] + < threshold * 24, + axis=1, + ) + ] + + filtered_dataset.reset_index(drop=True, inplace=True) + return filtered_dataset + + +def get_last_occurence_index(seq: List[str], target: str) -> int: + """Return the index of the last occurrence of target in seq. + + Args: + seq (List[str]): The input sequence. + target (str): The target string to find. + + Returns + ------- + int: The index of the last occurrence of target in seq. + """ + return len(seq) - (seq[::-1].index(target) + 1) + + +def check_readmission_label(row: pd.Series, max_len: int = 2048) -> int: + """Check if the label indicates readmission within one month. + + Args: + row (pd.Series): The input row. + max_len (int): The maximum length of the sequence. + + Returns + ------- + bool: True if readmission label is present, False otherwise. + """ + last_vs_index = row["last_VS_index"] + return int( + row[f"event_tokens_{max_len}"][last_vs_index - 1] + in ("[W_0]", "[W_1]", "[W_2]", "[W_3]", "[M_1]"), + ) + + +def get_length_of_stay(row: pd.Series) -> pd.Series: + """Determine the length of a given visit. + + Args: + row (pd.Series): The input row. + + Returns + ------- + pd.Series: The preprocessed row. + """ + admission_time = row["last_VS_index"] + 1 + discharge_time = row["last_VE_index"] - 1 + return (discharge_time - admission_time) / 24 + + +def get_visit_cutoff_at_threshold(row: pd.Series, threshold: int = 24, max_len: int = 2048) -> int: + """Get the index of the first event token of last visit that occurs after threshold hours. + + Args: + row (pd.Series): The input row. + threshold (int): The number of hours to consider. + max_len (int): The maximum length of the sequence. + + Returns + ------- + cutoff_index (int): The corrosponding cutoff index. + """ + last_vs_index = row["last_VS_index"] + last_ve_index = row["last_VE_index"] + + for i in range(last_vs_index + 1, last_ve_index): + if row[f"elapsed_tokens_{max_len}"][i] > threshold: + return i + + return len(row[f"event_tokens_{max_len}"]) + + +def process_length_of_stay_dataset( + dataset: pd.DataFrame, + threshold: int = 7, + max_len: int = 2048, +) -> pd.DataFrame: + """Process the length of stay dataset to extract required features. + + Args: + dataset (pd.DataFrame): The input dataset. + threshold (int): The threshold length of stay. + max_len (int): The maximum length of the sequence. + + Returns + ------- + pd.DataFrame: The processed dataset. + """ + dataset["last_VS_index"] = dataset[f"event_tokens_{max_len}"].transform( + lambda seq: get_last_occurence_index(list(seq), "[VS]"), + ) + dataset["last_VE_index"] = dataset[f"event_tokens_{max_len}"].transform( + lambda seq: get_last_occurence_index(list(seq), "[VE]"), + ) + dataset["length_of_stay"] = dataset.apply(get_length_of_stay, axis=1) + + dataset = filter_by_length_of_stay(dataset, threshold=1) + dataset["label_los_1week"] = (dataset["length_of_stay"] >= threshold).astype(int) + + dataset["cutoff_los"] = dataset.apply( + lambda row: get_visit_cutoff_at_threshold(row, threshold=24), + axis=1, + ) + dataset["token_length"] = dataset[f"event_tokens_{max_len}"].apply(len) + + return dataset + + +def process_condition_dataset(dataset: pd.DataFrame) -> pd.DataFrame: + """Process the condition dataset to extract required features. + + Args: + dataset (pd.DataFrame): The input condition dataset. + + Returns + ------- + pd.DataFrame: The processed condition dataset. + """ + dataset["all_conditions"] = dataset.apply( + lambda row: np.concatenate( + [row["common_conditions"], row["rare_conditions"]], + dtype=np.int64, + ), + axis=1, + ) + + return dataset + + +def process_mortality_dataset(dataset: pd.DataFrame) -> pd.DataFrame: + """Process the mortality dataset to extract required features. + + Args: + dataset (pd.DataFrame): The input mortality dataset. + + Returns + ------- + pd.DataFrame: The processed mortality dataset. + """ + dataset["label_mortality_2weeks"] = ( + (dataset["death_after_start"] >= 0) & (dataset["death_after_end"] <= 15) + ).astype(int) + + dataset["label_mortality_1month"] = ( + (dataset["death_after_start"] >= 0) & (dataset["death_after_end"] <= 32) + ).astype(int) + + return dataset + + +def process_readmission_dataset(dataset: pd.DataFrame, max_len: int = 2048) -> pd.DataFrame: + """Process the readmission dataset to extract required features. + + Args: + dataset (pd.DataFrame): The input dataset. + + Returns + ------- + pd.DataFrame: The processed dataset. + """ + + dataset = filter_by_num_visit(dataset.copy(), minimum_num_visits=2) + + dataset["last_VS_index"] = dataset[f"event_tokens_{max_len}"].transform( + lambda seq: get_last_occurence_index(list(seq), "[VS]"), + ) + dataset["cutoff_readmission"] = dataset["last_VS_index"] - 1 + dataset["label_readmission_1month"] = dataset.apply(check_readmission_label, axis=1) + + dataset["num_visits"] -= 1 + dataset["token_length"] = dataset[f"event_tokens_{max_len}"].apply(len) + + return dataset + + +def process_multi_dataset( + datasets: Dict[str, pd.DataFrame], + max_len: int = 2048, + num_conditions: int = 20, + nan_indicator: int = -1 + ) -> pd.DataFrame: + """ + Process the multi-task dataset by merging the original dataset with the other datasets. + + Args: + datasets (Dict): Dictionary mapping each task to its respective dataframe + max_len (int): The maximum length of the sequence + num_conditions (int): The number of conditions + nan_indicator (int): The indicator for NaN values in dataframe + + Returns + ------- + pd.DataFrame: The processed multi-task dataset + """ + # Merging datasets on 'patient_id' + multi_dataset = datasets["original"].merge( + datasets["condition"][["patient_id", "all_conditions"]], + on="patient_id", + how="left", + ) + multi_dataset = multi_dataset.merge( + datasets["mortality"][["patient_id", "label_mortality_1month"]], + on="patient_id", + how="left", + ) + multi_dataset = multi_dataset.merge( + datasets["readmission"][ + ["patient_id", "cutoff_readmission", "label_readmission_1month"] + ], + on="patient_id", + how="left", + ) + multi_dataset = multi_dataset.merge( + datasets["los"][["patient_id", "cutoff_los", "label_los_1week"]], + on="patient_id", + how="left", + ) + + # Selecting the required columns + multi_dataset = multi_dataset[ + [ + "patient_id", + "num_visits", + f"event_tokens_{max_len}", + f"type_tokens_{max_len}", + f"age_tokens_{max_len}", + f"time_tokens_{max_len}", + f"visit_tokens_{max_len}", + f"position_tokens_{max_len}", + f"elapsed_tokens_{max_len}", + "cutoff_los", + "cutoff_readmission", + "all_conditions", + "label_mortality_1month", + "label_readmission_1month", + "label_los_1week", + ] + ] + + # Transform conditions from a vector of numbers to binary classes + conditions_expanded = multi_dataset["all_conditions"].apply(pd.Series) + conditions_expanded.columns = [f"condition{i}" for i in range(num_conditions)] + multi_dataset = multi_dataset.drop("all_conditions", axis=1) + multi_dataset = pd.concat([multi_dataset, conditions_expanded], axis=1) + + # Standardize important column names + multi_dataset.rename( + columns={ + "cutoff_los": "cutoff_los_1week", + "cutoff_readmission": "cutoff_readmission_1month", + }, + inplace=True, + ) + condition_columns = {f"condition{i}": f"label_c{i}" for i in range(num_conditions)} + multi_dataset.rename(columns=condition_columns, inplace=True) + + numerical_columns = [ + "cutoff_los_1week", + "cutoff_readmission_1month", + "label_mortality_1month", + "label_readmission_1month", + "label_los_1week", + ] + [f"label_c{i}" for i in range(num_conditions)] + + # Fill NaN values and convert numerical columns to integers + for column in numerical_columns: + multi_dataset[column] = multi_dataset[column].fillna(nan_indicator).astype(int) + + # Reset dataset index + multi_dataset.reset_index(drop=True, inplace=True) + + return multi_dataset + + +def stratified_train_test_split( + dataset: pd.DataFrame, + target: str, + test_size: float, + return_test: Optional[bool] = False, + seed: int = SEED, +) -> List[str]: + """Split the given dataset into training and testing sets + using iterative stratification on given multi-label target. + + Args: + dataset (pd.DataFrame): The input dataset. + target (str): The target column for stratification. + test_size (float): The size of the test set. + return_test (bool): Whether to return the test set only. + seed (int): The random seed for reproducibility. + + Returns + ------- + List[str]: The patients ids for training and/or testing set. + """ + # Convert all_conditions into a format suitable for multi-label stratification + Y = np.array(dataset[target].values.tolist()) + X = dataset["patient_id"].to_numpy().reshape(-1, 1) + is_single_label = type(dataset.iloc[0][target]) == np.int64 + + # Perform stratified split + if is_single_label: + X_train, X_test, y_train, y_test = train_test_split( + X, + Y, + stratify=Y, + test_size=test_size, + random_state=seed, + ) + + else: + X_train, y_train, X_test, y_test = iterative_train_test_split( + X, + Y, + test_size=test_size, + ) + + X_train = X_train.flatten().tolist() + X_test = X_test.flatten().tolist() + + if return_test: + return X_test + else: + return X_train, X_test + + +def sample_balanced_subset( + dataset: pd.DataFrame, + target: str, + sample_size: int, + seed: int = SEED +) -> List[str]: + """Sample a subset of dataset with balanced target labels. + + Args: + dataset (pd.DataFrame): The input dataset. + target (str): The target column for sampling. + sample_size (int): The size of the sample. + seed (int): The random seed for reproducibility. + + Returns + ------- + List[str]: The patient ids for the sampled set. + """ + # Sampling positive and negative patients + pos_patients = dataset[dataset[target] == True].sample( + n=sample_size // 2, + random_state=seed, + ) + neg_patients = dataset[dataset[target] == False].sample( + n=sample_size // 2, + random_state=seed, + ) + + # Combining and shuffling patient IDs + sample_patients = ( + pos_patients["patient_id"].tolist() + neg_patients["patient_id"].tolist() + ) + random.shuffle(sample_patients) + + return sample_patients + + +def get_pretrain_test_split( + dataset: pd.DataFrame, + stratify_target: Optional[str] = None, + test_size: float = 0.15, + seed: int = SEED +) -> Tuple[List[str], List[str]]: + """Split dataset into pretrain and test set. Stratify on a given target column if needed. + + Args: + dataset (pd.DataFrame): The input dataset. + stratify_target (str): The target column for stratification. + test_size (int): The size of the test set. + seed (int): The random seed for reproducibility. + + Returns + ------- + Tuple[List[str], List[str]]: The patient ids for the pretrain and test set. + """ + if stratify_target: + pretrain_ids, test_ids = stratified_train_test_split( + dataset, + target=stratify_target, + test_size=test_size, + ) + + else: + test_patients = dataset.sample(n=test_size, random_state=seed) + test_ids = test_patients["patient_id"].tolist() + pretrain_ids = dataset[~dataset["patient_id"].isin(test_patients)][ + "patient_id" + ].tolist() + + random.shuffle(pretrain_ids) + + return pretrain_ids, test_ids + + +def get_finetune_split( + task_config: Any, + task: str, + patient_ids_dict: Dict[str, Any], +) -> Dict[str, Dict[str, List[str]]]: + """Split the dataset into training and cross-finetuneation sets using k-fold cross-finetuneation + while ensuring balanced label distribution in each fold. Saves the resulting dictionary to disk. + + Args: + task_config (Any): A dictionray containing the task-specific configuration. + task (str): The task name. Must be one of the keys in the task_config dictionary. + patient_ids_dict (Dict[str, Any]): A dictionary containing the patient ids for each split. + + Returns + ------- + Dict[str, Dict[str, List[str]]]: The updated patient_ids_dict. + """ + # Extract task-specific configuration + dataset = task_config[task]["dataset"].copy() + label_col = task_config[task]["label_col"] + finetune_sizes = task_config[task]["finetune_size"] + save_path = task_config[task]["save_path"] + split_mode = task_config[task]["split_mode"] + + # Get pretrain dataset + pretrain_ids = patient_ids_dict["pretrain"] + dataset = dataset[dataset["patient_id"].isin(pretrain_ids)] + + # Few-shot finetune patient ids + for finetune_num in finetune_sizes: + + if split_mode == "single_label_balanced": + finetune_ids = sample_balanced_subset( + dataset, + target=label_col, + sample_size=finetune_num, + ) + + elif ( + split_mode == "single_label_stratified" + or split_mode == "multi_label_stratified" + ): + finetune_ids = stratified_train_test_split( + dataset, + target=label_col, + test_size=finetune_num / len(dataset), + return_test=True, + ) + + patient_ids_dict["finetune"]["few_shot"][f"{finetune_num}"] = finetune_ids + + # Save the dictionary to disk + save_object_to_disk(patient_ids_dict, save_path) + + return patient_ids_dict