Skip to content

Commit

Permalink
rename
Browse files Browse the repository at this point in the history
  • Loading branch information
RaulPPelaez committed Jul 15, 2024
1 parent f6fa219 commit edc930a
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 9 deletions.
2 changes: 1 addition & 1 deletion torchmdnet/loss.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from torch.nn.functional import mse_loss, l1_loss, huber_loss

loss_map = {
loss_class_mapping = {
"mse_loss": mse_loss,
"l1_loss": l1_loss,
"huber_loss": huber_loss,
Expand Down
11 changes: 5 additions & 6 deletions torchmdnet/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,13 @@
from torch.nn.functional import local_response_norm
from torch import Tensor
from typing import Optional, Dict, Tuple

from lightning import LightningModule
from torchmdnet.models.model import create_model, load_model
from torchmdnet.models.utils import dtype_mapping
from torchmdnet.loss import l1_loss, loss_class_mapping
import torch_geometric.transforms as T


from torchmdnet.loss import l1_loss, loss_map


class FloatCastDatasetWrapper(T.BaseTransform):
"""A transform that casts all floating point tensors to a given dtype.
tensors to a given dtype.
Expand Down Expand Up @@ -68,6 +65,8 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None):
hparams["charge"] = False
if "spin" not in hparams:
hparams["spin"] = False
if "training_loss" not in hparams:
hparams["training_loss"] = "mse_loss"

self.save_hyperparameters(hparams)

Expand Down Expand Up @@ -95,11 +94,11 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None):
]
)

if self.hparams.training_loss not in loss_map:
if self.hparams.training_loss not in loss_class_mapping:
raise ValueError(
f"Training loss {self.hparams.training_loss} not supported. Supported losses are {list(loss_map.keys())}"
)
self.training_loss = loss_map[self.hparams.training_loss]
self.training_loss = loss_class_mapping[self.hparams.training_loss]

def configure_optimizers(self):
optimizer = AdamW(
Expand Down
4 changes: 2 additions & 2 deletions torchmdnet/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torchmdnet.module import LNNP
from torchmdnet import datasets, priors, models
from torchmdnet.data import DataModule
from torchmdnet.loss import loss_map
from torchmdnet.loss import loss_class_mapping
from torchmdnet.models import output_modules
from torchmdnet.models.model import create_prior_models
from torchmdnet.models.utils import rbf_class_mapping, act_class_mapping, dtype_mapping
Expand Down Expand Up @@ -71,7 +71,7 @@ def get_argparse():
parser.add_argument('--dataset-preload-limit', default=1024, type=int, help='Custom and HDF5 datasets will preload to RAM datasets that are less than this size in MB')
parser.add_argument('--y-weight', default=1.0, type=float, help='Weighting factor for y label in the loss function')
parser.add_argument('--neg-dy-weight', default=1.0, type=float, help='Weighting factor for neg_dy label in the loss function')
parser.add_argument('--train_loss', default='mse', type=str, choices=loss_map.keys(), help='Loss function to use during training')
parser.add_argument('--train_loss', default='mse', type=str, choices=loss_class_mapping.keys(), help='Loss function to use during training')

# model architecture
parser.add_argument('--model', type=str, default='graph-network', choices=models.__all_models__, help='Which model to train')
Expand Down

0 comments on commit edc930a

Please sign in to comment.