Skip to content

Commit

Permalink
Update requirements, new logic for uploading weights
Browse files Browse the repository at this point in the history
  • Loading branch information
SpirinEgor committed Sep 22, 2021
1 parent 29fc3b9 commit 9a869fc
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 13 deletions.
8 changes: 3 additions & 5 deletions embeddings_for_trees/utils/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import dgl
import torch
from commode_utils.callback import UploadCheckpointCallback, PrintEpochResultCallback
from commode_utils.callback import PrintEpochResultCallback, ModelCheckpointWithUpload
from omegaconf import DictConfig
from pytorch_lightning import seed_everything, Trainer, LightningModule, LightningDataModule
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger


Expand All @@ -20,15 +20,14 @@ def train(model: LightningModule, data_module: LightningDataModule, config: Dict
wandb_logger = WandbLogger(project=f"{model_name} -- {dataset_name}", log_model=False, offline=config.log_offline)

# define model checkpoint callback
checkpoint_callback = ModelCheckpoint(
checkpoint_callback = ModelCheckpointWithUpload(
dirpath=wandb_logger.experiment.dir,
filename="{epoch:02d}-val_loss={val/loss:.4f}",
monitor="val/loss",
every_n_epochs=params.save_every_epoch,
save_top_k=-1,
auto_insert_metric_name=False,
)
upload_checkpoint_callback = UploadCheckpointCallback(wandb_logger.experiment.dir)
# define early stopping callback
early_stopping_callback = EarlyStopping(patience=params.patience, monitor="val/loss", verbose=True, mode="min")
# define callback for printing intermediate result
Expand All @@ -50,7 +49,6 @@ def train(model: LightningModule, data_module: LightningDataModule, config: Dict
lr_logger,
early_stopping_callback,
checkpoint_callback,
upload_checkpoint_callback,
print_epoch_result_callback,
],
resume_from_checkpoint=config.get("checkpoint", None),
Expand Down
16 changes: 9 additions & 7 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
black==21.7b0
torch==1.9.0
tqdm==4.62.0
pytorch-lightning==1.4.1
wandb==0.11.2
mypy==0.910
omegaconf==2.1.0

torch==1.9.1
pytorch-lightning==1.4.7
dgl==0.6.1
commode-utils==0.3.7
torchmetrics==0.4.1
torchmetrics==0.5.1

tqdm==4.62.3
wandb==0.12.2
omegaconf==2.1.1
commode-utils==0.3.9
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from setuptools import setup, find_packages

VERSION = "1.0.2"
VERSION = "1.0.3"

with open("README.md") as readme_file:
readme = readme_file.read()
Expand Down

0 comments on commit 9a869fc

Please sign in to comment.