From a47e63a3926839e6ea8fbeda4a33d38c4cb7f451 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 15 Jul 2024 13:54:00 +0200 Subject: [PATCH] Fix LR scheduler naming --- torchmdnet/module.py | 4 +++- torchmdnet/scripts/train.py | 11 +++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/torchmdnet/module.py b/torchmdnet/module.py index b61c690f..6e508d0d 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -111,9 +111,11 @@ def configure_optimizers(self): patience=self.hparams.lr_patience, min_lr=self.hparams.lr_min, ) + lr_metric = getattr(self.hparams, "lr_metric", "val") + monitor = f"{lr_metric}_total_{self.hparams.train_loss}" lr_scheduler = { "scheduler": scheduler, - "monitor": getattr(self.hparams, "lr_metric", "val_loss"), + "monitor": monitor, "interval": "epoch", "frequency": 1, } diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index 7fb144be..1fd656b7 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -36,7 +36,7 @@ def get_argparse(): parser.add_argument('--inference-batch-size', default=None, type=int, help='Batchsize for validation and tests.') parser.add_argument('--lr', default=1e-4, type=float, help='learning rate') parser.add_argument('--lr-patience', type=int, default=10, help='Patience for lr-schedule. Patience per eval-interval of validation') - parser.add_argument('--lr-metric', type=str, default='val_total_mse_loss', choices=['train_total_mse_loss', 'val_total_mse_loss'], help='Metric to monitor when deciding whether to reduce learning rate') + parser.add_argument('--lr-metric', type=str, default='val', choices=['train', 'val'], help='Metric to monitor when deciding whether to reduce learning rate') parser.add_argument('--lr-min', type=float, default=1e-6, help='Minimum learning rate before early stop') parser.add_argument('--lr-factor', type=float, default=0.8, help='Factor by which to multiply the learning rate when the metric stops improving') parser.add_argument('--lr-warmup-steps', type=int, default=0, help='How many steps to warm-up over. Defaults to 0 for no warm-up') @@ -168,17 +168,16 @@ def main(): # initialize lightning module model = LNNP(args, prior_model=prior_models, mean=data.mean, std=data.std) + val_loss_name = f"val_total_{args.train_loss}" checkpoint_callback = ModelCheckpoint( dirpath=args.log_dir, - monitor="val_total_mse_loss", + monitor=val_loss_name, save_top_k=10, # -1 to save all every_n_epochs=args.save_interval, - filename="epoch={epoch}-val_loss={val_total_mse_loss:.4f}-test_loss={test_total_l1_loss:.4f}", + filename=f"epoch={{epoch}}-val_loss={{{val_loss_name}:.4f}}-test_loss={{test_total_l1_loss:.4f}}", auto_insert_metric_name=False, ) - early_stopping = EarlyStopping( - "val_total_mse_loss", patience=args.early_stopping_patience - ) + early_stopping = EarlyStopping(val_loss_name, patience=args.early_stopping_patience) csv_logger = CSVLogger(args.log_dir, name="", version="") _logger = [csv_logger]