Skip to content

Commit

Permalink
[pre-commit.ci] Add auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 4, 2024
1 parent 180a9ad commit 27fcbe3
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 43 deletions.
37 changes: 19 additions & 18 deletions odyssey/data/DataProcessor.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
"\n",
"SEED = 23\n",
"ROOT = \"/h/afallah/odyssey/odyssey\"\n",
"DATA_ROOT = f\"{ROOT}/odyssey/data/meds_data\" # bigbird_data\n",
"DATASET = f\"{DATA_ROOT}/patient_sequences/patient_sequences.parquet\" #patient_sequences_2048.parquet\\\n",
"DATA_ROOT = f\"{ROOT}/odyssey/data/meds_data\" # bigbird_data\n",
"DATASET = f\"{DATA_ROOT}/patient_sequences/patient_sequences.parquet\" # patient_sequences_2048.parquet\\\n",
"DATASET_2048 = f\"{DATA_ROOT}/patient_sequences/patient_sequences_2048.parquet\"\n",
"MAX_LEN = 2048\n",
"\n",
Expand Down Expand Up @@ -117,19 +117,22 @@
"source": [
"dataset = pl.read_parquet(DATASET)\n",
"dataset = dataset.rename({\"subject_id\": \"patient_id\", \"code\": \"event_tokens\"})\n",
"dataset = dataset.filter(pl.col('event_tokens').map_elements(len) > 5)\n",
"\n",
"dataset = dataset.with_columns([\n",
" pl.col('patient_id').cast(pl.String).alias('patient_id'),\n",
" pl.concat_list([\n",
" pl.col('event_tokens').list.slice(0, 2047),\n",
" pl.lit(['[EOS]'])\n",
" ]).alias('event_tokens'),\n",
"])\n",
"dataset = dataset.filter(pl.col(\"event_tokens\").map_elements(len) > 5)\n",
"\n",
"dataset = dataset.with_columns(\n",
" [\n",
" pl.col(\"patient_id\").cast(pl.String).alias(\"patient_id\"),\n",
" pl.concat_list(\n",
" [pl.col(\"event_tokens\").list.slice(0, 2047), pl.lit([\"[EOS]\"])]\n",
" ).alias(\"event_tokens\"),\n",
" ]\n",
")\n",
"\n",
"dataset = dataset.with_columns([\n",
" pl.col('event_tokens').map_elements(len).alias('token_length'),\n",
"])\n",
"dataset = dataset.with_columns(\n",
" [\n",
" pl.col(\"event_tokens\").map_elements(len).alias(\"token_length\"),\n",
" ]\n",
")\n",
"\n",
"print(dataset.head())\n",
"print(dataset.schema)\n",
Expand Down Expand Up @@ -226,7 +229,7 @@
"metadata": {},
"outputs": [],
"source": [
"dataset['event_tokens_2048'].iloc[0]"
"dataset[\"event_tokens_2048\"].iloc[0]"
]
},
{
Expand Down Expand Up @@ -280,9 +283,7 @@
"outputs": [],
"source": [
"# Process the dataset for hospital readmission in one month task\n",
"dataset_readmission = process_readmission_dataset(\n",
" dataset.copy(), max_len=MAX_LEN\n",
")"
"dataset_readmission = process_readmission_dataset(dataset.copy(), max_len=MAX_LEN)"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions odyssey/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
tokens["labels"] = tokens["concept_ids"]

if self.return_attention_mask:
tokens["attention_mask"] = tokenized_input["attention_mask"].squeeze()
tokens["attention_mask"] = tokenized_input["attention_mask"].squeeze()

return tokens

Expand Down Expand Up @@ -631,7 +631,7 @@ def __getitem__(self, idx: int) -> Dict[str, Any]:
data = self.data.iloc[index]

# Swap the first token with the task token.
data[f"event_tokens"][0] = self.tokenizer.task_to_token(task)
data["event_tokens"][0] = self.tokenizer.task_to_token(task)
data = self.truncate_and_pad(
row=data, cutoff=cutoff, additional_columns=self.additional_token_types
)
Expand Down
7 changes: 4 additions & 3 deletions odyssey/data/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from itertools import chain
from typing import Any, Dict, List, Optional, Set, Union

import numpy as np
import pandas as pd
from tokenizers import Tokenizer, models, pre_tokenizers
from transformers import BatchEncoding, PreTrainedTokenizerFast
Expand Down Expand Up @@ -49,7 +48,7 @@ class ConceptTokenizer:
Mask token.
start_token: str
Sequence Start token.
end_token: str
end_token: str
Sequence End token.
class_token: str
Class token.
Expand Down Expand Up @@ -88,7 +87,9 @@ def __init__(
reg_token: str = "[REG]",
unknown_token: str = "[UNK]",
data_dir: str = "data_files",
time_tokens: List[str] = [f"[W_{i}]" for i in range(0, 4)] + [f"[M_{i}]" for i in range(0, 13)] + ["[LT]"],
time_tokens: List[str] = [f"[W_{i}]" for i in range(0, 4)]
+ [f"[M_{i}]" for i in range(0, 13)]
+ ["[LT]"],
tokenizer_object: Optional[Tokenizer] = None,
tokenizer: Optional[PreTrainedTokenizerFast] = None,
) -> None:
Expand Down
2 changes: 1 addition & 1 deletion odyssey/models/configs/ehr_mamba2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ finetune:
acc: 1
patience: 10
persistent_workers: True
pin_memory: False
pin_memory: False
20 changes: 4 additions & 16 deletions odyssey/models/ehr_mamba2/model.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,12 @@
"""Mamba2 model."""

from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, Tuple

import numpy as np
import pytorch_lightning as pl
import torch
from sklearn.metrics import (
accuracy_score,
f1_score,
precision_score,
recall_score,
roc_auc_score,
)
from torch import nn
from torch.cuda.amp import autocast
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from transformers import Mamba2Config, Mamba2ForCausalLM
from transformers.models.mamba2.modeling_mamba2 import Mamba2CausalLMOutput


class Mamba2Pretrain(pl.LightningModule):
Expand All @@ -38,7 +27,7 @@ def __init__(
learning_rate: float = 5e-5,
dropout_prob: float = 0.1,
padding_idx: int = 0,
cls_idx: int = 1, # used as bos token
cls_idx: int = 1, # used as bos token
eos_idx: int = 2,
n_groups: int = 1,
chunk_size: int = 256,
Expand Down Expand Up @@ -78,15 +67,15 @@ def __init__(
dropout=self.dropout_prob,
num_heads=self.num_heads,
head_dim=self.head_dim,
max_position_embeddings=self.max_seq_length
max_position_embeddings=self.max_seq_length,
)

# Mamba has its own initialization
self.model = Mamba2ForCausalLM(config=self.config)

def _step(self, batch: Dict[str, Any], batch_idx: int, stage: str) -> Any:
"""Run a single step for training or validation.
Args:
batch: Input batch dictionary
batch_idx: Index of current batch
Expand Down Expand Up @@ -154,4 +143,3 @@ def configure_optimizers(
)

return [optimizer], [{"scheduler": scheduler, "interval": "step"}]

8 changes: 5 additions & 3 deletions pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,16 @@ def main(args: argparse.Namespace, model_config: Dict[str, Any]) -> None:
data_dir=args.vocab_dir,
start_token="[VS]",
end_token="[VE]",
time_tokens=[f"[W_{i}]" for i in range(0, 4)] + [f"[M_{i}]" for i in range(0, 13)] + ["[LT]"]
time_tokens=[f"[W_{i}]" for i in range(0, 4)]
+ [f"[M_{i}]" for i in range(0, 13)]
+ ["[LT]"],
)
else: # meds
else: # meds
tokenizer = ConceptTokenizer(
data_dir=args.vocab_dir,
start_token="[BOS]",
end_token="[EOS]",
time_tokens=None # New tokenizer comes with predefined time tokens
time_tokens=None, # New tokenizer comes with predefined time tokens
)
tokenizer.fit_on_vocab()

Expand Down

0 comments on commit 27fcbe3

Please sign in to comment.