From 27fcbe309e5bbb0f52044e41a3978c5407167249 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 4 Nov 2024 16:55:58 +0000 Subject: [PATCH] [pre-commit.ci] Add auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- odyssey/data/DataProcessor.ipynb | 37 +++++++++++++------------- odyssey/data/dataset.py | 4 +-- odyssey/data/tokenizer.py | 7 ++--- odyssey/models/configs/ehr_mamba2.yaml | 2 +- odyssey/models/ehr_mamba2/model.py | 20 +++----------- pretrain.py | 8 +++--- 6 files changed, 35 insertions(+), 43 deletions(-) diff --git a/odyssey/data/DataProcessor.ipynb b/odyssey/data/DataProcessor.ipynb index 590ca2e..e547db6 100644 --- a/odyssey/data/DataProcessor.ipynb +++ b/odyssey/data/DataProcessor.ipynb @@ -33,8 +33,8 @@ "\n", "SEED = 23\n", "ROOT = \"/h/afallah/odyssey/odyssey\"\n", - "DATA_ROOT = f\"{ROOT}/odyssey/data/meds_data\" # bigbird_data\n", - "DATASET = f\"{DATA_ROOT}/patient_sequences/patient_sequences.parquet\" #patient_sequences_2048.parquet\\\n", + "DATA_ROOT = f\"{ROOT}/odyssey/data/meds_data\" # bigbird_data\n", + "DATASET = f\"{DATA_ROOT}/patient_sequences/patient_sequences.parquet\" # patient_sequences_2048.parquet\\\n", "DATASET_2048 = f\"{DATA_ROOT}/patient_sequences/patient_sequences_2048.parquet\"\n", "MAX_LEN = 2048\n", "\n", @@ -117,19 +117,22 @@ "source": [ "dataset = pl.read_parquet(DATASET)\n", "dataset = dataset.rename({\"subject_id\": \"patient_id\", \"code\": \"event_tokens\"})\n", - "dataset = dataset.filter(pl.col('event_tokens').map_elements(len) > 5)\n", - "\n", - "dataset = dataset.with_columns([\n", - " pl.col('patient_id').cast(pl.String).alias('patient_id'),\n", - " pl.concat_list([\n", - " pl.col('event_tokens').list.slice(0, 2047),\n", - " pl.lit(['[EOS]'])\n", - " ]).alias('event_tokens'),\n", - "])\n", + "dataset = dataset.filter(pl.col(\"event_tokens\").map_elements(len) > 5)\n", + "\n", + "dataset = dataset.with_columns(\n", + " [\n", + " pl.col(\"patient_id\").cast(pl.String).alias(\"patient_id\"),\n", + " pl.concat_list(\n", + " [pl.col(\"event_tokens\").list.slice(0, 2047), pl.lit([\"[EOS]\"])]\n", + " ).alias(\"event_tokens\"),\n", + " ]\n", + ")\n", "\n", - "dataset = dataset.with_columns([\n", - " pl.col('event_tokens').map_elements(len).alias('token_length'),\n", - "])\n", + "dataset = dataset.with_columns(\n", + " [\n", + " pl.col(\"event_tokens\").map_elements(len).alias(\"token_length\"),\n", + " ]\n", + ")\n", "\n", "print(dataset.head())\n", "print(dataset.schema)\n", @@ -226,7 +229,7 @@ "metadata": {}, "outputs": [], "source": [ - "dataset['event_tokens_2048'].iloc[0]" + "dataset[\"event_tokens_2048\"].iloc[0]" ] }, { @@ -280,9 +283,7 @@ "outputs": [], "source": [ "# Process the dataset for hospital readmission in one month task\n", - "dataset_readmission = process_readmission_dataset(\n", - " dataset.copy(), max_len=MAX_LEN\n", - ")" + "dataset_readmission = process_readmission_dataset(dataset.copy(), max_len=MAX_LEN)" ] }, { diff --git a/odyssey/data/dataset.py b/odyssey/data/dataset.py index f880345..cf7fd36 100644 --- a/odyssey/data/dataset.py +++ b/odyssey/data/dataset.py @@ -464,7 +464,7 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: tokens["labels"] = tokens["concept_ids"] if self.return_attention_mask: - tokens["attention_mask"] = tokenized_input["attention_mask"].squeeze() + tokens["attention_mask"] = tokenized_input["attention_mask"].squeeze() return tokens @@ -631,7 +631,7 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: data = self.data.iloc[index] # Swap the first token with the task token. - data[f"event_tokens"][0] = self.tokenizer.task_to_token(task) + data["event_tokens"][0] = self.tokenizer.task_to_token(task) data = self.truncate_and_pad( row=data, cutoff=cutoff, additional_columns=self.additional_token_types ) diff --git a/odyssey/data/tokenizer.py b/odyssey/data/tokenizer.py index f6d62af..96abfad 100644 --- a/odyssey/data/tokenizer.py +++ b/odyssey/data/tokenizer.py @@ -7,7 +7,6 @@ from itertools import chain from typing import Any, Dict, List, Optional, Set, Union -import numpy as np import pandas as pd from tokenizers import Tokenizer, models, pre_tokenizers from transformers import BatchEncoding, PreTrainedTokenizerFast @@ -49,7 +48,7 @@ class ConceptTokenizer: Mask token. start_token: str Sequence Start token. - end_token: str + end_token: str Sequence End token. class_token: str Class token. @@ -88,7 +87,9 @@ def __init__( reg_token: str = "[REG]", unknown_token: str = "[UNK]", data_dir: str = "data_files", - time_tokens: List[str] = [f"[W_{i}]" for i in range(0, 4)] + [f"[M_{i}]" for i in range(0, 13)] + ["[LT]"], + time_tokens: List[str] = [f"[W_{i}]" for i in range(0, 4)] + + [f"[M_{i}]" for i in range(0, 13)] + + ["[LT]"], tokenizer_object: Optional[Tokenizer] = None, tokenizer: Optional[PreTrainedTokenizerFast] = None, ) -> None: diff --git a/odyssey/models/configs/ehr_mamba2.yaml b/odyssey/models/configs/ehr_mamba2.yaml index eab4f0f..cd81f4e 100644 --- a/odyssey/models/configs/ehr_mamba2.yaml +++ b/odyssey/models/configs/ehr_mamba2.yaml @@ -33,4 +33,4 @@ finetune: acc: 1 patience: 10 persistent_workers: True - pin_memory: False \ No newline at end of file + pin_memory: False diff --git a/odyssey/models/ehr_mamba2/model.py b/odyssey/models/ehr_mamba2/model.py index 840f2e6..c0737d8 100644 --- a/odyssey/models/ehr_mamba2/model.py +++ b/odyssey/models/ehr_mamba2/model.py @@ -1,23 +1,12 @@ """Mamba2 model.""" -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Tuple -import numpy as np import pytorch_lightning as pl -import torch -from sklearn.metrics import ( - accuracy_score, - f1_score, - precision_score, - recall_score, - roc_auc_score, -) -from torch import nn from torch.cuda.amp import autocast from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR from transformers import Mamba2Config, Mamba2ForCausalLM -from transformers.models.mamba2.modeling_mamba2 import Mamba2CausalLMOutput class Mamba2Pretrain(pl.LightningModule): @@ -38,7 +27,7 @@ def __init__( learning_rate: float = 5e-5, dropout_prob: float = 0.1, padding_idx: int = 0, - cls_idx: int = 1, # used as bos token + cls_idx: int = 1, # used as bos token eos_idx: int = 2, n_groups: int = 1, chunk_size: int = 256, @@ -78,7 +67,7 @@ def __init__( dropout=self.dropout_prob, num_heads=self.num_heads, head_dim=self.head_dim, - max_position_embeddings=self.max_seq_length + max_position_embeddings=self.max_seq_length, ) # Mamba has its own initialization @@ -86,7 +75,7 @@ def __init__( def _step(self, batch: Dict[str, Any], batch_idx: int, stage: str) -> Any: """Run a single step for training or validation. - + Args: batch: Input batch dictionary batch_idx: Index of current batch @@ -154,4 +143,3 @@ def configure_optimizers( ) return [optimizer], [{"scheduler": scheduler, "interval": "step"}] - diff --git a/pretrain.py b/pretrain.py index f8b9884..f56385c 100644 --- a/pretrain.py +++ b/pretrain.py @@ -53,14 +53,16 @@ def main(args: argparse.Namespace, model_config: Dict[str, Any]) -> None: data_dir=args.vocab_dir, start_token="[VS]", end_token="[VE]", - time_tokens=[f"[W_{i}]" for i in range(0, 4)] + [f"[M_{i}]" for i in range(0, 13)] + ["[LT]"] + time_tokens=[f"[W_{i}]" for i in range(0, 4)] + + [f"[M_{i}]" for i in range(0, 13)] + + ["[LT]"], ) - else: # meds + else: # meds tokenizer = ConceptTokenizer( data_dir=args.vocab_dir, start_token="[BOS]", end_token="[EOS]", - time_tokens=None # New tokenizer comes with predefined time tokens + time_tokens=None, # New tokenizer comes with predefined time tokens ) tokenizer.fit_on_vocab()