Skip to content

Commit

Permalink
Merge pull request #335 from RaulPPelaez/huberloss
Browse files Browse the repository at this point in the history
Add Huber loss, allow choosing training loss function from the yaml
  • Loading branch information
RaulPPelaez authored Aug 16, 2024
2 parents 26206eb + bfed435 commit 8a7b520
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 21 deletions.
7 changes: 7 additions & 0 deletions torchmdnet/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from torch.nn.functional import mse_loss, l1_loss, huber_loss

loss_class_mapping = {
"mse_loss": mse_loss,
"l1_loss": l1_loss,
"huber_loss": huber_loss,
}
66 changes: 51 additions & 15 deletions torchmdnet/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.nn.functional import local_response_norm, mse_loss, l1_loss
from torch.nn.functional import local_response_norm
from torch import Tensor
from typing import Optional, Dict, Tuple

from lightning import LightningModule
from torchmdnet.models.model import create_model, load_model
from torchmdnet.models.utils import dtype_mapping
from torchmdnet.loss import l1_loss, loss_class_mapping
import torch_geometric.transforms as T


Expand Down Expand Up @@ -48,6 +48,18 @@ def __call__(self, data):
return data


# This wrapper is here in order to permit Lightning to serialize the loss function.
class LossFunction:
def __init__(self, loss_fn, extra_args=None):
self.loss_fn = loss_fn
self.extra_args = extra_args
if self.extra_args is None:
self.extra_args = {}

def __call__(self, x, batch):
return self.loss_fn(x, batch, **self.extra_args)


class LNNP(LightningModule):
"""
Lightning wrapper for the Neural Network Potentials in TorchMD-Net.
Expand All @@ -65,7 +77,10 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None):
hparams["charge"] = False
if "spin" not in hparams:
hparams["spin"] = False

if "train_loss" not in hparams:
hparams["train_loss"] = "mse_loss"
if "train_loss_arg" not in hparams:
hparams["train_loss_arg"] = {}
self.save_hyperparameters(hparams)

if self.hparams.load_model:
Expand All @@ -92,6 +107,16 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None):
]
)

if self.hparams.train_loss not in loss_class_mapping:
raise ValueError(
f"Training loss {self.hparams.train_loss} not supported. Supported losses are {list(loss_class_mapping.keys())}"
)

self.train_loss_fn = LossFunction(
loss_class_mapping[self.hparams.train_loss],
self.hparams.train_loss_arg,
)

def configure_optimizers(self):
optimizer = AdamW(
self.model.parameters(),
Expand All @@ -105,9 +130,12 @@ def configure_optimizers(self):
patience=self.hparams.lr_patience,
min_lr=self.hparams.lr_min,
)
lr_metric = getattr(self.hparams, "lr_metric", "val")
monitor = f"{lr_metric}_total_{self.hparams.train_loss}"
lr_scheduler = {
"scheduler": scheduler,
"monitor": getattr(self.hparams, "lr_metric", "val_loss"),
"strict": True,
"monitor": monitor,
"interval": "epoch",
"frequency": 1,
}
Expand All @@ -126,7 +154,9 @@ def forward(
return self.model(z, pos, batch=batch, box=box, q=q, s=s, extra_args=extra_args)

def training_step(self, batch, batch_idx):
return self.step(batch, [mse_loss], "train")
return self.step(
batch, [(self.hparams.train_loss, self.train_loss_fn)], "train"
)

def validation_step(self, batch, batch_idx, *args):
# If args is not empty the first (and only) element is the dataloader_idx
Expand All @@ -135,28 +165,34 @@ def validation_step(self, batch, batch_idx, *args):
# The dataloader takes care of sending the two sets only when the second one is needed.
is_val = len(args) == 0 or (len(args) > 0 and args[0] == 0)
if is_val:
step_type = {"loss_fn_list": [l1_loss, mse_loss], "stage": "val"}
step_type = {
"loss_fn_list": [
("l1_loss", l1_loss),
(self.hparams.train_loss, self.train_loss_fn),
],
"stage": "val",
}
else:
step_type = {"loss_fn_list": [l1_loss], "stage": "test"}
step_type = {"loss_fn_list": [("l1_loss", l1_loss)], "stage": "test"}
return self.step(batch, **step_type)

def test_step(self, batch, batch_idx):
return self.step(batch, [l1_loss], "test")
return self.step(batch, [("l1_loss", l1_loss)], "test")

def _compute_losses(self, y, neg_y, batch, loss_fn, stage):
def _compute_losses(self, y, neg_y, batch, loss_fn, loss_name, stage):
# Compute the loss for the predicted value and the negative derivative (if available)
# Args:
# y: predicted value
# neg_y: predicted negative derivative
# batch: batch of data
# loss_fn: loss function to compute
# loss_fn: The loss function to compute
# loss_name: The name of the loss function
# Returns:
# loss_y: loss for the predicted value
# loss_neg_y: loss for the predicted negative derivative
loss_y, loss_neg_y = torch.tensor(0.0, device=self.device), torch.tensor(
0.0, device=self.device
)
loss_name = loss_fn.__name__
if self.hparams.derivative and "neg_dy" in batch:
loss_neg_y = loss_fn(neg_y, batch.neg_dy)
loss_neg_y = self._update_loss_with_ema(
Expand Down Expand Up @@ -221,10 +257,10 @@ def step(self, batch, loss_fn_list, stage):
neg_dy = neg_dy + y.sum() * 0
if "y" in batch and batch.y.ndim == 1:
batch.y = batch.y.unsqueeze(1)
for loss_fn in loss_fn_list:
step_losses = self._compute_losses(y, neg_dy, batch, loss_fn, stage)

loss_name = loss_fn.__name__
for loss_name, loss_fn in loss_fn_list:
step_losses = self._compute_losses(
y, neg_dy, batch, loss_fn, loss_name, stage
)
if self.hparams.neg_dy_weight > 0:
self.losses[stage]["neg_dy"][loss_name].append(
step_losses["neg_dy"].detach()
Expand Down
14 changes: 8 additions & 6 deletions torchmdnet/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from torchmdnet.module import LNNP
from torchmdnet import datasets, priors, models
from torchmdnet.data import DataModule
from torchmdnet.loss import loss_class_mapping
from torchmdnet.models import output_modules
from torchmdnet.models.model import create_prior_models
from torchmdnet.models.utils import rbf_class_mapping, act_class_mapping, dtype_mapping
Expand All @@ -34,7 +35,7 @@ def get_argparse():
parser.add_argument('--inference-batch-size', default=None, type=int, help='Batchsize for validation and tests.')
parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')
parser.add_argument('--lr-patience', type=int, default=10, help='Patience for lr-schedule. Patience per eval-interval of validation')
parser.add_argument('--lr-metric', type=str, default='val_total_mse_loss', choices=['train_total_mse_loss', 'val_total_mse_loss'], help='Metric to monitor when deciding whether to reduce learning rate')
parser.add_argument('--lr-metric', type=str, default='val', choices=['train', 'val'], help='Metric to monitor when deciding whether to reduce learning rate')
parser.add_argument('--lr-min', type=float, default=1e-6, help='Minimum learning rate before early stop')
parser.add_argument('--lr-factor', type=float, default=0.8, help='Factor by which to multiply the learning rate when the metric stops improving')
parser.add_argument('--lr-warmup-steps', type=int, default=0, help='How many steps to warm-up over. Defaults to 0 for no warm-up')
Expand Down Expand Up @@ -69,6 +70,8 @@ def get_argparse():
parser.add_argument('--dataset-preload-limit', default=1024, type=int, help='Custom and HDF5 datasets will preload to RAM datasets that are less than this size in MB')
parser.add_argument('--y-weight', default=1.0, type=float, help='Weighting factor for y label in the loss function')
parser.add_argument('--neg-dy-weight', default=1.0, type=float, help='Weighting factor for neg_dy label in the loss function')
parser.add_argument('--train-loss', default='mse_loss', type=str, choices=loss_class_mapping.keys(), help='Loss function to use during training')
parser.add_argument('--train-loss-arg', default=None, help='Additional arguments for the loss function. Needs to be a dictionary.')

# model architecture
parser.add_argument('--model', type=str, default='graph-network', choices=models.__all_models__, help='Which model to train')
Expand Down Expand Up @@ -165,17 +168,16 @@ def main():
# initialize lightning module
model = LNNP(args, prior_model=prior_models, mean=data.mean, std=data.std)

val_loss_name = f"val_total_{args.train_loss}"
checkpoint_callback = ModelCheckpoint(
dirpath=args.log_dir,
monitor="val_total_mse_loss",
monitor=val_loss_name,
save_top_k=10, # -1 to save all
every_n_epochs=args.save_interval,
filename="epoch={epoch}-val_loss={val_total_mse_loss:.4f}-test_loss={test_total_l1_loss:.4f}",
filename=f"epoch={{epoch}}-val_loss={{{val_loss_name}:.4f}}-test_loss={{test_total_l1_loss:.4f}}",
auto_insert_metric_name=False,
)
early_stopping = EarlyStopping(
"val_total_mse_loss", patience=args.early_stopping_patience
)
early_stopping = EarlyStopping(val_loss_name, patience=args.early_stopping_patience)

csv_logger = CSVLogger(args.log_dir, name="", version="")
_logger = [csv_logger]
Expand Down

0 comments on commit 8a7b520

Please sign in to comment.