Skip to content

Commit

Permalink
Update docstrings to be consistent with numpy style
Browse files Browse the repository at this point in the history
  • Loading branch information
amrit110 committed Apr 19, 2024
1 parent a729dd6 commit 21e3f24
Show file tree
Hide file tree
Showing 6 changed files with 251 additions and 158 deletions.
9 changes: 4 additions & 5 deletions finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def main(
fine_model_config: Dict[str, Any],
) -> None:
"""Train the model."""

# Setup environment
seed_everything(args.seed)
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
Expand Down Expand Up @@ -66,7 +65,7 @@ def main(
random_state=args.seed,
stratify=fine_tune["label"],
)

else: # Multi label classfication
fine_train_ids, _, fine_val_ids, _ = iterative_train_test_split(
X=fine_tune["patient_id"].to_numpy().reshape(-1, 1),
Expand Down Expand Up @@ -113,7 +112,7 @@ def main(
balance_guide=None,
max_len=args.max_len,
)

else:
train_dataset = FinetuneDataset(
data=fine_train,
Expand All @@ -130,7 +129,7 @@ def main(
tokenizer=tokenizer,
max_len=args.max_len,
)

train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
Expand Down Expand Up @@ -185,7 +184,7 @@ def main(
pretrained_model=pretrained_model,
**fine_model_config,
)

elif args.model_type == "cehr_bigbird":
pretrained_model = BigBirdPretrain(
vocab_size=tokenizer.get_vocab_size(),
Expand Down
72 changes: 38 additions & 34 deletions odyssey/data/DataProcessor.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@
" process_readmission_dataset,\n",
" process_multi_dataset,\n",
" stratified_train_test_split,\n",
" sample_balanced_subset, \n",
" sample_balanced_subset,\n",
" get_pretrain_test_split,\n",
" get_finetune_split\n",
" get_finetune_split,\n",
")\n",
"\n",
"SEED = 23\n",
Expand Down Expand Up @@ -77,7 +77,9 @@
"outputs": [],
"source": [
"# Process the dataset for length of stay prediction above a threshold\n",
"dataset_2048_los = process_length_of_stay_dataset(dataset_2048.copy(), threshold=7, max_len=MAX_LEN)"
"dataset_2048_los = process_length_of_stay_dataset(\n",
" dataset_2048.copy(), threshold=7, max_len=MAX_LEN\n",
")"
]
},
{
Expand Down Expand Up @@ -119,7 +121,9 @@
"outputs": [],
"source": [
"# Process the dataset for hospital readmission in one month task\n",
"dataset_2048_readmission = process_readmission_dataset(dataset_2048.copy(), max_len=MAX_LEN)"
"dataset_2048_readmission = process_readmission_dataset(\n",
" dataset_2048.copy(), max_len=MAX_LEN\n",
")"
]
},
{
Expand All @@ -137,7 +141,7 @@
" \"readmission\": dataset_2048_readmission,\n",
" \"los\": dataset_2048_los,\n",
" },\n",
" max_len=MAX_LEN\n",
" max_len=MAX_LEN,\n",
")"
]
},
Expand Down Expand Up @@ -174,35 +178,35 @@
"outputs": [],
"source": [
"task_config = {\n",
" \"mortality\": {\n",
" \"dataset\": dataset_2048_mortality,\n",
" \"label_col\": \"label_mortality_1month\",\n",
" \"finetune_size\": [250, 500, 1000, 5000, 20000],\n",
" \"save_path\": \"patient_id_dict/dataset_2048_mortality.pkl\",\n",
" \"split_mode\": \"single_label_balanced\",\n",
" },\n",
" \"readmission\": {\n",
" \"dataset\": dataset_2048_readmission,\n",
" \"label_col\": \"label_readmission_1month\",\n",
" \"finetune_size\": [250, 1000, 5000, 20000, 60000],\n",
" \"save_path\": \"patient_id_dict/dataset_2048_readmission.pkl\",\n",
" \"split_mode\": \"single_label_stratified\",\n",
" },\n",
" \"length_of_stay\": {\n",
" \"dataset\": dataset_2048_los,\n",
" \"label_col\": \"label_los_1week\",\n",
" \"finetune_size\": [250, 1000, 5000, 20000, 50000],\n",
" \"save_path\": \"patient_id_dict/dataset_2048_los.pkl\",\n",
" \"split_mode\": \"single_label_balanced\",\n",
" },\n",
" \"condition\": {\n",
" \"dataset\": dataset_2048_condition,\n",
" \"label_col\": \"all_conditions\",\n",
" \"finetune_size\": [50000],\n",
" \"save_path\": \"patient_id_dict/dataset_2048_condition.pkl\",\n",
" \"split_mode\": \"multi_label_stratified\",\n",
" },\n",
" }"
" \"mortality\": {\n",
" \"dataset\": dataset_2048_mortality,\n",
" \"label_col\": \"label_mortality_1month\",\n",
" \"finetune_size\": [250, 500, 1000, 5000, 20000],\n",
" \"save_path\": \"patient_id_dict/dataset_2048_mortality.pkl\",\n",
" \"split_mode\": \"single_label_balanced\",\n",
" },\n",
" \"readmission\": {\n",
" \"dataset\": dataset_2048_readmission,\n",
" \"label_col\": \"label_readmission_1month\",\n",
" \"finetune_size\": [250, 1000, 5000, 20000, 60000],\n",
" \"save_path\": \"patient_id_dict/dataset_2048_readmission.pkl\",\n",
" \"split_mode\": \"single_label_stratified\",\n",
" },\n",
" \"length_of_stay\": {\n",
" \"dataset\": dataset_2048_los,\n",
" \"label_col\": \"label_los_1week\",\n",
" \"finetune_size\": [250, 1000, 5000, 20000, 50000],\n",
" \"save_path\": \"patient_id_dict/dataset_2048_los.pkl\",\n",
" \"split_mode\": \"single_label_balanced\",\n",
" },\n",
" \"condition\": {\n",
" \"dataset\": dataset_2048_condition,\n",
" \"label_col\": \"all_conditions\",\n",
" \"finetune_size\": [50000],\n",
" \"save_path\": \"patient_id_dict/dataset_2048_condition.pkl\",\n",
" \"split_mode\": \"multi_label_stratified\",\n",
" },\n",
"}"
]
},
{
Expand Down
Loading

0 comments on commit 21e3f24

Please sign in to comment.