Skip to content

Commit

Permalink
Fix LR scheduler naming
Browse files Browse the repository at this point in the history
  • Loading branch information
RaulPPelaez committed Jul 15, 2024
1 parent 404afc0 commit a47e63a
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
4 changes: 3 additions & 1 deletion torchmdnet/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,11 @@ def configure_optimizers(self):
patience=self.hparams.lr_patience,
min_lr=self.hparams.lr_min,
)
lr_metric = getattr(self.hparams, "lr_metric", "val")
monitor = f"{lr_metric}_total_{self.hparams.train_loss}"
lr_scheduler = {
"scheduler": scheduler,
"monitor": getattr(self.hparams, "lr_metric", "val_loss"),
"monitor": monitor,
"interval": "epoch",
"frequency": 1,
}
Expand Down
11 changes: 5 additions & 6 deletions torchmdnet/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def get_argparse():
parser.add_argument('--inference-batch-size', default=None, type=int, help='Batchsize for validation and tests.')
parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')
parser.add_argument('--lr-patience', type=int, default=10, help='Patience for lr-schedule. Patience per eval-interval of validation')
parser.add_argument('--lr-metric', type=str, default='val_total_mse_loss', choices=['train_total_mse_loss', 'val_total_mse_loss'], help='Metric to monitor when deciding whether to reduce learning rate')
parser.add_argument('--lr-metric', type=str, default='val', choices=['train', 'val'], help='Metric to monitor when deciding whether to reduce learning rate')
parser.add_argument('--lr-min', type=float, default=1e-6, help='Minimum learning rate before early stop')
parser.add_argument('--lr-factor', type=float, default=0.8, help='Factor by which to multiply the learning rate when the metric stops improving')
parser.add_argument('--lr-warmup-steps', type=int, default=0, help='How many steps to warm-up over. Defaults to 0 for no warm-up')
Expand Down Expand Up @@ -168,17 +168,16 @@ def main():
# initialize lightning module
model = LNNP(args, prior_model=prior_models, mean=data.mean, std=data.std)

val_loss_name = f"val_total_{args.train_loss}"
checkpoint_callback = ModelCheckpoint(
dirpath=args.log_dir,
monitor="val_total_mse_loss",
monitor=val_loss_name,
save_top_k=10, # -1 to save all
every_n_epochs=args.save_interval,
filename="epoch={epoch}-val_loss={val_total_mse_loss:.4f}-test_loss={test_total_l1_loss:.4f}",
filename=f"epoch={{epoch}}-val_loss={{{val_loss_name}:.4f}}-test_loss={{test_total_l1_loss:.4f}}",
auto_insert_metric_name=False,
)
early_stopping = EarlyStopping(
"val_total_mse_loss", patience=args.early_stopping_patience
)
early_stopping = EarlyStopping(val_loss_name, patience=args.early_stopping_patience)

csv_logger = CSVLogger(args.log_dir, name="", version="")
_logger = [csv_logger]
Expand Down

0 comments on commit a47e63a

Please sign in to comment.