From 7bce1385372a875c32c7ce620797739041737626 Mon Sep 17 00:00:00 2001 From: Amrit K Date: Thu, 28 Mar 2024 08:42:05 -0400 Subject: [PATCH] Improve code style, more docstrings --- finetune.py | 2 +- odyssey/models/baseline/Bi-LSTM.py | 6 +- odyssey/models/cehr_bert/__init__.py | 1 + odyssey/models/cehr_big_bird/__init__.py | 1 + odyssey/models/prediction.py | 39 ++++++---- odyssey/models/utils.py | 87 ++++++++++++++++++++-- pretrain.py | 6 +- tests/odyssey/data/mimiciv/test_collect.py | 2 +- 8 files changed, 116 insertions(+), 28 deletions(-) create mode 100644 odyssey/models/cehr_bert/__init__.py diff --git a/finetune.py b/finetune.py index 079de6d..f6c5bd9 100644 --- a/finetune.py +++ b/finetune.py @@ -19,8 +19,8 @@ from odyssey.data.dataset import FinetuneDataset from odyssey.data.tokenizer import ConceptTokenizer -from odyssey.models.cehr_big_bird.model import BigBirdFinetune, BigBirdPretrain from odyssey.models.cehr_bert.model import BertFinetune, BertPretrain +from odyssey.models.cehr_big_bird.model import BigBirdFinetune, BigBirdPretrain from odyssey.models.utils import ( get_latest_checkpoint, get_run_id, diff --git a/odyssey/models/baseline/Bi-LSTM.py b/odyssey/models/baseline/Bi-LSTM.py index 8e20a38..193a2aa 100644 --- a/odyssey/models/baseline/Bi-LSTM.py +++ b/odyssey/models/baseline/Bi-LSTM.py @@ -17,9 +17,9 @@ from torch.utils.data import DataLoader, Dataset from tqdm import tqdm -from odyssey.models.big_bird_cehr.data import FinetuneDataset -from odyssey.models.big_bird_cehr.embeddings import Embeddings -from odyssey.models.big_bird_cehr.tokenizer import HuggingFaceConceptTokenizer +from odyssey.data.dataset import FinetuneDataset +from odyssey.models.cehr_big_bird.embeddings import Embeddings +from odyssey.models.cehr_big_bird.tokenizer import HuggingFaceConceptTokenizer ROOT = "/fs01/home/afallah/odyssey/odyssey" diff --git a/odyssey/models/cehr_bert/__init__.py b/odyssey/models/cehr_bert/__init__.py new file mode 100644 index 0000000..5e72351 --- /dev/null +++ b/odyssey/models/cehr_bert/__init__.py @@ -0,0 +1 @@ +"""CEHR-BERT model sub-package.""" diff --git a/odyssey/models/cehr_big_bird/__init__.py b/odyssey/models/cehr_big_bird/__init__.py index e69de29..22f8d15 100644 --- a/odyssey/models/cehr_big_bird/__init__.py +++ b/odyssey/models/cehr_big_bird/__init__.py @@ -0,0 +1 @@ +"""CEHR-Big Bird model sub-package.""" diff --git a/odyssey/models/prediction.py b/odyssey/models/prediction.py index 2d66e4f..7b8f92c 100644 --- a/odyssey/models/prediction.py +++ b/odyssey/models/prediction.py @@ -4,7 +4,7 @@ import torch -from odyssey.models.big_bird_cehr.model import BigBirdFinetune, BigBirdPretrain +from odyssey.models.cehr_big_bird.model import BigBirdFinetune, BigBirdPretrain from odyssey.tokenizer import ConceptTokenizer @@ -25,14 +25,19 @@ def load_finetuned_model( ---------- model_path: str Path to the finetuned model to load - tokenizer: Loaded tokenizer object - pre_model_config: Optional config to override default values of a pretrained model - fine_model_config: Optional config to override default values of a finetuned model - device: CUDA device. By default, GPU is used + tokenizer: ConceptTokenizer + Loaded tokenizer object + pre_model_config: Dict[str, Any], optional + Optional config to override default values of a pretrained model + fine_model_config: Dict[str, Any], optional + Optional config to override default values of a finetuned model + device: torch.device, optional + CUDA device. By default, GPU is used Returns ------- - The loaded PyTorch model + torch.nn.Module + Finetuned model loaded from model_path """ # Load GPU or CPU device @@ -65,16 +70,22 @@ def predict_patient_outcomes( model: torch.nn.Module, device: Optional[torch.device] = None, ) -> Any: - """ - Return model output predictions on given patient data. + """Compute model output predictions on given patient data. + + Parameters + ---------- + patient: Dict[str, torch.Tensor] + Patient data as a dictionary of tensors + model: torch.nn.Module + Model to use for prediction + device: torch.device, optional + CUDA device. By default, GPU is used - Args: - patient: Dictionary of tokenized patient information - model: Finetuned or pretrained PyTorch model - device: CUDA device. By default, GPU is used + Returns + ------- + Any + Model output predictions on the given patient data - Return: - Outputs of model predictions on given patient data """ # Load GPU or CPU device if not device: diff --git a/odyssey/models/utils.py b/odyssey/models/utils.py index df32ccb..efce177 100644 --- a/odyssey/models/utils.py +++ b/odyssey/models/utils.py @@ -14,14 +14,36 @@ def load_config(config_dir: str, model_type: str) -> Any: - """Load the model configuration.""" + """Load the model configuration. + + Parameters + ---------- + config_dir: str + Directory containing the model configuration files + + model_type: str + Model type to load configuration for + + Returns + ------- + Any + Model configuration + + """ config_file = join(config_dir, f"{model_type}.yaml") with open(config_file, "r") as file: return yaml.safe_load(file) def seed_everything(seed: int) -> None: - """Seed all components of the model.""" + """Seed all components of the model. + + Parameters + ---------- + seed: int + Seed value to use + + """ torch.manual_seed(seed) np.random.seed(seed) torch.cuda.manual_seed_all(seed) @@ -35,7 +57,23 @@ def load_pretrain_data( sequence_file: str, id_file: str, ) -> pd.DataFrame: - """Load the pretraining data.""" + """Load the pretraining data. + + Parameters + ---------- + data_dir: str + Directory containing the data files + sequence_file: str + Sequence file name + id_file: str + ID file name + + Returns + ------- + pd.DataFrame + Pretraining data + + """ sequence_path = join(data_dir, sequence_file) id_path = join(data_dir, id_file) @@ -59,7 +97,27 @@ def load_finetune_data( valid_scheme: str, num_finetune_patients: str, ) -> pd.DataFrame: - """Load the finetuning data.""" + """Load the finetuning data. + + Parameters + ---------- + data_dir: str + Directory containing the data files + sequence_file: str + Sequence file name + id_file: str + ID file name + valid_scheme: str + Validation scheme + num_finetune_patients: str + Number of finetune patients + + Returns + ------- + pd.DataFrame + Finetuning data + + """ sequence_path = join(data_dir, sequence_file) id_path = join(data_dir, id_file) @@ -88,10 +146,27 @@ def get_run_id( run_id_file: str = "wandb_run_id.txt", length: int = 8, ) -> str: - """ - Return the run ID for the current run. + """Fetch the run ID for the current run. If the run ID file exists, retrieve the run ID from the file. + Otherwise, generate a new run ID and save it to the file. + + Parameters + ---------- + checkpoint_dir: str + Directory to store the run ID file + retrieve: bool, optional + Retrieve the run ID from the file, by default False + run_id_file: str, optional + Run ID file name, by default "wandb_run_id.txt" + length: int, optional + String length of the run ID, by default 8 + + Returns + ------- + str + Run ID for the current run + """ run_id_path = os.path.join(checkpoint_dir, run_id_file) if retrieve and os.path.exists(run_id_path): diff --git a/pretrain.py b/pretrain.py index c123fd4..3f6eba8 100644 --- a/pretrain.py +++ b/pretrain.py @@ -15,8 +15,8 @@ from odyssey.data.dataset import PretrainDataset from odyssey.data.tokenizer import ConceptTokenizer -from odyssey.models.cehr_big_bird.model import BigBirdPretrain from odyssey.models.cehr_bert.model import BertPretrain +from odyssey.models.cehr_big_bird.model import BigBirdPretrain from odyssey.models.utils import ( get_run_id, load_config, @@ -37,14 +37,14 @@ def main(args: Dict[str, Any], model_config: Dict[str, Any]) -> None: args.sequence_file, args.id_file, ) - # pre_data.rename(columns={args.label_name: "label"}, inplace=True) + # pre_data.rename(columns={args.label_name: "label"}, inplace=True) # noqa: ERA001 # Split data pre_train, pre_val = train_test_split( pre_data, test_size=args.val_size, random_state=args.seed, - # stratify=pre_data["label"], + # stratify=pre_data["label"], # noqa: ERA001 ) # Train Tokenizer diff --git a/tests/odyssey/data/mimiciv/test_collect.py b/tests/odyssey/data/mimiciv/test_collect.py index 31852c0..6497ee3 100644 --- a/tests/odyssey/data/mimiciv/test_collect.py +++ b/tests/odyssey/data/mimiciv/test_collect.py @@ -24,7 +24,7 @@ def setUp(self) -> None: buffer_size=10, ) - def tearDown(self) -> None: # noqa: N802 + def tearDown(self) -> None: """Tear down FHIRDataCollector.""" if os.path.exists(self.save_dir): shutil.rmtree(self.save_dir)