Skip to content

Commit

Permalink
[pre-commit.ci] Add auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed May 1, 2024
1 parent bd28037 commit 7e46abf
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 66 deletions.
6 changes: 1 addition & 5 deletions odyssey/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,11 +719,7 @@ def __getitem__(self, idx: int) -> Dict[str, Any]:
concept_ids = tokenized_input["input_ids"].squeeze()
labels = torch.tensor(labels)

return {
"concept_ids": concept_ids,
"labels": labels,
"task": task
}
return {"concept_ids": concept_ids, "labels": labels, "task": task}

def tokenize_data(self, sequence: Union[str, List[str]]) -> Any:
"""Tokenize the sequence and return input_ids and attention mask.
Expand Down
72 changes: 37 additions & 35 deletions odyssey/models/cehr_mamba/mamba-dev.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
" MambaConfig,\n",
" MambaModel,\n",
" MambaForCausalLM,\n",
" MambaPreTrainedModel\n",
" MambaPreTrainedModel,\n",
")\n",
"\n",
"import numpy as np\n",
Expand Down Expand Up @@ -73,14 +73,20 @@
"outputs": [],
"source": [
"class args:\n",
" data_dir = 'odyssey/data/bigbird_data'\n",
" sequence_file = 'patient_sequences_2048_multi.parquet'\n",
" id_file = 'dataset_2048_multi.pkl'\n",
" vocab_dir = 'odyssey/data/vocab'\n",
" data_dir = \"odyssey/data/bigbird_data\"\n",
" sequence_file = \"patient_sequences_2048_multi.parquet\"\n",
" id_file = \"dataset_2048_multi.pkl\"\n",
" vocab_dir = \"odyssey/data/vocab\"\n",
" max_len = 2048\n",
" mask_prob = 0.15\n",
" tasks = ['mortality_1month', 'los_1week', 'c0', 'c1', 'c2']\n",
" balance_guide = {'mortality_1month': 0.5, 'los_1week': 0.5, 'c0': 0.5, 'c1': 0.5, 'c2': 0.5}"
" tasks = [\"mortality_1month\", \"los_1week\", \"c0\", \"c1\", \"c2\"]\n",
" balance_guide = {\n",
" \"mortality_1month\": 0.5,\n",
" \"los_1week\": 0.5,\n",
" \"c0\": 0.5,\n",
" \"c1\": 0.5,\n",
" \"c2\": 0.5,\n",
" }"
]
},
{
Expand Down Expand Up @@ -239,15 +245,15 @@
],
"source": [
"train_loader = DataLoader(\n",
" decoder_dataset, #test_dataset, #train_dataset\n",
" batch_size=3,\n",
" shuffle=False,\n",
" )\n",
" decoder_dataset, # test_dataset, #train_dataset\n",
" batch_size=3,\n",
" shuffle=False,\n",
")\n",
"\n",
"sample = decoder_dataset[2323] #test_dataset[8765] #train_dataset[0]\n",
"task = sample.pop('task')\n",
"sample = {key:tensor.unsqueeze(0).to(device) for key, tensor in sample.items()}\n",
"sample['task'] = task\n",
"sample = decoder_dataset[2323] # test_dataset[8765] #train_dataset[0]\n",
"task = sample.pop(\"task\")\n",
"sample = {key: tensor.unsqueeze(0).to(device) for key, tensor in sample.items()}\n",
"sample[\"task\"] = task\n",
"\n",
"# sample = next(iter(train_loader))\n",
"# sample = {key:tensor.to(device) for key, tensor in sample.items()}\n",
Expand Down Expand Up @@ -382,9 +388,7 @@
],
"source": [
"# model = model.backbone\n",
"outputs = model(\n",
" input_ids=sample[\"concept_ids\"], return_dict=True\n",
")\n",
"outputs = model(input_ids=sample[\"concept_ids\"], return_dict=True)\n",
"\n",
"last_hidden_states = outputs.last_hidden_state\n",
"last_hidden_states.shape"
Expand Down Expand Up @@ -475,8 +479,8 @@
}
],
"source": [
"sequence_lengths = torch.eq(sample['concept_ids'], 0).int().argmax(-1) - 1\n",
"sequence_lengths "
"sequence_lengths = torch.eq(sample[\"concept_ids\"], 0).int().argmax(-1) - 1\n",
"sequence_lengths"
]
},
{
Expand Down Expand Up @@ -517,7 +521,9 @@
}
],
"source": [
"pooled_last_hidden_states = last_hidden_states[torch.arange(1, device=device), sequence_lengths]\n",
"pooled_last_hidden_states = last_hidden_states[\n",
" torch.arange(1, device=device), sequence_lengths\n",
"]\n",
"classifier(pooled_last_hidden_states)"
]
},
Expand All @@ -539,6 +545,7 @@
],
"source": [
"import copy\n",
"\n",
"config_copy = copy.deepcopy(config)\n",
"config_copy.classifier_dropout = 0.1\n",
"head = MambaClassificationHead(config_copy).to(device)\n",
Expand All @@ -563,7 +570,7 @@
],
"source": [
"loss_fct = torch.nn.CrossEntropyLoss()\n",
"loss = loss_fct(pooled_logits.view(-1,2), torch.tensor([0]).to(device).view(-1))\n",
"loss = loss_fct(pooled_logits.view(-1, 2), torch.tensor([0]).to(device).view(-1))\n",
"loss"
]
},
Expand All @@ -582,7 +589,7 @@
"metadata": {},
"outputs": [],
"source": [
"inputs['input_ids'].shape"
"inputs[\"input_ids\"].shape"
]
},
{
Expand All @@ -601,8 +608,7 @@
"outputs": [],
"source": [
"from odyssey.models.cehr_mamba.model import MambaPretrain\n",
"from torch.nn import MSELoss, CrossEntropyLoss, BCEWithLogitsLoss\n",
"\n"
"from torch.nn import MSELoss, CrossEntropyLoss, BCEWithLogitsLoss"
]
},
{
Expand Down Expand Up @@ -792,11 +798,7 @@
" concept_ids = tokenized_input[\"input_ids\"].squeeze()\n",
" labels = torch.tensor(labels)\n",
"\n",
" return {\n",
" \"concept_ids\": concept_ids,\n",
" \"labels\": labels,\n",
" \"task\": task\n",
" }\n",
" return {\"concept_ids\": concept_ids, \"labels\": labels, \"task\": task}\n",
"\n",
" def tokenize_data(self, sequence: Union[str, List[str]]) -> Any:\n",
" \"\"\"Tokenize the sequence and return input_ids and attention mask.\n",
Expand Down Expand Up @@ -848,11 +850,11 @@
"\n",
"\n",
"decoder_dataset = FinetuneDatasetDecoder(\n",
" data=fine_test,\n",
" tokenizer=tokenizer,\n",
" max_len=args.max_len,\n",
" tasks=args.tasks,\n",
" balance_guide=args.balance_guide,\n",
" data=fine_test,\n",
" tokenizer=tokenizer,\n",
" max_len=args.max_len,\n",
" tasks=args.tasks,\n",
" balance_guide=args.balance_guide,\n",
")\n",
"decoder_dataset[12112]"
]
Expand Down
42 changes: 24 additions & 18 deletions odyssey/models/cehr_mamba/mamba_utils.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,23 @@
"""Utilities following HuggingFace style for Mamba models."""

from typing import Any, Dict, List, Optional, Set, Union, Tuple
from typing import Optional, Tuple, Union

import torch
from torch import nn
from torch.nn import MSELoss, CrossEntropyLoss, BCEWithLogitsLoss

from transformers import (
MambaModel,
MambaPreTrainedModel
)
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers import MambaModel, MambaPreTrainedModel
from transformers.activations import ACT2FN
from transformers.models.mamba.modeling_mamba import (
MAMBA_INPUTS_DOCSTRING,
MAMBA_START_DOCSTRING,
MambaModel,
MambaPreTrainedModel,
MAMBA_START_DOCSTRING,
MAMBA_INPUTS_DOCSTRING,
)
from transformers.utils import (
ModelOutput,
dataclass,
add_start_docstrings,
add_start_docstrings_to_model_forward,
dataclass,
replace_return_docstrings,
)

Expand All @@ -45,6 +41,7 @@ class MambaSequenceClassifierOutput(ModelOutput):
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
"""

loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
Expand All @@ -61,7 +58,7 @@ def __init__(self, config):
self.config = config

def forward(self, features, **kwargs):
x = features # Pooling is done by the forward pass
x = features # Pooling is done by the forward pass
x = self.dropout(x)
x = self.dense(x)
x = ACT2FN[self.config.hidden_act](x)
Expand All @@ -88,8 +85,12 @@ def __init__(self, config):
# Initialize weights and apply final processing
self.post_init()

@add_start_docstrings_to_model_forward(MAMBA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=MambaSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
@add_start_docstrings_to_model_forward(
MAMBA_INPUTS_DOCSTRING.format("batch_size, sequence_length")
)
@replace_return_docstrings(
output_type=MambaSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC
)
def forward(
self,
input_ids: torch.LongTensor = None,
Expand All @@ -112,11 +113,14 @@ def forward(
)
last_hidden_states = sequence_outputs[0]
batch_size = last_hidden_states.shape[0]

# Pool the hidden states for the last tokens before padding to use for classification
last_token_indexes = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
last_token_indexes = (
torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
)
pooled_last_hidden_states = last_hidden_states[
torch.arange(batch_size, device=last_hidden_states.device), last_token_indexes
torch.arange(batch_size, device=last_hidden_states.device),
last_token_indexes,
]

logits = self.classifier(pooled_last_hidden_states)
Expand All @@ -126,7 +130,9 @@ def forward(
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
elif self.num_labels > 1 and (
labels.dtype == torch.long or labels.dtype == torch.int
):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
Expand All @@ -152,4 +158,4 @@ def forward(
loss=loss,
logits=logits,
hidden_states=sequence_outputs.hidden_states,
)
)
19 changes: 11 additions & 8 deletions odyssey/models/cehr_mamba/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,28 @@

import numpy as np
import pytorch_lightning as pl

import torch
from sklearn.metrics import (
accuracy_score,
f1_score,
precision_score,
recall_score,
roc_auc_score,
)

import torch
from torch import nn
from torch.cuda.amp import autocast
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR, SequentialLR
from transformers import MambaConfig
from transformers.models.mamba.modeling_mamba import (
MambaCausalLMOutput,
MambaForCausalLM,
)

from transformers.models.mamba.modeling_mamba import MambaForCausalLM, MambaCausalLMOutput
from odyssey.models.cehr_mamba.mamba_utils import MambaForSequenceClassification, MambaSequenceClassifierOutput
from odyssey.models.cehr_mamba.mamba_utils import (
MambaForSequenceClassification,
MambaSequenceClassifierOutput,
)


class MambaPretrain(pl.LightningModule):
Expand Down Expand Up @@ -226,8 +230,7 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)



def training_step(self, batch: Dict[str, Any], batch_idx: int) -> Any:
"""Train model on training dataset."""
concept_ids = batch["concept_ids"]
Expand Down Expand Up @@ -284,7 +287,7 @@ def test_step(self, batch: Dict[str, Any], batch_idx: int) -> Any:
labels=labels,
return_dict=True,
)

loss = outputs[0]
logits = outputs[1]
preds = torch.argmax(logits, dim=1)
Expand Down

0 comments on commit 7e46abf

Please sign in to comment.