diff --git a/odyssey/data/DataProcessor.ipynb b/odyssey/data/DataProcessor.ipynb deleted file mode 100644 index 590ca2e..0000000 --- a/odyssey/data/DataProcessor.ipynb +++ /dev/null @@ -1,764 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "ExecuteTime": { - "end_time": "2024-03-13T16:14:45.546088300Z", - "start_time": "2024-03-13T16:14:43.587090300Z" - }, - "collapsed": true - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "[rank: 0] Seed set to 23\n", - "[rank: 0] Seed set to 23\n" - ] - } - ], - "source": [ - "import os\n", - "import sys\n", - "import pickle\n", - "import random\n", - "from typing import Any, Dict, List, Optional\n", - "\n", - "import numpy as np\n", - "import pandas as pd\n", - "import polars as pl\n", - "\n", - "SEED = 23\n", - "ROOT = \"/h/afallah/odyssey/odyssey\"\n", - "DATA_ROOT = f\"{ROOT}/odyssey/data/meds_data\" # bigbird_data\n", - "DATASET = f\"{DATA_ROOT}/patient_sequences/patient_sequences.parquet\" #patient_sequences_2048.parquet\\\n", - "DATASET_2048 = f\"{DATA_ROOT}/patient_sequences/patient_sequences_2048.parquet\"\n", - "MAX_LEN = 2048\n", - "\n", - "os.chdir(ROOT)\n", - "\n", - "from odyssey.utils.utils import save_object_to_disk, seed_everything\n", - "from odyssey.data.tokenizer import ConceptTokenizer\n", - "from odyssey.data.dataset import FinetuneMultiDataset\n", - "from odyssey.data.processor import *\n", - "\n", - "seed_everything(seed=SEED)" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "# (lengths <= 4096 ).sum() / len(lengths)\n", - "\n", - "# lengths = dataset['event_tokens'].map_elements(len)\n", - "# (lengths.filter(lengths > 2048) - 2048).sum() / 1e6\n", - "# lengths.filter((lengths > 2048) & (lengths <= 4096)).sum() / 1e6\n", - "# lengths.sum() / 1e6\n", - "\n", - "# sample_id = 2130\n", - "# sample = dataset[sample_id]['event_tokens']\n", - "# print(list(sample[0]), '\\n', len(list(sample[0])))" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/tmp/ipykernel_19611/5770890.py:3: MapWithoutReturnDtypeWarning: Calling `map_elements` without specifying `return_dtype` can lead to unpredictable results. Specify `return_dtype` to silence this warning.\n", - " dataset = dataset.filter(pl.col('event_tokens').map_elements(len) > 5)\n", - "/tmp/ipykernel_19611/5770890.py:13: MapWithoutReturnDtypeWarning: Calling `map_elements` without specifying `return_dtype` can lead to unpredictable results. Specify `return_dtype` to silence this warning.\n", - " dataset = dataset.with_columns([\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "shape: (5, 3)\n", - "┌────────────┬─────────────────────────────────┬──────────────┐\n", - "│ patient_id ┆ event_tokens ┆ token_length │\n", - "│ --- ┆ --- ┆ --- │\n", - "│ str ┆ list[str] ┆ i64 │\n", - "╞════════════╪═════════════════════════════════╪══════════════╡\n", - "│ 11002607 ┆ [\"[CLS]\", \"AGE//82\", … \"[EOS]\"… ┆ 1284 │\n", - "│ 11122678 ┆ [\"[CLS]\", \"AGE//66\", … \"[EOS]\"… ┆ 1249 │\n", - "│ 19846309 ┆ [\"[CLS]\", \"AGE//25\", … \"[EOS]\"… ┆ 19 │\n", - "│ 16935023 ┆ [\"[CLS]\", \"AGE//20\", … \"[EOS]\"… ┆ 52 │\n", - "│ 16677169 ┆ [\"[CLS]\", \"AGE//56\", … \"[EOS]\"… ┆ 441 │\n", - "└────────────┴─────────────────────────────────┴──────────────┘\n", - "Schema([('patient_id', String), ('event_tokens', List(String)), ('token_length', Int64)])\n", - "shape: (5, 3)\n", - "┌────────────┬─────────────────────────────────┬──────────────┐\n", - "│ patient_id ┆ event_tokens ┆ token_length │\n", - "│ --- ┆ --- ┆ --- │\n", - "│ str ┆ list[str] ┆ i64 │\n", - "╞════════════╪═════════════════════════════════╪══════════════╡\n", - "│ 11002607 ┆ [\"[CLS]\", \"AGE//82\", … \"[EOS]\"… ┆ 1284 │\n", - "│ 11122678 ┆ [\"[CLS]\", \"AGE//66\", … \"[EOS]\"… ┆ 1249 │\n", - "│ 19846309 ┆ [\"[CLS]\", \"AGE//25\", … \"[EOS]\"… ┆ 19 │\n", - "│ 16935023 ┆ [\"[CLS]\", \"AGE//20\", … \"[EOS]\"… ┆ 52 │\n", - "│ 16677169 ┆ [\"[CLS]\", \"AGE//56\", … \"[EOS]\"… ┆ 441 │\n", - "└────────────┴─────────────────────────────────┴──────────────┘\n", - "Schema([('patient_id', String), ('event_tokens', List(String)), ('token_length', Int64)])\n" - ] - } - ], - "source": [ - "dataset = pl.read_parquet(DATASET)\n", - "dataset = dataset.rename({\"subject_id\": \"patient_id\", \"code\": \"event_tokens\"})\n", - "dataset = dataset.filter(pl.col('event_tokens').map_elements(len) > 5)\n", - "\n", - "dataset = dataset.with_columns([\n", - " pl.col('patient_id').cast(pl.String).alias('patient_id'),\n", - " pl.concat_list([\n", - " pl.col('event_tokens').list.slice(0, 2047),\n", - " pl.lit(['[EOS]'])\n", - " ]).alias('event_tokens'),\n", - "])\n", - "\n", - "dataset = dataset.with_columns([\n", - " pl.col('event_tokens').map_elements(len).alias('token_length'),\n", - "])\n", - "\n", - "print(dataset.head())\n", - "print(dataset.schema)\n", - "\n", - "dataset.write_parquet(DATASET_2048)\n", - "\n", - "dataset_saved = pl.read_parquet(DATASET_2048)\n", - "print(dataset_saved.head())\n", - "print(dataset_saved.schema)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# patient_ids_dict = {\n", - "# \"pretrain\": [],\n", - "# \"finetune\": {\"few_shot\": {}, \"kfold\": {}},\n", - "# \"test\": [],\n", - "# }\n", - "\n", - "# import numpy as np\n", - "# import pickle\n", - "\n", - "# # Set random seed\n", - "# np.random.seed(23)\n", - "\n", - "# # Get unique patient IDs\n", - "# unique_patients = dataset_saved['patient_id'].unique()\n", - "\n", - "# # Randomly shuffle patient IDs\n", - "# np.random.shuffle(unique_patients)\n", - "\n", - "# # Calculate split sizes\n", - "# n_patients = len(unique_patients)\n", - "# n_pretrain = int(0.65 * n_patients)\n", - "# n_finetune = int(0.25 * n_patients)\n", - "\n", - "# # Split patient IDs\n", - "# patient_ids_dict[\"pretrain\"] = unique_patients[:n_pretrain].tolist()\n", - "# patient_ids_dict[\"finetune\"][\"few_shot\"] = unique_patients[n_pretrain:n_pretrain+n_finetune].tolist()\n", - "# patient_ids_dict[\"test\"] = unique_patients[n_pretrain+n_finetune:].tolist()\n", - "\n", - "# # Save the dictionary\n", - "# save_object_to_disk(patient_ids_dict, f\"{DATA_ROOT}/patient_id_dict/patient_splits.pkl\")\n", - "# len(patient_ids_dict[\"pretrain\"]), len(patient_ids_dict[\"finetune\"][\"few_shot\"]), len(patient_ids_dict[\"test\"]), patient_ids_dict[\"pretrain\"][2323]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# patient_ids_dict = load_object_from_disk(f\"{DATA_ROOT}/patient_id_dict/patient_splits.pkl\")\n", - "# len(patient_ids_dict[\"pretrain\"]), len(patient_ids_dict[\"finetune\"][\"few_shot\"]), len(patient_ids_dict[\"test\"]), patient_ids_dict[\"pretrain\"][2323]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "ExecuteTime": { - "end_time": "2024-03-13T16:15:12.321718600Z", - "start_time": "2024-03-13T16:14:45.553089800Z" - }, - "collapsed": false - }, - "outputs": [], - "source": [ - "# Load complete dataset\n", - "dataset = pd.read_parquet(DATASET)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "dataset[\"num_visits\"] = dataset[\"event_tokens_2048\"].transform(\n", - " lambda series: list(series).count(\"[VS]\")\n", - ")\n", - "\n", - "print(f\"Current columns: {dataset.columns}\")\n", - "dataset.head()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "dataset['event_tokens_2048'].iloc[0]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Process the dataset for length of stay prediction above a threshold\n", - "dataset_los = process_length_of_stay_dataset(\n", - " dataset.copy(), threshold=7, max_len=MAX_LEN\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Process the dataset for conditions including rare and common\n", - "dataset_condition = process_condition_dataset(dataset.copy())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "ExecuteTime": { - "end_time": "2024-03-13T16:15:16.075719400Z", - "start_time": "2024-03-13T16:15:12.335721100Z" - }, - "collapsed": false - }, - "outputs": [], - "source": [ - "# Process the dataset for mortality in two weeks or one month task\n", - "dataset_mortality = process_mortality_dataset(dataset.copy())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "ExecuteTime": { - "end_time": "2024-03-13T16:15:47.326996100Z", - "start_time": "2024-03-13T16:15:16.094719300Z" - }, - "collapsed": false - }, - "outputs": [], - "source": [ - "# Process the dataset for hospital readmission in one month task\n", - "dataset_readmission = process_readmission_dataset(\n", - " dataset.copy(), max_len=MAX_LEN\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Process the multi dataset\n", - "multi_dataset = process_multi_dataset(\n", - " datasets={\n", - " \"original\": dataset,\n", - " \"mortality\": dataset_mortality,\n", - " \"condition\": dataset_condition,\n", - " \"readmission\": dataset_readmission,\n", - " \"los\": dataset_los,\n", - " },\n", - " max_len=MAX_LEN,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Split data\n", - "patient_ids_dict = {\n", - " \"pretrain\": [],\n", - " \"finetune\": {\"few_shot\": {}, \"kfold\": {}},\n", - " \"test\": [],\n", - "}\n", - "\n", - "# Get train-test split\n", - "# pretrain_ids, test_ids = get_pretrain_test_split(dataset_readmission, stratify_target='label_readmission_1month', test_size=0.2)\n", - "# pretrain_ids, test_ids = get_pretrain_test_split(process_condition_dataset, stratify_target='all_conditions', test_size=0.15)\n", - "# patient_ids_dict['pretrain'] = pretrain_ids\n", - "# patient_ids_dict['test'] = test_ids\n", - "\n", - "# Load pretrain and test patient IDs\n", - "pid = pickle.load(open(f\"{DATA_ROOT}/patient_id_dict/dataset_multi.pkl\", \"rb\"))\n", - "patient_ids_dict[\"pretrain\"] = pid[\"pretrain\"]\n", - "patient_ids_dict[\"test\"] = pid[\"test\"]\n", - "set(pid[\"pretrain\"] + pid[\"test\"]) == set(dataset[\"patient_id\"])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "multi_dataset_pretrain = multi_dataset.loc[\n", - " multi_dataset[\"patient_id\"].isin(patient_ids_dict[\"pretrain\"])\n", - "]\n", - "multi_dataset_test = multi_dataset.loc[\n", - " multi_dataset[\"patient_id\"].isin(patient_ids_dict[\"test\"])\n", - "]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# DataFrame assumed loaded as multi_dataset_pretrain\n", - "# Define the requirements for each label\n", - "label_requirements = {\n", - " \"label_mortality_1month\": 25000,\n", - " \"label_readmission_1month\": 30000,\n", - " \"label_los_1week\": 40000,\n", - " # 'label_c0': 25000,\n", - " # 'label_c1': 25000,\n", - " # 'label_c2': 25000\n", - "}\n", - "\n", - "# Prepare a dictionary to store indices for each label and category\n", - "selected_indices = {label: {\"0\": set(), \"1\": set()} for label in label_requirements}\n", - "\n", - "# Initialize a dictionary to track usage of indices across labels\n", - "index_usage = {}\n", - "\n", - "\n", - "# Function to select indices while maximizing overlap\n", - "def select_indices(label, num_required, category):\n", - " # Candidates are those indices matching the category requirement\n", - " candidates = set(\n", - " multi_dataset_pretrain[multi_dataset_pretrain[label] == category].index\n", - " )\n", - " # Prefer candidates that are already used elsewhere to maximize overlap\n", - " preferred = candidates & set(index_usage.keys())\n", - " additional = list(candidates - preferred)\n", - " np.random.shuffle(additional) # Shuffle to avoid any unintended order bias\n", - "\n", - " # Determine how many more are needed\n", - " needed = num_required - len(selected_indices[label][str(category)] & candidates)\n", - " if needed > 0:\n", - " # Select as many as possible from preferred, then from additional\n", - " selected = list(preferred - selected_indices[label][str(category)])[:needed]\n", - " selected += additional[: needed - len(selected)]\n", - " # Update the selected indices for this label and category\n", - " selected_indices[label][str(category)].update(selected)\n", - " # Update overall index usage\n", - " for idx in selected:\n", - " index_usage[idx] = index_usage.get(idx, 0) + 1\n", - "\n", - "\n", - "# Process each label and category\n", - "for label in label_requirements:\n", - " num_required = label_requirements[label] // 2 # Divide by 2 for 50-50 distribution\n", - " select_indices(label, num_required, 0)\n", - " select_indices(label, num_required, 1)\n", - "\n", - "# Combine all selected indices from both categories\n", - "all_selected_indices = set()\n", - "for indices in selected_indices.values():\n", - " all_selected_indices.update(indices[\"0\"])\n", - " all_selected_indices.update(indices[\"1\"])\n", - "\n", - "# Create the balanced DataFrame\n", - "multi_dataset_finetune = multi_dataset_pretrain.loc[list(all_selected_indices)]\n", - "multi_dataset_finetune" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for label in [\n", - " \"label_mortality_1month\",\n", - " \"label_readmission_1month\",\n", - " \"label_los_1week\",\n", - " \"label_c0\",\n", - " \"label_c1\",\n", - " \"label_c2\",\n", - "]:\n", - " print(\n", - " f\"Label: {label} | Mean: {multi_dataset_finetune[label].mean()}\\n{multi_dataset_finetune[label].value_counts()}\\n\"\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "patient_ids_dict[\"finetune\"][\"few_shot\"][\"all\"] = multi_dataset_finetune[\n", - " \"patient_id\"\n", - "].tolist()\n", - "\n", - "multi_dataset_pretrain = multi_dataset_pretrain.loc[\n", - " ~multi_dataset_pretrain[\"patient_id\"].isin(multi_dataset_finetune[\"patient_id\"])\n", - "]\n", - "\n", - "patient_ids_dict[\"pretrain\"] = multi_dataset_pretrain[\"patient_id\"].tolist()\n", - "\n", - "save_object_to_disk(\n", - " patient_ids_dict, f\"{DATA_ROOT}/patient_id_dict/dataset_multi_v2.pkl\"\n", - ")\n", - "\n", - "# \"mortality_1month=0.5, los_1week=0.5, readmission_1month=0.5, c0=0.5, c1=0.5, c2=0.5\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "\"\"\"\n", - "Current Approach:\n", - " - Pretrain: 141234 Patients\n", - " - Test: 24924 Patients, 132682 Datapoints\n", - " - Finetune: 139514 Unique Patients, 434270 Datapoints\n", - " - Mortality: 26962 Patients\n", - " - Readmission: 48898 Patients\n", - " - Length of Stay: 72686 Patients\n", - " - Condition 0: 122722 Patients\n", - " - Condition 1: 94048 Patients\n", - " - Condition 2: 68954 Patients\n", - "\"\"\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "task_config = {\n", - " \"mortality\": {\n", - " \"dataset\": dataset_mortality,\n", - " \"label_col\": \"label_mortality_1month\",\n", - " \"finetune_size\": [250, 500, 1000, 5000, 20000],\n", - " \"save_path\": \"patient_id_dict/dataset_mortality.pkl\",\n", - " \"split_mode\": \"single_label_balanced\",\n", - " },\n", - " \"readmission\": {\n", - " \"dataset\": dataset_readmission,\n", - " \"label_col\": \"label_readmission_1month\",\n", - " \"finetune_size\": [250, 1000, 5000, 20000, 60000],\n", - " \"save_path\": \"patient_id_dict/dataset_readmission.pkl\",\n", - " \"split_mode\": \"single_label_stratified\",\n", - " },\n", - " \"length_of_stay\": {\n", - " \"dataset\": dataset_los,\n", - " \"label_col\": \"label_los_1week\",\n", - " \"finetune_size\": [250, 1000, 5000, 20000, 50000],\n", - " \"save_path\": \"patient_id_dict/dataset_los.pkl\",\n", - " \"split_mode\": \"single_label_balanced\",\n", - " },\n", - " \"condition\": {\n", - " \"dataset\": dataset_condition,\n", - " \"label_col\": \"all_conditions\",\n", - " \"finetune_size\": [50000],\n", - " \"save_path\": \"patient_id_dict/dataset_condition.pkl\",\n", - " \"split_mode\": \"multi_label_stratified\",\n", - " },\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Get finetune split\n", - "for task in task_config.keys():\n", - " patient_ids_dict = get_finetune_split(\n", - " task_config=task_config,\n", - " task=task,\n", - " patient_ids_dict=patient_ids_dict,\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "ExecuteTime": { - "end_time": "2024-03-13T14:14:10.181184300Z", - "start_time": "2024-03-13T14:13:39.154567400Z" - }, - "collapsed": false - }, - "outputs": [], - "source": [ - "# dataset_mortality.to_parquet(\n", - "# \"patient_sequences/patient_sequences_2048_mortality.parquet\",\n", - "# )\n", - "# dataset_readmission.to_parquet(\n", - "# \"patient_sequences/patient_sequences_2048_readmission.parquet\",\n", - "# )\n", - "# dataset_los.to_parquet(\"patient_sequences/patient_sequences_2048_los.parquet\")\n", - "# dataset_condition.to_parquet(\n", - "# \"patient_sequences/patient_sequences_2048_condition.parquet\",\n", - "# )\n", - "multi_dataset.to_parquet(\n", - " f\"{DATA_ROOT}/patient_sequences/patient_sequences_2048_multi_v2.parquet\"\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# # Load data\n", - "# multi_dataset = pd.read_parquet('patient_sequences/patient_sequences_2048_multi.parquet')\n", - "# pid = pickle.load(open('patient_id_dict/dataset_multi.pkl', 'rb'))\n", - "# multi_dataset = multi_dataset[multi_dataset['patient_id'].isin(pid['finetune']['few_shot']['all'])]\n", - "\n", - "# # Train Tokenizer\n", - "# tokenizer = ConceptTokenizer(data_dir='/h/afallah/odyssey/odyssey/odyssey/data/vocab')\n", - "# tokenizer.fit_on_vocab(with_tasks=True)\n", - "\n", - "# # Load datasets\n", - "# tasks = ['mortality_1month', 'los_1week'] + [f'c{i}' for i in range(5)]\n", - "\n", - "# train_dataset = FinetuneMultiDataset(\n", - "# data=multi_dataset,\n", - "# tokenizer=tokenizer,\n", - "# tasks=tasks,\n", - "# balance_guide={'mortality_1month': 0.5, 'los_1week': 0.5},\n", - "# max_len=2048,\n", - "# )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# dataset_condition = pd.read_parquet('patient_sequences/patient_sequences_2048_condition.parquet')\n", - "# pid = pickle.load(open('patient_id_dict/dataset_condition.pkl', 'rb'))\n", - "# condition_finetune = dataset_condition.loc[dataset_condition['patient_id'].isin(pid['finetune']['few_shot']['50000'])]\n", - "# condition_finetune" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# freq = np.array(condition_finetune['all_conditions'].tolist()).sum(axis=0)\n", - "# weights = np.clip(0, 50, sum(freq) / freq)\n", - "# np.max(np.sqrt(freq)) / np.sqrt(freq)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# sorted(patient_ids_dict['pretrain']) == sorted(pickle.load(open('new_data/patient_id_dict/sample_pretrain_test_patient_ids_with_conditions.pkl', 'rb'))['pretrain'])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# merged_df = pd.merge(dataset_mortality, dataset_readmission, how='outer', on='patient_id')\n", - "# final_merged_df = pd.merge(merged_df, dataset_condition, how='outer', on='patient_id')\n", - "# final_merged_df" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Performing stratified k-fold split\n", - "# skf = StratifiedKFold(n_splits=num_splits, shuffle=True, random_state=SEED)\n", - "\n", - "# for i, (train_index, cv_index) in enumerate(skf.split(dataset, dataset[label_col])):\n", - "\n", - "# dataset_cv = dataset.iloc[cv_index]\n", - "# dataset_finetune = dataset.iloc[train_index]\n", - "\n", - "# # Separate positive and negative labeled patients\n", - "# pos_patients = dataset_cv[dataset_cv[label_col] == True]['patient_id'].tolist()\n", - "# neg_patients = dataset_cv[dataset_cv[label_col] == False]['patient_id'].tolist()\n", - "\n", - "# # Calculate the number of positive and negative patients needed for balanced CV set\n", - "# num_pos_needed = cv_size // 2\n", - "# num_neg_needed = cv_size // 2\n", - "\n", - "# # Select positive and negative patients for CV set ensuring balanced distribution\n", - "# cv_patients = pos_patients[:num_pos_needed] + neg_patients[:num_neg_needed]\n", - "# remaining_finetune_patients = pos_patients[num_pos_needed:] + neg_patients[num_neg_needed:]\n", - "\n", - "# # Extract patient IDs for training set\n", - "# finetune_patients = dataset_finetune['patient_id'].tolist()\n", - "# finetune_patients += remaining_finetune_patients\n", - "\n", - "# # Shuffle each list of patients\n", - "# random.shuffle(cv_patients)\n", - "# random.shuffle(finetune_patients)\n", - "\n", - "# patient_ids_dict['finetune']['kfold'][f'group{i+1}'] = {'finetune': finetune_patients, 'cv': cv_patients}" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Assuming dataset.event_tokens is your DataFrame column\n", - "# dataset.event_tokens.transform(len).plot(kind='hist', bins=100)\n", - "# plt.xlim(1000, 8000) # Limit x-axis to 5000\n", - "# plt.ylim(0, 6000)\n", - "# plt.xlabel('Length of Event Tokens')\n", - "# plt.ylabel('Frequency')\n", - "# plt.title('Histogram of Event Tokens Length')\n", - "# plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# len(patient_ids_dict['group3']['cv'])\n", - "\n", - "# dataset.loc[dataset['patient_id'].isin(patient_ids_dict['group1']['cv'])]['label_mortality_1month']\n", - "\n", - "# s = set()\n", - "# for i in range(1, 6):\n", - "# s = s.union(set(patient_ids_dict[f'group{i}']['cv']))\n", - "#\n", - "# len(s)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "##### DEAD ZONE | DO NOT ENTER #####\n", - "\n", - "# patient_ids = pickle.load(open(join(\"/h/afallah/odyssey/odyssey/data/bigbird_data\", 'dataset_mortality_1month.pkl'), 'rb'))\n", - "# patient_ids['finetune']['few_shot'].keys()\n", - "\n", - "# patient_ids2 = pickle.load(open(join(\"/h/afallah/odyssey/odyssey/data/bigbird_data\", 'dataset_mortality_2weeks.pkl'), 'rb'))['pretrain']\n", - "#\n", - "# patient_ids1.sort()\n", - "# patient_ids2.sort()\n", - "#\n", - "# patient_ids1 == patient_ids2\n", - "# # dataset.loc[dataset['patient_id'].isin(patient_ids['pretrain'])]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# dataset_readmission = dataset.loc[dataset['num_visits'] > 1]\n", - "# dataset_readmission.reset_index(drop=True, inplace=True)\n", - "#\n", - "# dataset_readmission['last_VS_index'] = dataset_readmission['event_tokens_2048'].transform(lambda seq: get_last_occurence_index(list(seq), '[VS]'))\n", - "#\n", - "# dataset_readmission['label_readmission_1month'] = dataset_readmission.apply(\n", - "# lambda row: row['event_tokens_2048'][row['last_VS_index'] - 1] in ('[W_0]', '[W_1]', '[W_2]', '[W_3]', '[M_1]'), axis=1\n", - "# )\n", - "# dataset_readmission['event_tokens_2048'] = dataset_readmission.apply(\n", - "# lambda row: row['event_tokens_2048'][:row['last_VS_index'] - 1], axis=1\n", - "# )\n", - "# dataset_readmission.drop(['deceased', 'death_after_start', 'death_after_end', 'length'], axis=1, inplace=True)\n", - "# dataset_readmission['num_visits'] -= 1\n", - "# dataset_readmission['token_length'] = dataset_readmission['event_tokens_2048'].apply(len)\n", - "# dataset_readmission = dataset_readmission.apply(lambda row: truncate_and_pad(row), axis=1)\n", - "# dataset_readmission['event_tokens_2048'] = dataset_readmission['event_tokens_2048'].transform(\n", - "# lambda token_list: ' '.join(token_list)\n", - "# )\n", - "#\n", - "# dataset_readmission" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.9" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -}