Skip to content

Commit

Permalink
Fix module.py (#223)
Browse files Browse the repository at this point in the history
* fix compute losses output

* fix compute losses according to previous version

* remove redundant metrics from the logdict

* update ptl checkpoint to get retrocompatibility

* remove blank lines
  • Loading branch information
AntonioMirarchi authored Sep 26, 2023
1 parent 4e24910 commit d998fbe
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 14 deletions.
18 changes: 8 additions & 10 deletions torchmdnet/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,18 @@ def step(self, batch, loss_fn_list, stage):
batch.y = batch.y.unsqueeze(1)
for loss_fn in loss_fn_list:
step_losses = self._compute_losses(y, neg_dy, batch, loss_fn, stage)

loss_name = loss_fn.__name__
if self.hparams.neg_dy_weight > 0:
self.losses[stage]["neg_dy"][loss_name].append(
step_losses["neg_dy"].detach()
)
if self.hparams.y_weight > 0:
self.losses[stage]["y"][loss_name].append(step_losses["y"].detach())
total_loss = (
step_losses["y"] * self.hparams.y_weight
+ step_losses["neg_dy"] * self.hparams.neg_dy_weight
)
loss_name = loss_fn.__name__
self.losses[stage]["neg_dy"][loss_name].append(
step_losses["neg_dy"].detach()
)
self.losses[stage]["y"][loss_name].append(step_losses["y"].detach())
self.losses[stage]["total"][loss_name].append(total_loss.detach())
return total_loss

Expand Down Expand Up @@ -215,11 +218,6 @@ def on_validation_epoch_end(self):
result_dict.update(self._get_mean_loss_dict_for_type("total"))
result_dict.update(self._get_mean_loss_dict_for_type("y"))
result_dict.update(self._get_mean_loss_dict_for_type("neg_dy"))
# For retro compatibility with previous versions of TorchMD-Net we report some losses twice
result_dict["val_loss"] = result_dict["val_total_mse_loss"]
result_dict["train_loss"] = result_dict["train_total_mse_loss"]
if "test_total_l1_loss" in result_dict:
result_dict["test_loss"] = result_dict["test_total_l1_loss"]
self.log_dict(result_dict, sync_dist=True)

self._reset_losses_dict()
Expand Down
11 changes: 7 additions & 4 deletions torchmdnet/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def get_args():
parser.add_argument('--inference-batch-size', default=None, type=int, help='Batchsize for validation and tests.')
parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')
parser.add_argument('--lr-patience', type=int, default=10, help='Patience for lr-schedule. Patience per eval-interval of validation')
parser.add_argument('--lr-metric', type=str, default='val_loss', choices=['train_loss', 'val_loss'], help='Metric to monitor when deciding whether to reduce learning rate')
parser.add_argument('--lr-metric', type=str, default='val_total_mse_loss', choices=['train_total_mse_loss', 'val_total_mse_loss'], help='Metric to monitor when deciding whether to reduce learning rate')
parser.add_argument('--lr-min', type=float, default=1e-6, help='Minimum learning rate before early stop')
parser.add_argument('--lr-factor', type=float, default=0.8, help='Factor by which to multiply the learning rate when the metric stops improving')
parser.add_argument('--lr-warmup-steps', type=int, default=0, help='How many steps to warm-up over. Defaults to 0 for no warm-up')
Expand Down Expand Up @@ -140,12 +140,15 @@ def main():

checkpoint_callback = ModelCheckpoint(
dirpath=args.log_dir,
monitor="val_loss",
monitor="val_total_mse_loss",
save_top_k=10, # -1 to save all
every_n_epochs=args.save_interval,
filename="{epoch}-{val_loss:.4f}-{test_loss:.4f}",
filename="epoch={epoch}-val_loss={val_total_mse_loss:.4f}-test_loss={test_total_l1_loss:.4f}",
auto_insert_metric_name=False,
)
early_stopping = EarlyStopping(
"val_total_mse_loss", patience=args.early_stopping_patience
)
early_stopping = EarlyStopping("val_loss", patience=args.early_stopping_patience)

csv_logger = CSVLogger(args.log_dir, name="", version="")
_logger = [csv_logger]
Expand Down

0 comments on commit d998fbe

Please sign in to comment.