From f6fa219b44da2ce78fdd7a8d2285ea0ecae61082 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 15 Jul 2024 11:55:53 +0200 Subject: [PATCH 01/11] Add Huber loss --- torchmdnet/loss.py | 7 +++++++ torchmdnet/module.py | 15 ++++++++++++--- torchmdnet/scripts/train.py | 2 ++ 3 files changed, 21 insertions(+), 3 deletions(-) create mode 100644 torchmdnet/loss.py diff --git a/torchmdnet/loss.py b/torchmdnet/loss.py new file mode 100644 index 00000000..e88faa60 --- /dev/null +++ b/torchmdnet/loss.py @@ -0,0 +1,7 @@ +from torch.nn.functional import mse_loss, l1_loss, huber_loss + +loss_map = { + "mse_loss": mse_loss, + "l1_loss": l1_loss, + "huber_loss": huber_loss, +} diff --git a/torchmdnet/module.py b/torchmdnet/module.py index d5ea73cf..98d9c000 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -6,7 +6,7 @@ import torch from torch.optim import AdamW from torch.optim.lr_scheduler import ReduceLROnPlateau -from torch.nn.functional import local_response_norm, mse_loss, l1_loss +from torch.nn.functional import local_response_norm from torch import Tensor from typing import Optional, Dict, Tuple @@ -16,6 +16,9 @@ import torch_geometric.transforms as T +from torchmdnet.loss import l1_loss, loss_map + + class FloatCastDatasetWrapper(T.BaseTransform): """A transform that casts all floating point tensors to a given dtype. tensors to a given dtype. @@ -92,6 +95,12 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None): ] ) + if self.hparams.training_loss not in loss_map: + raise ValueError( + f"Training loss {self.hparams.training_loss} not supported. Supported losses are {list(loss_map.keys())}" + ) + self.training_loss = loss_map[self.hparams.training_loss] + def configure_optimizers(self): optimizer = AdamW( self.model.parameters(), @@ -126,7 +135,7 @@ def forward( return self.model(z, pos, batch=batch, box=box, q=q, s=s, extra_args=extra_args) def training_step(self, batch, batch_idx): - return self.step(batch, [mse_loss], "train") + return self.step(batch, [self.training_loss], "train") def validation_step(self, batch, batch_idx, *args): # If args is not empty the first (and only) element is the dataloader_idx @@ -135,7 +144,7 @@ def validation_step(self, batch, batch_idx, *args): # The dataloader takes care of sending the two sets only when the second one is needed. is_val = len(args) == 0 or (len(args) > 0 and args[0] == 0) if is_val: - step_type = {"loss_fn_list": [l1_loss, mse_loss], "stage": "val"} + step_type = {"loss_fn_list": [l1_loss, self.training_loss], "stage": "val"} else: step_type = {"loss_fn_list": [l1_loss], "stage": "test"} return self.step(batch, **step_type) diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index 7f2d8e07..c7aa1db0 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -18,6 +18,7 @@ from torchmdnet.module import LNNP from torchmdnet import datasets, priors, models from torchmdnet.data import DataModule +from torchmdnet.loss import loss_map from torchmdnet.models import output_modules from torchmdnet.models.model import create_prior_models from torchmdnet.models.utils import rbf_class_mapping, act_class_mapping, dtype_mapping @@ -70,6 +71,7 @@ def get_argparse(): parser.add_argument('--dataset-preload-limit', default=1024, type=int, help='Custom and HDF5 datasets will preload to RAM datasets that are less than this size in MB') parser.add_argument('--y-weight', default=1.0, type=float, help='Weighting factor for y label in the loss function') parser.add_argument('--neg-dy-weight', default=1.0, type=float, help='Weighting factor for neg_dy label in the loss function') + parser.add_argument('--train_loss', default='mse', type=str, choices=loss_map.keys(), help='Loss function to use during training') # model architecture parser.add_argument('--model', type=str, default='graph-network', choices=models.__all_models__, help='Which model to train') From edc930abc3fd077f699d1a61ed4f5ed421b188a7 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 15 Jul 2024 13:00:51 +0200 Subject: [PATCH 02/11] rename --- torchmdnet/loss.py | 2 +- torchmdnet/module.py | 11 +++++------ torchmdnet/scripts/train.py | 4 ++-- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/torchmdnet/loss.py b/torchmdnet/loss.py index e88faa60..2f2d8c57 100644 --- a/torchmdnet/loss.py +++ b/torchmdnet/loss.py @@ -1,6 +1,6 @@ from torch.nn.functional import mse_loss, l1_loss, huber_loss -loss_map = { +loss_class_mapping = { "mse_loss": mse_loss, "l1_loss": l1_loss, "huber_loss": huber_loss, diff --git a/torchmdnet/module.py b/torchmdnet/module.py index 98d9c000..aae47c30 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -9,16 +9,13 @@ from torch.nn.functional import local_response_norm from torch import Tensor from typing import Optional, Dict, Tuple - from lightning import LightningModule from torchmdnet.models.model import create_model, load_model from torchmdnet.models.utils import dtype_mapping +from torchmdnet.loss import l1_loss, loss_class_mapping import torch_geometric.transforms as T -from torchmdnet.loss import l1_loss, loss_map - - class FloatCastDatasetWrapper(T.BaseTransform): """A transform that casts all floating point tensors to a given dtype. tensors to a given dtype. @@ -68,6 +65,8 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None): hparams["charge"] = False if "spin" not in hparams: hparams["spin"] = False + if "training_loss" not in hparams: + hparams["training_loss"] = "mse_loss" self.save_hyperparameters(hparams) @@ -95,11 +94,11 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None): ] ) - if self.hparams.training_loss not in loss_map: + if self.hparams.training_loss not in loss_class_mapping: raise ValueError( f"Training loss {self.hparams.training_loss} not supported. Supported losses are {list(loss_map.keys())}" ) - self.training_loss = loss_map[self.hparams.training_loss] + self.training_loss = loss_class_mapping[self.hparams.training_loss] def configure_optimizers(self): optimizer = AdamW( diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index c7aa1db0..c88d55c3 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -18,7 +18,7 @@ from torchmdnet.module import LNNP from torchmdnet import datasets, priors, models from torchmdnet.data import DataModule -from torchmdnet.loss import loss_map +from torchmdnet.loss import loss_class_mapping from torchmdnet.models import output_modules from torchmdnet.models.model import create_prior_models from torchmdnet.models.utils import rbf_class_mapping, act_class_mapping, dtype_mapping @@ -71,7 +71,7 @@ def get_argparse(): parser.add_argument('--dataset-preload-limit', default=1024, type=int, help='Custom and HDF5 datasets will preload to RAM datasets that are less than this size in MB') parser.add_argument('--y-weight', default=1.0, type=float, help='Weighting factor for y label in the loss function') parser.add_argument('--neg-dy-weight', default=1.0, type=float, help='Weighting factor for neg_dy label in the loss function') - parser.add_argument('--train_loss', default='mse', type=str, choices=loss_map.keys(), help='Loss function to use during training') + parser.add_argument('--train_loss', default='mse', type=str, choices=loss_class_mapping.keys(), help='Loss function to use during training') # model architecture parser.add_argument('--model', type=str, default='graph-network', choices=models.__all_models__, help='Which model to train') From 7281a567b43a1ac7ae9a1fbd9558a911311b546f Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 15 Jul 2024 13:17:14 +0200 Subject: [PATCH 03/11] Typo --- torchmdnet/module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/module.py b/torchmdnet/module.py index aae47c30..2e2eef65 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -96,7 +96,7 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None): if self.hparams.training_loss not in loss_class_mapping: raise ValueError( - f"Training loss {self.hparams.training_loss} not supported. Supported losses are {list(loss_map.keys())}" + f"Training loss {self.hparams.training_loss} not supported. Supported losses are {list(loss_class_mapping.keys())}" ) self.training_loss = loss_class_mapping[self.hparams.training_loss] From eb321cc6d2f19a5d1029acb4bc65e36709def925 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 15 Jul 2024 13:28:34 +0200 Subject: [PATCH 04/11] Typo --- torchmdnet/module.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/torchmdnet/module.py b/torchmdnet/module.py index 2e2eef65..b61c690f 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -65,8 +65,6 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None): hparams["charge"] = False if "spin" not in hparams: hparams["spin"] = False - if "training_loss" not in hparams: - hparams["training_loss"] = "mse_loss" self.save_hyperparameters(hparams) @@ -94,11 +92,11 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None): ] ) - if self.hparams.training_loss not in loss_class_mapping: + if self.hparams.train_loss not in loss_class_mapping: raise ValueError( - f"Training loss {self.hparams.training_loss} not supported. Supported losses are {list(loss_class_mapping.keys())}" + f"Training loss {self.hparams.train_loss} not supported. Supported losses are {list(loss_class_mapping.keys())}" ) - self.training_loss = loss_class_mapping[self.hparams.training_loss] + self.training_loss = loss_class_mapping[self.hparams.train_loss] def configure_optimizers(self): optimizer = AdamW( From 404afc0cf0fe1e8dfa113d334be8caaa7c1623aa Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 15 Jul 2024 13:33:24 +0200 Subject: [PATCH 05/11] Typo --- torchmdnet/scripts/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index c88d55c3..7fb144be 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -71,7 +71,7 @@ def get_argparse(): parser.add_argument('--dataset-preload-limit', default=1024, type=int, help='Custom and HDF5 datasets will preload to RAM datasets that are less than this size in MB') parser.add_argument('--y-weight', default=1.0, type=float, help='Weighting factor for y label in the loss function') parser.add_argument('--neg-dy-weight', default=1.0, type=float, help='Weighting factor for neg_dy label in the loss function') - parser.add_argument('--train_loss', default='mse', type=str, choices=loss_class_mapping.keys(), help='Loss function to use during training') + parser.add_argument('--train-loss', default='mse_loss', type=str, choices=loss_class_mapping.keys(), help='Loss function to use during training') # model architecture parser.add_argument('--model', type=str, default='graph-network', choices=models.__all_models__, help='Which model to train') From a47e63a3926839e6ea8fbeda4a33d38c4cb7f451 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 15 Jul 2024 13:54:00 +0200 Subject: [PATCH 06/11] Fix LR scheduler naming --- torchmdnet/module.py | 4 +++- torchmdnet/scripts/train.py | 11 +++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/torchmdnet/module.py b/torchmdnet/module.py index b61c690f..6e508d0d 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -111,9 +111,11 @@ def configure_optimizers(self): patience=self.hparams.lr_patience, min_lr=self.hparams.lr_min, ) + lr_metric = getattr(self.hparams, "lr_metric", "val") + monitor = f"{lr_metric}_total_{self.hparams.train_loss}" lr_scheduler = { "scheduler": scheduler, - "monitor": getattr(self.hparams, "lr_metric", "val_loss"), + "monitor": monitor, "interval": "epoch", "frequency": 1, } diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index 7fb144be..1fd656b7 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -36,7 +36,7 @@ def get_argparse(): parser.add_argument('--inference-batch-size', default=None, type=int, help='Batchsize for validation and tests.') parser.add_argument('--lr', default=1e-4, type=float, help='learning rate') parser.add_argument('--lr-patience', type=int, default=10, help='Patience for lr-schedule. Patience per eval-interval of validation') - parser.add_argument('--lr-metric', type=str, default='val_total_mse_loss', choices=['train_total_mse_loss', 'val_total_mse_loss'], help='Metric to monitor when deciding whether to reduce learning rate') + parser.add_argument('--lr-metric', type=str, default='val', choices=['train', 'val'], help='Metric to monitor when deciding whether to reduce learning rate') parser.add_argument('--lr-min', type=float, default=1e-6, help='Minimum learning rate before early stop') parser.add_argument('--lr-factor', type=float, default=0.8, help='Factor by which to multiply the learning rate when the metric stops improving') parser.add_argument('--lr-warmup-steps', type=int, default=0, help='How many steps to warm-up over. Defaults to 0 for no warm-up') @@ -168,17 +168,16 @@ def main(): # initialize lightning module model = LNNP(args, prior_model=prior_models, mean=data.mean, std=data.std) + val_loss_name = f"val_total_{args.train_loss}" checkpoint_callback = ModelCheckpoint( dirpath=args.log_dir, - monitor="val_total_mse_loss", + monitor=val_loss_name, save_top_k=10, # -1 to save all every_n_epochs=args.save_interval, - filename="epoch={epoch}-val_loss={val_total_mse_loss:.4f}-test_loss={test_total_l1_loss:.4f}", + filename=f"epoch={{epoch}}-val_loss={{{val_loss_name}:.4f}}-test_loss={{test_total_l1_loss:.4f}}", auto_insert_metric_name=False, ) - early_stopping = EarlyStopping( - "val_total_mse_loss", patience=args.early_stopping_patience - ) + early_stopping = EarlyStopping(val_loss_name, patience=args.early_stopping_patience) csv_logger = CSVLogger(args.log_dir, name="", version="") _logger = [csv_logger] From 1a8f0e364b2812bd749098cd19e3c83ef3db5f24 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 15 Jul 2024 14:03:17 +0200 Subject: [PATCH 07/11] Add default --- torchmdnet/module.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchmdnet/module.py b/torchmdnet/module.py index 6e508d0d..1f3aa1bf 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -65,6 +65,8 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None): hparams["charge"] = False if "spin" not in hparams: hparams["spin"] = False + if "train_loss" not in hparams: + hparams["train_loss"] = "mse_loss" self.save_hyperparameters(hparams) From a7621ab29a6496da9c3ef97c599826ed80f62f8c Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 12 Aug 2024 12:27:38 +0200 Subject: [PATCH 08/11] Add strict so Lightning complains if the loss metric is not available in the logs --- torchmdnet/module.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchmdnet/module.py b/torchmdnet/module.py index 1f3aa1bf..2751462a 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -117,6 +117,7 @@ def configure_optimizers(self): monitor = f"{lr_metric}_total_{self.hparams.train_loss}" lr_scheduler = { "scheduler": scheduler, + "strict": True, "monitor": monitor, "interval": "epoch", "frequency": 1, From 860fa90159eaf6826f889c0e813a65828a314875 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 12 Aug 2024 12:28:18 +0200 Subject: [PATCH 09/11] Add argument to pass additional arguments to loss functions --- torchmdnet/module.py | 14 +++++++++++--- torchmdnet/scripts/train.py | 1 + 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/torchmdnet/module.py b/torchmdnet/module.py index 2751462a..87553f2d 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -67,7 +67,8 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None): hparams["spin"] = False if "train_loss" not in hparams: hparams["train_loss"] = "mse_loss" - + if "train_loss_arg" not in hparams: + hparams["train_loss_arg"] = {} self.save_hyperparameters(hparams) if self.hparams.load_model: @@ -98,7 +99,15 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None): raise ValueError( f"Training loss {self.hparams.train_loss} not supported. Supported losses are {list(loss_class_mapping.keys())}" ) - self.training_loss = loss_class_mapping[self.hparams.train_loss] + if self.hparams.train_loss_arg is None: + self.hparams.train_loss_arg = {} + self.training_loss = lambda x, batch: loss_class_mapping[ + self.hparams.train_loss + ](x, batch, **self.hparams.train_loss_arg) + # The name of the loss function is used for logging and the LR scheduler, it cannot be just + self.training_loss.__name__ = loss_class_mapping[ + self.hparams.train_loss + ].__name__ def configure_optimizers(self): optimizer = AdamW( @@ -234,7 +243,6 @@ def step(self, batch, loss_fn_list, stage): batch.y = batch.y.unsqueeze(1) for loss_fn in loss_fn_list: step_losses = self._compute_losses(y, neg_dy, batch, loss_fn, stage) - loss_name = loss_fn.__name__ if self.hparams.neg_dy_weight > 0: self.losses[stage]["neg_dy"][loss_name].append( diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index e2c76732..76f4c63d 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -71,6 +71,7 @@ def get_argparse(): parser.add_argument('--y-weight', default=1.0, type=float, help='Weighting factor for y label in the loss function') parser.add_argument('--neg-dy-weight', default=1.0, type=float, help='Weighting factor for neg_dy label in the loss function') parser.add_argument('--train-loss', default='mse_loss', type=str, choices=loss_class_mapping.keys(), help='Loss function to use during training') + parser.add_argument('--train-loss-arg', default=None, help='Additional arguments for the loss function. Needs to be a dictionary.') # model architecture parser.add_argument('--model', type=str, default='graph-network', choices=models.__all_models__, help='Which model to train') From 5a88579c715091d3fd87377e0201ff155de8ab11 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 12 Aug 2024 13:06:18 +0200 Subject: [PATCH 10/11] Fix serialization issue --- torchmdnet/module.py | 51 +++++++++++++++++++++++++++++--------------- 1 file changed, 34 insertions(+), 17 deletions(-) diff --git a/torchmdnet/module.py b/torchmdnet/module.py index 87553f2d..19c7f6e9 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -48,6 +48,16 @@ def __call__(self, data): return data +# This wrapper is here in order to permit Lightning to serialize the loss function. +class LossFunction: + def __init__(self, loss_fn, **kwargs): + self.loss_fn = loss_fn + self.kwargs = kwargs + + def __call__(self, x, batch): + return self.loss_fn(x, batch, **self.kwargs) + + class LNNP(LightningModule): """ Lightning wrapper for the Neural Network Potentials in TorchMD-Net. @@ -101,13 +111,11 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None): ) if self.hparams.train_loss_arg is None: self.hparams.train_loss_arg = {} - self.training_loss = lambda x, batch: loss_class_mapping[ - self.hparams.train_loss - ](x, batch, **self.hparams.train_loss_arg) - # The name of the loss function is used for logging and the LR scheduler, it cannot be just - self.training_loss.__name__ = loss_class_mapping[ - self.hparams.train_loss - ].__name__ + + self.train_loss_fn = LossFunction( + loss_class_mapping[self.hparams.train_loss], + **self.hparams.train_loss_arg, + ) def configure_optimizers(self): optimizer = AdamW( @@ -146,7 +154,9 @@ def forward( return self.model(z, pos, batch=batch, box=box, q=q, s=s, extra_args=extra_args) def training_step(self, batch, batch_idx): - return self.step(batch, [self.training_loss], "train") + return self.step( + batch, [(self.hparams.train_loss, self.train_loss_fn)], "train" + ) def validation_step(self, batch, batch_idx, *args): # If args is not empty the first (and only) element is the dataloader_idx @@ -155,28 +165,34 @@ def validation_step(self, batch, batch_idx, *args): # The dataloader takes care of sending the two sets only when the second one is needed. is_val = len(args) == 0 or (len(args) > 0 and args[0] == 0) if is_val: - step_type = {"loss_fn_list": [l1_loss, self.training_loss], "stage": "val"} + step_type = { + "loss_fn_list": [ + ("l1_loss", l1_loss), + (self.hparams.train_loss, self.train_loss_fn), + ], + "stage": "val", + } else: - step_type = {"loss_fn_list": [l1_loss], "stage": "test"} + step_type = {"loss_fn_list": [("l1_loss", l1_loss)], "stage": "test"} return self.step(batch, **step_type) def test_step(self, batch, batch_idx): - return self.step(batch, [l1_loss], "test") + return self.step(batch, [("l1_loss", l1_loss)], "test") - def _compute_losses(self, y, neg_y, batch, loss_fn, stage): + def _compute_losses(self, y, neg_y, batch, loss_fn, loss_name, stage): # Compute the loss for the predicted value and the negative derivative (if available) # Args: # y: predicted value # neg_y: predicted negative derivative # batch: batch of data - # loss_fn: loss function to compute + # loss_fn: The loss function to compute + # loss_name: The name of the loss function # Returns: # loss_y: loss for the predicted value # loss_neg_y: loss for the predicted negative derivative loss_y, loss_neg_y = torch.tensor(0.0, device=self.device), torch.tensor( 0.0, device=self.device ) - loss_name = loss_fn.__name__ if self.hparams.derivative and "neg_dy" in batch: loss_neg_y = loss_fn(neg_y, batch.neg_dy) loss_neg_y = self._update_loss_with_ema( @@ -241,9 +257,10 @@ def step(self, batch, loss_fn_list, stage): neg_dy = neg_dy + y.sum() * 0 if "y" in batch and batch.y.ndim == 1: batch.y = batch.y.unsqueeze(1) - for loss_fn in loss_fn_list: - step_losses = self._compute_losses(y, neg_dy, batch, loss_fn, stage) - loss_name = loss_fn.__name__ + for loss_name, loss_fn in loss_fn_list: + step_losses = self._compute_losses( + y, neg_dy, batch, loss_fn, loss_name, stage + ) if self.hparams.neg_dy_weight > 0: self.losses[stage]["neg_dy"][loss_name].append( step_losses["neg_dy"].detach() From bfed43548c937f2507abc90d40c6e3cb0a30c4bd Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 12 Aug 2024 13:10:01 +0200 Subject: [PATCH 11/11] Fix hparam --- torchmdnet/module.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torchmdnet/module.py b/torchmdnet/module.py index 19c7f6e9..d0f9d3bd 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -50,12 +50,14 @@ def __call__(self, data): # This wrapper is here in order to permit Lightning to serialize the loss function. class LossFunction: - def __init__(self, loss_fn, **kwargs): + def __init__(self, loss_fn, extra_args=None): self.loss_fn = loss_fn - self.kwargs = kwargs + self.extra_args = extra_args + if self.extra_args is None: + self.extra_args = {} def __call__(self, x, batch): - return self.loss_fn(x, batch, **self.kwargs) + return self.loss_fn(x, batch, **self.extra_args) class LNNP(LightningModule): @@ -109,12 +111,10 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None): raise ValueError( f"Training loss {self.hparams.train_loss} not supported. Supported losses are {list(loss_class_mapping.keys())}" ) - if self.hparams.train_loss_arg is None: - self.hparams.train_loss_arg = {} self.train_loss_fn = LossFunction( loss_class_mapping[self.hparams.train_loss], - **self.hparams.train_loss_arg, + self.hparams.train_loss_arg, ) def configure_optimizers(self):