Skip to content

Commit

Permalink
Add some small fixes, sklearn to dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
amrit110 committed Mar 18, 2024
1 parent deed430 commit 4cc3caf
Show file tree
Hide file tree
Showing 8 changed files with 261 additions and 144 deletions.
4 changes: 2 additions & 2 deletions data/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,8 +545,8 @@ def _time_after_admission(
encounter_index = encounter_row["encounter_ids"].index(encounter_id)
start_time = encounter_row["starts"][encounter_index]
start_time = parser.parse(start_time)
event_time = parser.parse(event_time)
elapsed_time = round((event_time - start_time).total_seconds() / 3600, 2)
event_time_ = parser.parse(event_time)
elapsed_time = round((event_time_ - start_time).total_seconds() / 3600, 2)
elapsed_times.append(elapsed_time)
row["elapsed_time"] = elapsed_times
return row
Expand Down
2 changes: 2 additions & 0 deletions lib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def get_latest_checkpoint(checkpoint_dir: str) -> Any:
list_of_files = glob.glob(join(checkpoint_dir, "last*.ckpt"))
return max(list_of_files, key=os.path.getctime) if list_of_files else None

return None


def load_pretrain_data(
data_dir: str,
Expand Down
10 changes: 5 additions & 5 deletions models/baseline/Bi-LSTM.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm


ROOT = "/fs01/home/afallah/odyssey/odyssey"
os.chdir(ROOT)

from models.big_bird_cehr.data import FinetuneDataset
from models.big_bird_cehr.embeddings import Embeddings
from models.big_bird_cehr.tokenizer import HuggingFaceConceptTokenizer


ROOT = "/fs01/home/afallah/odyssey/odyssey"
os.chdir(ROOT)


DATA_ROOT = f"{ROOT}/data/slurm_data/512/one_month"
DATA_PATH = f"{DATA_ROOT}/pretrain.parquet"
FINE_TUNE_PATH = f"{DATA_ROOT}/fine_tune.parquet"
Expand Down Expand Up @@ -162,7 +162,7 @@ def get_inputs_labels(

@staticmethod
def get_balanced_accuracy(outputs: torch.Tensor, labels: torch.Tensor) -> Any:
"""Return the balanced accuracy metric by comparing outputs to labels"""
"""Return the balanced accuracy metric by comparing outputs to labels."""
predictions = torch.round(sigmoid(outputs))
predictions = predictions.detach().cpu().numpy()
labels = labels.detach().cpu().numpy()
Expand Down
4 changes: 3 additions & 1 deletion models/big_bird_cehr/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,10 @@ def cache_input(
visit_orders: torch.Tensor,
visit_segments: torch.Tensor,
) -> None:
"""Cache values for time_stamps, ages, visit_orders & visit_segments inside the class object.
"""Cache values for time_stamps, ages, visit_orders & visit_segments.
These values will be used by the forward pass to change the final embedding.
"""
self.time_stamps = time_stamps
self.ages = ages
Expand Down
15 changes: 8 additions & 7 deletions models/cehr_bert/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Embedding modules."""

import math

import torch
Expand All @@ -19,7 +21,7 @@ def __init__(self, embedding_size: int, is_time_delta: bool = False):
nn.init.xavier_uniform_(self.phi)

def forward(self, time_stamps: torch.Tensor) -> torch.Tensor:
"""Applies time embedding to the input time stamps."""
"""Apply time embedding to the input time stamps."""
if self.is_time_delta:
# If the time_stamps represent time deltas, we calculate the deltas.
# This is equivalent to the difference between consecutive elements.
Expand Down Expand Up @@ -48,7 +50,7 @@ def __init__(
self.embedding = nn.Embedding(self.visit_order_size, self.embedding_size)

def forward(self, visit_segments: torch.Tensor) -> torch.Tensor:
"""Applies visit embedding to the input visit segments."""
"""Apply visit embedding to the input visit segments."""
return self.embedding(visit_segments)


Expand All @@ -69,7 +71,7 @@ def __init__(
)

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
"""Applies concept embedding to the input concepts."""
"""Apply concept embedding to the input concepts."""
return self.embedding(inputs)


Expand All @@ -95,7 +97,7 @@ def __init__(self, embedding_size, max_len=512):
self.register_buffer("pe", pe)

def forward(self, visit_orders: torch.Tensor) -> torch.Tensor:
"""Applies positional embedding to the input visit orders."""
"""Apply positional embedding to the input visit orders."""
first_visit_concept_orders = visit_orders[:, 0:1]
normalized_visit_orders = torch.clamp(
visit_orders - first_visit_concept_orders,
Expand Down Expand Up @@ -157,7 +159,7 @@ def forward(
visit_orders: torch.Tensor,
visit_segments: torch.Tensor,
) -> torch.Tensor:
"""Applies embeddings to the input features."""
"""Apply embeddings to the input features."""
concept_embed = self.concept_embedding(concept_ids)
type_embed = self.token_type_embeddings(type_ids)
time_embed = self.time_embedding(time_stamps)
Expand All @@ -169,6 +171,5 @@ def forward(
embeddings = self.tanh(self.scale_back_concat_layer(embeddings))
embeddings = embeddings + type_embed + positional_embed + visit_segment_embed
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)

return embeddings
return self.dropout(embeddings)
46 changes: 25 additions & 21 deletions models/cehr_bert/model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""CEHR-BERT model."""

from typing import Optional, Tuple, Union

import pytorch_lightning as pl
Expand Down Expand Up @@ -89,7 +91,7 @@ def __init__(
self.post_init()

def _init_weights(self, module) -> None:
"""Initialize the weights"""
"""Initialize the weights."""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.bert_config.initializer_range)
if module.bias is not None:
Expand All @@ -103,11 +105,12 @@ def _init_weights(self, module) -> None:
module.weight.data.fill_(1.0)

def post_init(self) -> None:
"""Apply post initialization."""
self.apply(self._init_weights)

def forward(
self,
input: Tuple[
input_: Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
Expand All @@ -122,7 +125,7 @@ def forward(
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor, ...], MaskedLMOutput]:
"""Forward pass for the model."""
concept_ids, type_ids, time_stamps, ages, visit_orders, visit_segments = input
concept_ids, type_ids, time_stamps, ages, visit_orders, visit_segments = input_
embedding_output = self.embeddings(
concept_ids,
type_ids,
Expand Down Expand Up @@ -166,8 +169,8 @@ def forward(
)

def training_step(self, batch, batch_idx) -> torch.Tensor:
"""Training step."""
input = (
"""Compute training step."""
input_ = (
batch["concept_ids"],
batch["type_ids"],
batch["time_stamps"],
Expand All @@ -178,7 +181,7 @@ def training_step(self, batch, batch_idx) -> torch.Tensor:
labels = batch["labels"]
attention_mask = batch["attention_mask"]
loss = self(
input,
input_,
attention_mask=attention_mask,
labels=labels,
return_dict=True,
Expand All @@ -187,8 +190,8 @@ def training_step(self, batch, batch_idx) -> torch.Tensor:
return loss

def validation_step(self, batch, batch_idx) -> torch.Tensor:
"""Validation step."""
input = (
"""Compute validation step."""
input_ = (
batch["concept_ids"],
batch["type_ids"],
batch["time_stamps"],
Expand All @@ -199,7 +202,7 @@ def validation_step(self, batch, batch_idx) -> torch.Tensor:
labels = batch["labels"]
attention_mask = batch["attention_mask"]
loss = self(
input,
input_,
attention_mask=attention_mask,
labels=labels,
return_dict=True,
Expand Down Expand Up @@ -281,11 +284,12 @@ def _init_weights(self, module):
module.weight.data.fill_(1.0)

def post_init(self):
"""Apply post initialization."""
self.apply(self._init_weights)

def forward(
self,
input: Tuple[
input_: Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
Expand All @@ -301,9 +305,9 @@ def forward(
) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutput]:
"""Forward pass for the model."""
if attention_mask is None:
attention_mask = torch.ones_like(input[0])
attention_mask = torch.ones_like(input_[0])
outputs = self.pretrained_model(
input,
input_,
attention_mask=attention_mask,
output_attentions=True,
output_hidden_states=True,
Expand Down Expand Up @@ -332,8 +336,8 @@ def forward(
)

def training_step(self, batch, batch_idx):
"""Training step."""
input = (
"""Compute training step."""
input_ = (
batch["concept_ids"],
batch["type_ids"],
batch["time_stamps"],
Expand All @@ -344,7 +348,7 @@ def training_step(self, batch, batch_idx):
labels = batch["labels"]
attention_mask = batch["attention_mask"]
loss = self(
input,
input_,
attention_mask=attention_mask,
labels=labels,
return_dict=True,
Expand All @@ -353,8 +357,8 @@ def training_step(self, batch, batch_idx):
return loss

def validation_step(self, batch, batch_idx):
"""Validation step."""
input = (
"""Compute validation step."""
input_ = (
batch["concept_ids"],
batch["type_ids"],
batch["time_stamps"],
Expand All @@ -365,7 +369,7 @@ def validation_step(self, batch, batch_idx):
labels = batch["labels"]
attention_mask = batch["attention_mask"]
loss = self(
input,
input_,
attention_mask=attention_mask,
labels=labels,
return_dict=True,
Expand All @@ -374,8 +378,8 @@ def validation_step(self, batch, batch_idx):
return loss

def test_step(self, batch, batch_idx):
"""Test step."""
input = (
"""Compute test step."""
input_ = (
batch["concept_ids"],
batch["type_ids"],
batch["time_stamps"],
Expand All @@ -386,7 +390,7 @@ def test_step(self, batch, batch_idx):
labels = batch["labels"]
attention_mask = batch["attention_mask"]
outputs = self(
input,
input_,
attention_mask=attention_mask,
labels=labels,
return_dict=True,
Expand Down
Loading

0 comments on commit 4cc3caf

Please sign in to comment.