Skip to content

Commit

Permalink
Merge pull request #34 from VectorInstitute/feature/mamba
Browse files Browse the repository at this point in the history
  • Loading branch information
amrit110 authored May 1, 2024
2 parents 9ae62fe + bacacc5 commit be69218
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 9 deletions.
62 changes: 56 additions & 6 deletions finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,15 @@
from skmultilearn.model_selection import iterative_train_test_split
from torch.utils.data import DataLoader

from odyssey.data.dataset import FinetuneDataset, FinetuneMultiDataset
from odyssey.data.dataset import (
FinetuneDataset,
FinetuneDatasetDecoder,
FinetuneMultiDataset,
)
from odyssey.data.tokenizer import ConceptTokenizer
from odyssey.models.cehr_bert.model import BertFinetune, BertPretrain
from odyssey.models.cehr_big_bird.model import BigBirdFinetune, BigBirdPretrain
from odyssey.models.cehr_mamba.model import MambaFinetune, MambaPretrain
from odyssey.models.model_utils import (
get_run_id,
load_config,
Expand Down Expand Up @@ -90,7 +95,30 @@ def main(
tokenizer.fit_on_vocab(with_tasks=args.is_multi_model)

# Load datasets based on model type
if args.is_multi_model:
if args.is_decoder:
train_dataset = FinetuneDatasetDecoder(
data=fine_train,
tokenizer=tokenizer,
tasks=args.tasks,
balance_guide=args.balance_guide,
max_len=args.max_len,
)
val_dataset = FinetuneDatasetDecoder(
data=fine_val,
tokenizer=tokenizer,
tasks=args.tasks,
balance_guide=args.balance_guide,
max_len=args.max_len,
)
test_dataset = FinetuneDatasetDecoder(
data=fine_test,
tokenizer=tokenizer,
tasks=args.tasks,
balance_guide=None,
max_len=args.max_len,
)

elif args.is_multi_model:
train_dataset = FinetuneMultiDataset(
data=fine_train,
tokenizer=tokenizer,
Expand Down Expand Up @@ -192,14 +220,28 @@ def main(
**pre_model_config,
)
pretrained_model.load_state_dict(torch.load(args.pretrained_path)["state_dict"])

model = BigBirdFinetune(
pretrained_model=pretrained_model,
num_labels=args.num_labels,
problem_type=args.problem_type,
**fine_model_config,
)

elif args.model_type == "cehr_mamba":
pretrained_model = MambaPretrain(
vocab_size=tokenizer.get_vocab_size(),
padding_idx=tokenizer.get_pad_token_id(),
cls_idx=tokenizer.get_class_token_id(),
**pre_model_config,
)
pretrained_model.load_state_dict(torch.load(args.pretrained_path)["state_dict"])
model = MambaFinetune(
pretrained_model=pretrained_model,
num_labels=args.num_labels,
problem_type=args.problem_type,
**fine_model_config,
)

run_id = get_run_id(args.checkpoint_dir)

wandb_logger = WandbLogger(
Expand Down Expand Up @@ -265,7 +307,7 @@ def main(
"--model-type",
type=str,
required=True,
help="Model type: 'cehr_bert' or 'cehr_bigbird'",
help="Model type: 'cehr_bert' or 'cehr_bigbird', or 'cehr_mamba'",
)
parser.add_argument(
"--exp-name",
Expand Down Expand Up @@ -302,6 +344,12 @@ def main(
default=False,
help="Is the model a multimodel like multibird or not",
)
parser.add_argument(
"--is-decoder",
type=bool,
default=False,
help="Is the model a decoder (e.g. Mamba) or not",
)

# data-related arguments
parser.add_argument(
Expand Down Expand Up @@ -419,8 +467,10 @@ def main(

# Process arguments
args = parser.parse_args()
if args.model_type not in ["cehr_bert", "cehr_bigbird"]:
print("Invalid model type. Choose 'cehr_bert' or 'cehr_bigbird'.")
if args.model_type not in ["cehr_bert", "cehr_bigbird", "cehr_mamba"]:
print(
"Invalid model type. Choose 'cehr_bert' or 'cehr_bigbird' or 'cehr_mamba'."
)
sys.exit(1)

args.checkpoint_dir = os.path.join(args.checkpoint_dir, args.exp_name)
Expand Down
2 changes: 2 additions & 0 deletions odyssey/models/cehr_big_bird/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def training_step(self, batch: Dict[str, Any], batch_idx: int) -> Any:
dictionary={"train_loss": loss, "lr": current_lr},
on_step=True,
prog_bar=True,
sync_dist=True,
)
return loss

Expand Down Expand Up @@ -356,6 +357,7 @@ def training_step(self, batch: Dict[str, Any], batch_idx: int) -> Any:
dictionary={"train_loss": loss, "lr": current_lr},
on_step=True,
prog_bar=True,
sync_dist=True,
)
return loss

Expand Down
5 changes: 4 additions & 1 deletion odyssey/models/cehr_mamba/mamba_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Utilities following HuggingFace style for Mamba models."""

from dataclasses import dataclass
from typing import Optional, Tuple, Union

import torch
Expand All @@ -17,7 +18,6 @@
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
dataclass,
replace_return_docstrings,
)

Expand Down Expand Up @@ -104,6 +104,9 @@ def forward(
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Returns
-------
"""
sequence_outputs = self.backbone(
input_ids,
Expand Down
3 changes: 3 additions & 0 deletions odyssey/models/cehr_mamba/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def training_step(self, batch: Dict[str, Any], batch_idx: int) -> Any:
dictionary={"train_loss": loss, "lr": current_lr},
on_step=True,
prog_bar=True,
sync_dist=True,
)

return loss
Expand Down Expand Up @@ -249,6 +250,7 @@ def training_step(self, batch: Dict[str, Any], batch_idx: int) -> Any:
dictionary={"train_loss": loss, "lr": current_lr},
on_step=True,
prog_bar=True,
sync_dist=True,
)

return loss
Expand All @@ -271,6 +273,7 @@ def validation_step(self, batch: Dict[str, Any], batch_idx: int) -> Any:
dictionary={"val_loss": loss, "lr": current_lr},
on_step=True,
prog_bar=True,
sync_dist=True,
)

return loss
Expand Down
2 changes: 1 addition & 1 deletion odyssey/models/configs/cehr_mamba.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ train:
pin_memory: False

finetune:
batch_size: 26
batch_size: 64 #26
num_workers: 6
gpus: 4
nodes: 1
Expand Down
8 changes: 7 additions & 1 deletion pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def main(args: argparse.Namespace, model_config: Dict[str, Any]) -> None:
tokenizer.fit_on_vocab()

# Load datasets
if args.model_type == "cehr_mamba": # Decoder model
if args.is_decoder: # e.g. Mamba
train_dataset = PretrainDatasetDecoder(
data=pre_train,
tokenizer=tokenizer,
Expand Down Expand Up @@ -197,6 +197,12 @@ def main(args: argparse.Namespace, model_config: Dict[str, Any]) -> None:
required=True,
help="Path to model config file",
)
parser.add_argument(
"--is-decoder",
type=bool,
default=False,
help="Is the model a decoder (e.g. Mamba) or not",
)

# data-related arguments
parser.add_argument(
Expand Down

0 comments on commit be69218

Please sign in to comment.