Skip to content

Commit

Permalink
Improve code style, more docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
amrit110 committed Mar 28, 2024
1 parent f344ce5 commit 7bce138
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 28 deletions.
2 changes: 1 addition & 1 deletion finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@

from odyssey.data.dataset import FinetuneDataset
from odyssey.data.tokenizer import ConceptTokenizer
from odyssey.models.cehr_big_bird.model import BigBirdFinetune, BigBirdPretrain
from odyssey.models.cehr_bert.model import BertFinetune, BertPretrain
from odyssey.models.cehr_big_bird.model import BigBirdFinetune, BigBirdPretrain
from odyssey.models.utils import (
get_latest_checkpoint,
get_run_id,
Expand Down
6 changes: 3 additions & 3 deletions odyssey/models/baseline/Bi-LSTM.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

from odyssey.models.big_bird_cehr.data import FinetuneDataset
from odyssey.models.big_bird_cehr.embeddings import Embeddings
from odyssey.models.big_bird_cehr.tokenizer import HuggingFaceConceptTokenizer
from odyssey.data.dataset import FinetuneDataset
from odyssey.models.cehr_big_bird.embeddings import Embeddings
from odyssey.models.cehr_big_bird.tokenizer import HuggingFaceConceptTokenizer


ROOT = "/fs01/home/afallah/odyssey/odyssey"
Expand Down
1 change: 1 addition & 0 deletions odyssey/models/cehr_bert/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""CEHR-BERT model sub-package."""
1 change: 1 addition & 0 deletions odyssey/models/cehr_big_bird/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""CEHR-Big Bird model sub-package."""
39 changes: 25 additions & 14 deletions odyssey/models/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from odyssey.models.big_bird_cehr.model import BigBirdFinetune, BigBirdPretrain
from odyssey.models.cehr_big_bird.model import BigBirdFinetune, BigBirdPretrain
from odyssey.tokenizer import ConceptTokenizer


Expand All @@ -25,14 +25,19 @@ def load_finetuned_model(
----------
model_path: str
Path to the finetuned model to load
tokenizer: Loaded tokenizer object
pre_model_config: Optional config to override default values of a pretrained model
fine_model_config: Optional config to override default values of a finetuned model
device: CUDA device. By default, GPU is used
tokenizer: ConceptTokenizer
Loaded tokenizer object
pre_model_config: Dict[str, Any], optional
Optional config to override default values of a pretrained model
fine_model_config: Dict[str, Any], optional
Optional config to override default values of a finetuned model
device: torch.device, optional
CUDA device. By default, GPU is used
Returns
-------
The loaded PyTorch model
torch.nn.Module
Finetuned model loaded from model_path
"""
# Load GPU or CPU device
Expand Down Expand Up @@ -65,16 +70,22 @@ def predict_patient_outcomes(
model: torch.nn.Module,
device: Optional[torch.device] = None,
) -> Any:
"""
Return model output predictions on given patient data.
"""Compute model output predictions on given patient data.
Parameters
----------
patient: Dict[str, torch.Tensor]
Patient data as a dictionary of tensors
model: torch.nn.Module
Model to use for prediction
device: torch.device, optional
CUDA device. By default, GPU is used
Args:
patient: Dictionary of tokenized patient information
model: Finetuned or pretrained PyTorch model
device: CUDA device. By default, GPU is used
Returns
-------
Any
Model output predictions on the given patient data
Return:
Outputs of model predictions on given patient data
"""
# Load GPU or CPU device
if not device:
Expand Down
87 changes: 81 additions & 6 deletions odyssey/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,36 @@


def load_config(config_dir: str, model_type: str) -> Any:
"""Load the model configuration."""
"""Load the model configuration.
Parameters
----------
config_dir: str
Directory containing the model configuration files
model_type: str
Model type to load configuration for
Returns
-------
Any
Model configuration
"""
config_file = join(config_dir, f"{model_type}.yaml")
with open(config_file, "r") as file:
return yaml.safe_load(file)


def seed_everything(seed: int) -> None:
"""Seed all components of the model."""
"""Seed all components of the model.
Parameters
----------
seed: int
Seed value to use
"""
torch.manual_seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed_all(seed)
Expand All @@ -35,7 +57,23 @@ def load_pretrain_data(
sequence_file: str,
id_file: str,
) -> pd.DataFrame:
"""Load the pretraining data."""
"""Load the pretraining data.
Parameters
----------
data_dir: str
Directory containing the data files
sequence_file: str
Sequence file name
id_file: str
ID file name
Returns
-------
pd.DataFrame
Pretraining data
"""
sequence_path = join(data_dir, sequence_file)
id_path = join(data_dir, id_file)

Expand All @@ -59,7 +97,27 @@ def load_finetune_data(
valid_scheme: str,
num_finetune_patients: str,
) -> pd.DataFrame:
"""Load the finetuning data."""
"""Load the finetuning data.
Parameters
----------
data_dir: str
Directory containing the data files
sequence_file: str
Sequence file name
id_file: str
ID file name
valid_scheme: str
Validation scheme
num_finetune_patients: str
Number of finetune patients
Returns
-------
pd.DataFrame
Finetuning data
"""
sequence_path = join(data_dir, sequence_file)
id_path = join(data_dir, id_file)

Expand Down Expand Up @@ -88,10 +146,27 @@ def get_run_id(
run_id_file: str = "wandb_run_id.txt",
length: int = 8,
) -> str:
"""
Return the run ID for the current run.
"""Fetch the run ID for the current run.
If the run ID file exists, retrieve the run ID from the file.
Otherwise, generate a new run ID and save it to the file.
Parameters
----------
checkpoint_dir: str
Directory to store the run ID file
retrieve: bool, optional
Retrieve the run ID from the file, by default False
run_id_file: str, optional
Run ID file name, by default "wandb_run_id.txt"
length: int, optional
String length of the run ID, by default 8
Returns
-------
str
Run ID for the current run
"""
run_id_path = os.path.join(checkpoint_dir, run_id_file)
if retrieve and os.path.exists(run_id_path):
Expand Down
6 changes: 3 additions & 3 deletions pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

from odyssey.data.dataset import PretrainDataset
from odyssey.data.tokenizer import ConceptTokenizer
from odyssey.models.cehr_big_bird.model import BigBirdPretrain
from odyssey.models.cehr_bert.model import BertPretrain
from odyssey.models.cehr_big_bird.model import BigBirdPretrain
from odyssey.models.utils import (
get_run_id,
load_config,
Expand All @@ -37,14 +37,14 @@ def main(args: Dict[str, Any], model_config: Dict[str, Any]) -> None:
args.sequence_file,
args.id_file,
)
# pre_data.rename(columns={args.label_name: "label"}, inplace=True)
# pre_data.rename(columns={args.label_name: "label"}, inplace=True) # noqa: ERA001

# Split data
pre_train, pre_val = train_test_split(
pre_data,
test_size=args.val_size,
random_state=args.seed,
# stratify=pre_data["label"],
# stratify=pre_data["label"], # noqa: ERA001
)

# Train Tokenizer
Expand Down
2 changes: 1 addition & 1 deletion tests/odyssey/data/mimiciv/test_collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def setUp(self) -> None:
buffer_size=10,
)

def tearDown(self) -> None: # noqa: N802
def tearDown(self) -> None:
"""Tear down FHIRDataCollector."""
if os.path.exists(self.save_dir):
shutil.rmtree(self.save_dir)
Expand Down

0 comments on commit 7bce138

Please sign in to comment.