Skip to content

Commit

Permalink
Minor style improvements.
Browse files Browse the repository at this point in the history
  • Loading branch information
Adibvafa committed Apr 30, 2024
2 parents c0102d4 + dc640ed commit a6ab852
Show file tree
Hide file tree
Showing 14 changed files with 697 additions and 3,684 deletions.
27 changes: 23 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0 # Use the ref you want to point at
rev: v4.6.0 # Use the ref you want to point at
hooks:
- id: trailing-whitespace
- id: check-ast
Expand All @@ -15,8 +15,8 @@ repos:
- id: check-yaml
- id: check-toml

- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: 'v0.3.7'
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: 'v0.4.2'
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand All @@ -25,7 +25,7 @@ repos:
types_or: [python, jupyter]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.8.0
rev: v1.10.0
hooks:
- id: mypy
entry: python3 -m mypy --config-file pyproject.toml
Expand All @@ -41,3 +41,22 @@ repos:
language: system
pass_filenames: false
always_run: true

- repo: local
hooks:
- id: nbstripout
name: nbstripout
language: system
entry: python3 -m nbstripout

ci:
autofix_commit_msg: |
[pre-commit.ci] Add auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
autofix_prs: true
autoupdate_branch: ''
autoupdate_commit_msg: '[pre-commit.ci] pre-commit autoupdate'
autoupdate_schedule: weekly
skip: [pytest,nbstripout,mypy]
submodules: false
23 changes: 11 additions & 12 deletions finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,11 @@


def main(
args: Dict[str, Any],
args: argparse.Namespace,
pre_model_config: Dict[str, Any],
fine_model_config: Dict[str, Any],
) -> None:
"""Train the model."""

# Setup environment
seed_everything(args.seed)
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
Expand Down Expand Up @@ -66,7 +65,7 @@ def main(
random_state=args.seed,
stratify=fine_tune["label"],
)

else: # Multi label classfication
fine_train_ids, _, fine_val_ids, _ = iterative_train_test_split(
X=fine_tune["patient_id"].to_numpy().reshape(-1, 1),
Expand Down Expand Up @@ -113,7 +112,7 @@ def main(
balance_guide=None,
max_len=args.max_len,
)

else:
train_dataset = FinetuneDataset(
data=fine_train,
Expand All @@ -130,7 +129,7 @@ def main(
tokenizer=tokenizer,
max_len=args.max_len,
)

train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
Expand Down Expand Up @@ -163,12 +162,12 @@ def main(
dirpath=args.checkpoint_dir,
),
LearningRateMonitor(logging_interval="step"),
# EarlyStopping(
# monitor="val_loss",
# patience=args.patience,
# verbose=True,
# mode="min",
# ),
EarlyStopping(
monitor="val_loss",
patience=args.patience,
verbose=True,
mode="min",
),
]

# Create model
Expand All @@ -185,7 +184,7 @@ def main(
pretrained_model=pretrained_model,
**fine_model_config,
)

elif args.model_type == "cehr_bigbird":
pretrained_model = BigBirdPretrain(
vocab_size=tokenizer.get_vocab_size(),
Expand Down
2 changes: 1 addition & 1 deletion interpret.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@


def main(
args: Dict[str, Any],
args: argparse.Namespace,
pre_model_config: Dict[str, Any],
fine_model_config: Dict[str, Any],
) -> None:
Expand Down
Loading

0 comments on commit a6ab852

Please sign in to comment.