diff --git a/torchmdnet/module.py b/torchmdnet/module.py index 71be10f36..9c66c35b2 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -163,15 +163,18 @@ def step(self, batch, loss_fn_list, stage): batch.y = batch.y.unsqueeze(1) for loss_fn in loss_fn_list: step_losses = self._compute_losses(y, neg_dy, batch, loss_fn, stage) + + loss_name = loss_fn.__name__ + if self.hparams.neg_dy_weight > 0: + self.losses[stage]["neg_dy"][loss_name].append( + step_losses["neg_dy"].detach() + ) + if self.hparams.y_weight > 0: + self.losses[stage]["y"][loss_name].append(step_losses["y"].detach()) total_loss = ( step_losses["y"] * self.hparams.y_weight + step_losses["neg_dy"] * self.hparams.neg_dy_weight ) - loss_name = loss_fn.__name__ - self.losses[stage]["neg_dy"][loss_name].append( - step_losses["neg_dy"].detach() - ) - self.losses[stage]["y"][loss_name].append(step_losses["y"].detach()) self.losses[stage]["total"][loss_name].append(total_loss.detach()) return total_loss @@ -215,11 +218,6 @@ def on_validation_epoch_end(self): result_dict.update(self._get_mean_loss_dict_for_type("total")) result_dict.update(self._get_mean_loss_dict_for_type("y")) result_dict.update(self._get_mean_loss_dict_for_type("neg_dy")) - # For retro compatibility with previous versions of TorchMD-Net we report some losses twice - result_dict["val_loss"] = result_dict["val_total_mse_loss"] - result_dict["train_loss"] = result_dict["train_total_mse_loss"] - if "test_total_l1_loss" in result_dict: - result_dict["test_loss"] = result_dict["test_total_l1_loss"] self.log_dict(result_dict, sync_dist=True) self._reset_losses_dict() diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index 61b8dc33b..11cfed0f6 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -29,7 +29,7 @@ def get_args(): parser.add_argument('--inference-batch-size', default=None, type=int, help='Batchsize for validation and tests.') parser.add_argument('--lr', default=1e-4, type=float, help='learning rate') parser.add_argument('--lr-patience', type=int, default=10, help='Patience for lr-schedule. Patience per eval-interval of validation') - parser.add_argument('--lr-metric', type=str, default='val_loss', choices=['train_loss', 'val_loss'], help='Metric to monitor when deciding whether to reduce learning rate') + parser.add_argument('--lr-metric', type=str, default='val_total_mse_loss', choices=['train_total_mse_loss', 'val_total_mse_loss'], help='Metric to monitor when deciding whether to reduce learning rate') parser.add_argument('--lr-min', type=float, default=1e-6, help='Minimum learning rate before early stop') parser.add_argument('--lr-factor', type=float, default=0.8, help='Factor by which to multiply the learning rate when the metric stops improving') parser.add_argument('--lr-warmup-steps', type=int, default=0, help='How many steps to warm-up over. Defaults to 0 for no warm-up') @@ -140,12 +140,15 @@ def main(): checkpoint_callback = ModelCheckpoint( dirpath=args.log_dir, - monitor="val_loss", + monitor="val_total_mse_loss", save_top_k=10, # -1 to save all every_n_epochs=args.save_interval, - filename="{epoch}-{val_loss:.4f}-{test_loss:.4f}", + filename="epoch={epoch}-val_loss={val_total_mse_loss:.4f}-test_loss={test_total_l1_loss:.4f}", + auto_insert_metric_name=False, + ) + early_stopping = EarlyStopping( + "val_total_mse_loss", patience=args.early_stopping_patience ) - early_stopping = EarlyStopping("val_loss", patience=args.early_stopping_patience) csv_logger = CSVLogger(args.log_dir, name="", version="") _logger = [csv_logger]