Skip to content

Commit

Permalink
Formatting fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
amrit110 committed Apr 17, 2024
1 parent 57f576f commit 8a8bc86
Show file tree
Hide file tree
Showing 7 changed files with 10 additions and 13 deletions.
4 changes: 2 additions & 2 deletions finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from skmultilearn.model_selection import iterative_train_test_split
from torch.utils.data import DataLoader

from odyssey.utils.utils import seed_everything
from odyssey.data.dataset import FinetuneDataset, FinetuneMultiDataset
from odyssey.data.tokenizer import ConceptTokenizer
from odyssey.models.cehr_bert.model import BertFinetune, BertPretrain
Expand All @@ -29,6 +28,7 @@
load_config,
load_finetune_data,
)
from odyssey.utils.utils import seed_everything


def main(
Expand Down Expand Up @@ -251,7 +251,7 @@ def main(

if __name__ == "__main__":
parser = argparse.ArgumentParser()

# project configuration
parser.add_argument(
"--model-type",
Expand Down
1 change: 0 additions & 1 deletion odyssey/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import pandas as pd

import torch
from torch.utils.data import Dataset

Expand Down
2 changes: 1 addition & 1 deletion odyssey/evals/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import torch

from odyssey.models.cehr_big_bird.model import BigBirdFinetune, BigBirdPretrain
from odyssey.data.tokenizer import ConceptTokenizer
from odyssey.models.cehr_big_bird.model import BigBirdFinetune, BigBirdPretrain


def load_finetuned_model(
Expand Down
2 changes: 1 addition & 1 deletion odyssey/models/baseline/Bi-LSTM.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from tqdm import tqdm

from odyssey.data.dataset import FinetuneDataset
from odyssey.models.cehr_big_bird.embeddings import Embeddings
from odyssey.data.tokenizer import HuggingFaceConceptTokenizer
from odyssey.models.cehr_big_bird.embeddings import Embeddings


ROOT = "/fs01/home/afallah/odyssey/odyssey"
Expand Down
5 changes: 2 additions & 3 deletions odyssey/models/cehr_big_bird/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@

import numpy as np
import pytorch_lightning as pl
import torch
from sklearn.metrics import (
accuracy_score,
f1_score,
precision_score,
recall_score,
roc_auc_score,
)

import torch
from torch import nn, optim
from torch.cuda.amp import autocast
from torch.optim import AdamW
Expand Down Expand Up @@ -453,7 +452,7 @@ def on_test_epoch_end(self) -> Any:
auc = roc_auc_score(labels, preds)
precision = precision_score(labels, preds)
recall = recall_score(labels, preds)

self.log("test_loss", loss)
self.log("test_acc", accuracy)
self.log("test_f1", f1)
Expand Down
5 changes: 2 additions & 3 deletions odyssey/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
"""General utility functions and constants for the project."""

from typing import Any

import pickle
import random
from typing import Any

import numpy as np
import torch
import pytorch_lightning as pl
import torch


def seed_everything(seed: int) -> None:
Expand Down
4 changes: 2 additions & 2 deletions pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader

from odyssey.utils.utils import seed_everything
from odyssey.data.dataset import PretrainDataset
from odyssey.data.tokenizer import ConceptTokenizer
from odyssey.models.cehr_bert.model import BertPretrain
Expand All @@ -23,6 +22,7 @@
load_config,
load_pretrain_data,
)
from odyssey.utils.utils import seed_everything


def main(args: argparse.Namespace, model_config: Dict[str, Any]) -> None:
Expand Down Expand Up @@ -152,7 +152,7 @@ def main(args: argparse.Namespace, model_config: Dict[str, Any]) -> None:

if __name__ == "__main__":
parser = argparse.ArgumentParser()

# project configuration
parser.add_argument(
"--model-type",
Expand Down

0 comments on commit 8a8bc86

Please sign in to comment.