Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve code style, more docstrings #21

Merged
merged 1 commit into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@

from odyssey.data.dataset import FinetuneDataset
from odyssey.data.tokenizer import ConceptTokenizer
from odyssey.models.cehr_big_bird.model import BigBirdFinetune, BigBirdPretrain
from odyssey.models.cehr_bert.model import BertFinetune, BertPretrain
from odyssey.models.cehr_big_bird.model import BigBirdFinetune, BigBirdPretrain
from odyssey.models.utils import (
get_latest_checkpoint,
get_run_id,
Expand Down
6 changes: 3 additions & 3 deletions odyssey/models/baseline/Bi-LSTM.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

from odyssey.models.big_bird_cehr.data import FinetuneDataset
from odyssey.models.big_bird_cehr.embeddings import Embeddings
from odyssey.models.big_bird_cehr.tokenizer import HuggingFaceConceptTokenizer
from odyssey.data.dataset import FinetuneDataset
from odyssey.models.cehr_big_bird.embeddings import Embeddings
from odyssey.models.cehr_big_bird.tokenizer import HuggingFaceConceptTokenizer


ROOT = "/fs01/home/afallah/odyssey/odyssey"
Expand Down
1 change: 1 addition & 0 deletions odyssey/models/cehr_bert/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""CEHR-BERT model sub-package."""
1 change: 1 addition & 0 deletions odyssey/models/cehr_big_bird/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""CEHR-Big Bird model sub-package."""
39 changes: 25 additions & 14 deletions odyssey/models/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from odyssey.models.big_bird_cehr.model import BigBirdFinetune, BigBirdPretrain
from odyssey.models.cehr_big_bird.model import BigBirdFinetune, BigBirdPretrain
from odyssey.tokenizer import ConceptTokenizer


Expand All @@ -25,14 +25,19 @@ def load_finetuned_model(
----------
model_path: str
Path to the finetuned model to load
tokenizer: Loaded tokenizer object
pre_model_config: Optional config to override default values of a pretrained model
fine_model_config: Optional config to override default values of a finetuned model
device: CUDA device. By default, GPU is used
tokenizer: ConceptTokenizer
Loaded tokenizer object
pre_model_config: Dict[str, Any], optional
Optional config to override default values of a pretrained model
fine_model_config: Dict[str, Any], optional
Optional config to override default values of a finetuned model
device: torch.device, optional
CUDA device. By default, GPU is used

Returns
-------
The loaded PyTorch model
torch.nn.Module
Finetuned model loaded from model_path

"""
# Load GPU or CPU device
Expand Down Expand Up @@ -65,16 +70,22 @@ def predict_patient_outcomes(
model: torch.nn.Module,
device: Optional[torch.device] = None,
) -> Any:
"""
Return model output predictions on given patient data.
"""Compute model output predictions on given patient data.

Parameters
----------
patient: Dict[str, torch.Tensor]
Patient data as a dictionary of tensors
model: torch.nn.Module
Model to use for prediction
device: torch.device, optional
CUDA device. By default, GPU is used

Args:
patient: Dictionary of tokenized patient information
model: Finetuned or pretrained PyTorch model
device: CUDA device. By default, GPU is used
Returns
-------
Any
Model output predictions on the given patient data

Return:
Outputs of model predictions on given patient data
"""
# Load GPU or CPU device
if not device:
Expand Down
87 changes: 81 additions & 6 deletions odyssey/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,36 @@


def load_config(config_dir: str, model_type: str) -> Any:
"""Load the model configuration."""
"""Load the model configuration.

Parameters
----------
config_dir: str
Directory containing the model configuration files

model_type: str
Model type to load configuration for

Returns
-------
Any
Model configuration

"""
config_file = join(config_dir, f"{model_type}.yaml")
with open(config_file, "r") as file:
return yaml.safe_load(file)


def seed_everything(seed: int) -> None:
"""Seed all components of the model."""
"""Seed all components of the model.

Parameters
----------
seed: int
Seed value to use

"""
torch.manual_seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed_all(seed)
Expand All @@ -35,7 +57,23 @@ def load_pretrain_data(
sequence_file: str,
id_file: str,
) -> pd.DataFrame:
"""Load the pretraining data."""
"""Load the pretraining data.

Parameters
----------
data_dir: str
Directory containing the data files
sequence_file: str
Sequence file name
id_file: str
ID file name

Returns
-------
pd.DataFrame
Pretraining data

"""
sequence_path = join(data_dir, sequence_file)
id_path = join(data_dir, id_file)

Expand All @@ -59,7 +97,27 @@ def load_finetune_data(
valid_scheme: str,
num_finetune_patients: str,
) -> pd.DataFrame:
"""Load the finetuning data."""
"""Load the finetuning data.

Parameters
----------
data_dir: str
Directory containing the data files
sequence_file: str
Sequence file name
id_file: str
ID file name
valid_scheme: str
Validation scheme
num_finetune_patients: str
Number of finetune patients

Returns
-------
pd.DataFrame
Finetuning data

"""
sequence_path = join(data_dir, sequence_file)
id_path = join(data_dir, id_file)

Expand Down Expand Up @@ -88,10 +146,27 @@ def get_run_id(
run_id_file: str = "wandb_run_id.txt",
length: int = 8,
) -> str:
"""
Return the run ID for the current run.
"""Fetch the run ID for the current run.

If the run ID file exists, retrieve the run ID from the file.
Otherwise, generate a new run ID and save it to the file.

Parameters
----------
checkpoint_dir: str
Directory to store the run ID file
retrieve: bool, optional
Retrieve the run ID from the file, by default False
run_id_file: str, optional
Run ID file name, by default "wandb_run_id.txt"
length: int, optional
String length of the run ID, by default 8

Returns
-------
str
Run ID for the current run

"""
run_id_path = os.path.join(checkpoint_dir, run_id_file)
if retrieve and os.path.exists(run_id_path):
Expand Down
6 changes: 3 additions & 3 deletions pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

from odyssey.data.dataset import PretrainDataset
from odyssey.data.tokenizer import ConceptTokenizer
from odyssey.models.cehr_big_bird.model import BigBirdPretrain
from odyssey.models.cehr_bert.model import BertPretrain
from odyssey.models.cehr_big_bird.model import BigBirdPretrain
from odyssey.models.utils import (
get_run_id,
load_config,
Expand All @@ -37,14 +37,14 @@ def main(args: Dict[str, Any], model_config: Dict[str, Any]) -> None:
args.sequence_file,
args.id_file,
)
# pre_data.rename(columns={args.label_name: "label"}, inplace=True)
# pre_data.rename(columns={args.label_name: "label"}, inplace=True) # noqa: ERA001

# Split data
pre_train, pre_val = train_test_split(
pre_data,
test_size=args.val_size,
random_state=args.seed,
# stratify=pre_data["label"],
# stratify=pre_data["label"], # noqa: ERA001
)

# Train Tokenizer
Expand Down
2 changes: 1 addition & 1 deletion tests/odyssey/data/mimiciv/test_collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def setUp(self) -> None:
buffer_size=10,
)

def tearDown(self) -> None: # noqa: N802
def tearDown(self) -> None:
"""Tear down FHIRDataCollector."""
if os.path.exists(self.save_dir):
shutil.rmtree(self.save_dir)
Expand Down
Loading