From ac16c09069f3c8de85b0e6451db51eca4e86144d Mon Sep 17 00:00:00 2001 From: Raul Date: Wed, 6 Sep 2023 11:23:12 +0200 Subject: [PATCH] Updating to Lightning 2.0 (#210) * Update to lightning 2.0 * Update env * Update env * Add pydantic<2 * Update * Update train.py * Update a couple imports * Update * Update * Update * Default test_interval to -1, print warning if its positive * Reproduce the trick used before to test during training Alas, it requires to reload the dataloaders every epoch when test_interval>0 * Blacken * Blacken * Small update * Fix typo * Fix default reset_dataloaders_every_n_epochs * Add l1_loss to validation * Report val_loss, train_loss and test_loss with the same name as before Change the order of operations to reproduce trains done with previous versions more closely * blacken * Use zero_rank_warn * Set inference_mode to false during testing * Add more arguments to testing Trainer * Add a test result log Always log losses on y and neg_dy even if they are weighted 0 for the total loss * Change lightning.pytorch to just lightning * Change a comprehension into an if/else * Remove spurious self. in _get_mean_losses * Use defaultdict * Update to lightning 2.0.8 * Blacken --- environment.yml | 3 +- tests/test_model.py | 2 +- tests/test_module.py | 4 +- tests/test_priors.py | 2 +- torchmdnet/data.py | 25 ++-- torchmdnet/models/model.py | 2 +- torchmdnet/module.py | 268 +++++++++++++++++++---------------- torchmdnet/priors/atomref.py | 2 +- torchmdnet/scripts/train.py | 43 ++++-- torchmdnet/utils.py | 2 +- 10 files changed, 200 insertions(+), 153 deletions(-) diff --git a/environment.yml b/environment.yml index aeac149e3..d7a5207e5 100644 --- a/environment.yml +++ b/environment.yml @@ -11,7 +11,8 @@ dependencies: - pytorch_geometric==2.3.1 - pytorch_scatter==2.1.1 - pytorch_sparse==0.6.17 - - pytorch-lightning==1.6.3 + - lightning==2.0.8 + - pydantic<2 - torchmetrics==0.11.4 - tqdm # Dev tools diff --git a/tests/test_model.py b/tests/test_model.py index 80e2461ef..8b436abb3 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -3,7 +3,7 @@ import pickle from os.path import exists, dirname, join import torch -import pytorch_lightning as pl +import lightning as pl from torchmdnet import models from torchmdnet.models.model import create_model from torchmdnet.models import output_modules diff --git a/tests/test_module.py b/tests/test_module.py index 9631a6996..eba599cfb 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -1,7 +1,7 @@ from pytest import mark from glob import glob from os.path import dirname, join -import pytorch_lightning as pl +import lightning as pl from torchmdnet import models from torchmdnet.models.model import load_model from torchmdnet.priors import Atomref @@ -49,6 +49,6 @@ def test_train(model_name, use_atomref, precision, tmpdir): module = LNNP(args, prior_model=prior) - trainer = pl.Trainer(max_steps=10, default_root_dir=tmpdir, precision=args["precision"]) + trainer = pl.Trainer(max_steps=10, default_root_dir=tmpdir, precision=args["precision"],inference_mode=False) trainer.fit(module, datamodule) trainer.test(module, datamodule) diff --git a/tests/test_priors.py b/tests/test_priors.py index 1d6b477f9..47b838381 100644 --- a/tests/test_priors.py +++ b/tests/test_priors.py @@ -1,7 +1,7 @@ import pytest from pytest import mark import torch -import pytorch_lightning as pl +import lightning as pl from torchmdnet import models from torchmdnet.models.model import create_model, create_prior_models from torchmdnet.module import LNNP diff --git a/torchmdnet/data.py b/torchmdnet/data.py index f150d5f8d..702df4f46 100644 --- a/torchmdnet/data.py +++ b/torchmdnet/data.py @@ -3,8 +3,8 @@ import torch from torch.utils.data import Subset from torch_geometric.loader import DataLoader -from pytorch_lightning import LightningDataModule -from pytorch_lightning.utilities import rank_zero_warn +from lightning import LightningDataModule +from lightning_utilities.core.rank_zero import rank_zero_warn from torchmdnet import datasets from torch_geometric.data import Dataset from torchmdnet.utils import make_splits, MissingEnergyException @@ -92,10 +92,10 @@ def train_dataloader(self): def val_dataloader(self): loaders = [self._get_dataloader(self.val_dataset, "val")] - if ( - len(self.test_dataset) > 0 - and (self.trainer.current_epoch + 1) % self.hparams["test_interval"] == 0 - ): + # To allow to report the performance on the testing dataset during training + # we send the trainer two dataloaders every few steps and modify the + # validation step to understand the second dataloader as test data. + if self._is_test_during_training_epoch(): loaders.append(self._get_dataloader(self.test_dataset, "test")) return loaders @@ -116,13 +116,16 @@ def mean(self): def std(self): return self._std - def _get_dataloader(self, dataset, stage, store_dataloader=True): - store_dataloader = ( - store_dataloader and self.trainer.reload_dataloaders_every_n_epochs <= 0 + def _is_test_during_training_epoch(self): + return ( + len(self.test_dataset) > 0 + and self.hparams["test_interval"] > 0 + and self.trainer.current_epoch > 0 + and self.trainer.current_epoch % self.hparams["test_interval"] == 0 ) + + def _get_dataloader(self, dataset, stage, store_dataloader=True): if stage in self._saved_dataloaders and store_dataloader: - # storing the dataloaders like this breaks calls to trainer.reload_train_val_dataloaders - # but makes it possible that the dataloaders are not recreated on every testing epoch return self._saved_dataloaders[stage] if stage == "train": diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index 8a56dd39b..8a5bf7038 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -4,11 +4,11 @@ from torch.autograd import grad from torch import nn, Tensor from torch_scatter import scatter -from pytorch_lightning.utilities import rank_zero_warn from torchmdnet.models import output_modules from torchmdnet.models.wrappers import AtomFilter from torchmdnet.models.utils import dtype_mapping from torchmdnet import priors +from lightning_utilities.core.rank_zero import rank_zero_warn import warnings diff --git a/torchmdnet/module.py b/torchmdnet/module.py index a04142a61..12da3943d 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -1,11 +1,12 @@ +from collections import defaultdict import torch from torch.optim import AdamW from torch.optim.lr_scheduler import ReduceLROnPlateau -from torch.nn.functional import mse_loss, l1_loss +from torch.nn.functional import local_response_norm, mse_loss, l1_loss from torch import Tensor from typing import Optional, Dict, Tuple -from pytorch_lightning import LightningModule +from lightning import LightningModule from torchmdnet.models.model import create_model, load_model @@ -13,9 +14,9 @@ class LNNP(LightningModule): """ Lightning wrapper for the Neural Network Potentials in TorchMD-Net. """ + def __init__(self, hparams, prior_model=None, mean=None, std=None): super(LNNP, self).__init__() - if "charge" not in hparams: hparams["charge"] = False if "spin" not in hparams: @@ -57,34 +58,89 @@ def configure_optimizers(self): } return [optimizer], [lr_scheduler] - def forward(self, - z: Tensor, - pos: Tensor, - batch: Optional[Tensor] = None, - q: Optional[Tensor] = None, - s: Optional[Tensor] = None, - extra_args: Optional[Dict[str, Tensor]] = None - ) -> Tuple[Tensor, Optional[Tensor]]: - + def forward( + self, + z: Tensor, + pos: Tensor, + batch: Optional[Tensor] = None, + q: Optional[Tensor] = None, + s: Optional[Tensor] = None, + extra_args: Optional[Dict[str, Tensor]] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: return self.model(z, pos, batch=batch, q=q, s=s, extra_args=extra_args) def training_step(self, batch, batch_idx): - return self.step(batch, mse_loss, "train") + return self.step(batch, [mse_loss], "train") def validation_step(self, batch, batch_idx, *args): - if len(args) == 0 or (len(args) > 0 and args[0] == 0): - # validation step - return self.step(batch, mse_loss, "val") - # test step - return self.step(batch, l1_loss, "test") + # If args is not empty the first (and only) element is the dataloader_idx + # We want to test every number of epochs just for reporting, but this is not supported by Lightning. + # Instead, we trick it by providing two validation dataloaders and interpreting the second one as test. + # The dataloader takes care of sending the two sets only when the second one is needed. + is_val = len(args) == 0 or (len(args) > 0 and args[0] == 0) + if is_val: + step_type = {"loss_fn_list": [l1_loss, mse_loss], "stage": "val"} + else: + step_type = {"loss_fn_list": [l1_loss], "stage": "test"} + return self.step(batch, **step_type) def test_step(self, batch, batch_idx): - return self.step(batch, l1_loss, "test") + return self.step(batch, [l1_loss], "test") + + def _compute_losses(self, y, neg_y, batch, loss_fn, stage): + # Compute the loss for the predicted value and the negative derivative (if available) + # Args: + # y: predicted value + # neg_y: predicted negative derivative + # batch: batch of data + # loss_fn: loss function to compute + # Returns: + # loss_y: loss for the predicted value + # loss_neg_y: loss for the predicted negative derivative + loss_y, loss_neg_y = 0.0, 0.0 + loss_name = loss_fn.__name__ + if self.hparams.derivative and "neg_dy" in batch: + loss_neg_y = loss_fn(neg_y, batch.neg_dy) + loss_neg_y = self._update_loss_with_ema( + stage, "neg_dy", loss_name, loss_neg_y + ) + if "y" in batch: + loss_y = loss_fn(y, batch.y) + loss_y = self._update_loss_with_ema(stage, "y", loss_name, loss_y) + return {"y": loss_y, "neg_dy": loss_neg_y} + + def _update_loss_with_ema(self, stage, type, loss_name, loss): + # Update the loss using an exponential moving average when applicable + # Args: + # stage: stage of the training (train, val, test) + # type: type of loss (y, neg_dy) + # loss_name: name of the loss function + # loss: loss value + alpha = getattr(self.hparams, f"ema_alpha_{type}") + if stage in ["train", "val"] and alpha < 1: + ema = ( + self.ema[stage][type][loss_name] + if loss_name in self.ema[stage][type] + else loss.detach() + ) + loss = alpha * loss + (1 - alpha) * ema + self.ema[stage][type][loss_name] = loss.detach() + return loss - def step(self, batch, loss_fn, stage): + def step(self, batch, loss_fn_list, stage): + # Run a forward pass and compute the loss for each loss function + # If the batch contains the derivative, also compute the loss for the negative derivative + # Args: + # batch: batch of data + # loss_fn_list: list of loss functions to compute and record (the last one is used for the total loss returned by this function) + # stage: stage of the training (train, val, test) + # Returns: + # total_loss: sum of all losses (weighted by the loss weights) for the last loss function in the provided list + assert len(loss_fn_list) > 0 + assert self.losses is not None with torch.set_grad_enabled(stage == "train" or self.hparams.derivative): extra_args = batch.to_dict() - for a in ('y', 'neg_dy', 'z', 'pos', 'batch', 'q', 's'): + for a in ("y", "neg_dy", "z", "pos", "batch", "q", "s"): if a in extra_args: del extra_args[a] # TODO: the model doesn't necessarily need to return a derivative once @@ -95,59 +151,29 @@ def step(self, batch, loss_fn, stage): batch=batch.batch, q=batch.q if self.hparams.charge else None, s=batch.s if self.hparams.spin else None, - extra_args=extra_args + extra_args=extra_args, ) - - loss_y, loss_neg_dy = 0, 0 - if self.hparams.derivative: - if "y" not in batch: - # "use" both outputs of the model's forward function but discard the first - # to only use the negative derivative and avoid 'Expected to have finished reduction - # in the prior iteration before starting a new one.', which otherwise get's - # thrown because of setting 'find_unused_parameters=False' in the DDPPlugin - neg_dy = neg_dy + y.sum() * 0 - - # negative derivative loss - loss_neg_dy = loss_fn(neg_dy, batch.neg_dy) - - if stage in ["train", "val"] and self.hparams.ema_alpha_neg_dy < 1: - if self.ema[stage + "_neg_dy"] is None: - self.ema[stage + "_neg_dy"] = loss_neg_dy.detach() - # apply exponential smoothing over batches to neg_dy - loss_neg_dy = ( - self.hparams.ema_alpha_neg_dy * loss_neg_dy - + (1 - self.hparams.ema_alpha_neg_dy) * self.ema[stage + "_neg_dy"] - ) - self.ema[stage + "_neg_dy"] = loss_neg_dy.detach() - - if self.hparams.neg_dy_weight > 0: - self.losses[stage + "_neg_dy"].append(loss_neg_dy.detach()) - - if "y" in batch: - if batch.y.ndim == 1: - batch.y = batch.y.unsqueeze(1) - - # y loss - loss_y = loss_fn(y, batch.y) - - if stage in ["train", "val"] and self.hparams.ema_alpha_y < 1: - if self.ema[stage + "_y"] is None: - self.ema[stage + "_y"] = loss_y.detach() - # apply exponential smoothing over batches to y - loss_y = ( - self.hparams.ema_alpha_y * loss_y - + (1 - self.hparams.ema_alpha_y) * self.ema[stage + "_y"] - ) - self.ema[stage + "_y"] = loss_y.detach() - - if self.hparams.y_weight > 0: - self.losses[stage + "_y"].append(loss_y.detach()) - - # total loss - loss = loss_y * self.hparams.y_weight + loss_neg_dy * self.hparams.neg_dy_weight - - self.losses[stage].append(loss.detach()) - return loss + if self.hparams.derivative and "y" not in batch: + # "use" both outputs of the model's forward function but discard the first + # to only use the negative derivative and avoid 'Expected to have finished reduction + # in the prior iteration before starting a new one.', which otherwise get's + # thrown because of setting 'find_unused_parameters=False' in the DDPPlugin + neg_dy = neg_dy + y.sum() * 0 + if "y" in batch and batch.y.ndim == 1: + batch.y = batch.y.unsqueeze(1) + for loss_fn in loss_fn_list: + step_losses = self._compute_losses(y, neg_dy, batch, loss_fn, stage) + total_loss = ( + step_losses["y"] * self.hparams.y_weight + + step_losses["neg_dy"] * self.hparams.neg_dy_weight + ) + loss_name = loss_fn.__name__ + self.losses[stage]["neg_dy"][loss_name].append( + step_losses["neg_dy"].detach() + ) + self.losses[stage]["y"][loss_name].append(step_losses["y"].detach()) + self.losses[stage]["total"][loss_name].append(total_loss.detach()) + return total_loss def optimizer_step(self, *args, **kwargs): optimizer = kwargs["optimizer"] if "optimizer" in kwargs else args[2] @@ -163,64 +189,66 @@ def optimizer_step(self, *args, **kwargs): super().optimizer_step(*args, **kwargs) optimizer.zero_grad() - def training_epoch_end(self, training_step_outputs): - dm = self.trainer.datamodule - if hasattr(dm, "test_dataset") and len(dm.test_dataset) > 0: - should_reset = ( - self.current_epoch % self.hparams.test_interval == 0 - or (self.current_epoch + 1) % self.hparams.test_interval == 0 - ) - if should_reset: - # reset validation dataloaders before and after testing epoch, which is faster - # than skipping test validation steps by returning None - self.trainer.reset_val_dataloader(self) + def _get_mean_loss_dict_for_type(self, type): + # Returns a list with the mean loss for each loss_fn for each stage (train, val, test) + # Parameters: + # type: either y, neg_dy or total + # Returns: + # A dict with an entry for each stage (train, val, test) with the mean loss for each loss_fn (e.g. mse_loss) + # The key for each entry is "stage_type_loss_fn" + assert self.losses is not None + mean_losses = {} + for stage in ["train", "val", "test"]: + for loss_fn_name in self.losses[stage][type].keys(): + mean_losses[stage + "_" + type + "_" + loss_fn_name] = torch.stack( + self.losses[stage][type][loss_fn_name] + ).mean() + return mean_losses - def validation_epoch_end(self, validation_step_outputs): + def on_validation_epoch_end(self): if not self.trainer.sanity_checking: # construct dict of logged metrics result_dict = { "epoch": float(self.current_epoch), "lr": self.trainer.optimizers[0].param_groups[0]["lr"], - "train_loss": torch.stack(self.losses["train"]).mean(), - "val_loss": torch.stack(self.losses["val"]).mean(), } + result_dict |= self._get_mean_loss_dict_for_type("total") + result_dict |= self._get_mean_loss_dict_for_type("y") + result_dict |= self._get_mean_loss_dict_for_type("neg_dy") + # For retro compatibility with previous versions of TorchMD-Net we report some losses twice + result_dict["val_loss"] = result_dict["val_total_mse_loss"] + result_dict["train_loss"] = result_dict["train_total_mse_loss"] + if "test_total_l1_loss" in result_dict: + result_dict["test_loss"] = result_dict["test_total_l1_loss"] + self.log_dict(result_dict, sync_dist=True) - # add test loss if available - if len(self.losses["test"]) > 0: - result_dict["test_loss"] = torch.stack(self.losses["test"]).mean() - - # if prediction and derivative are present, also log them separately - if len(self.losses["train_y"]) > 0 and len(self.losses["train_neg_dy"]) > 0: - result_dict["train_loss_y"] = torch.stack(self.losses["train_y"]).mean() - result_dict["train_loss_neg_dy"] = torch.stack( - self.losses["train_neg_dy"] - ).mean() - result_dict["val_loss_y"] = torch.stack(self.losses["val_y"]).mean() - result_dict["val_loss_neg_dy"] = torch.stack(self.losses["val_neg_dy"]).mean() - - if len(self.losses["test"]) > 0: - result_dict["test_loss_y"] = torch.stack( - self.losses["test_y"] - ).mean() - result_dict["test_loss_neg_dy"] = torch.stack( - self.losses["test_neg_dy"] - ).mean() + self._reset_losses_dict() + def on_test_epoch_end(self): + # Log all test losses + if not self.trainer.sanity_checking: + result_dict = {} + result_dict |= self._get_mean_loss_dict_for_type("total") + result_dict |= self._get_mean_loss_dict_for_type("y") + result_dict |= self._get_mean_loss_dict_for_type("neg_dy") + # Get only test entries + result_dict = {k: v for k, v in result_dict.items() if k.startswith("test")} self.log_dict(result_dict, sync_dist=True) - self._reset_losses_dict() def _reset_losses_dict(self): - self.losses = { - "train": [], - "val": [], - "test": [], - "train_y": [], - "val_y": [], - "test_y": [], - "train_neg_dy": [], - "val_neg_dy": [], - "test_neg_dy": [], - } + # Losses has an entry for each stage in ["train", "val", "test"] + # Each entry has an entry with "total", "y" and "neg_dy" + # Each of these entries has an entry for each loss_fn (e.g. mse_loss) + # The loss_fn values are not known in advance + self.losses = {} + for stage in ["train", "val", "test"]: + self.losses[stage] = {} + for loss_type in ["total", "y", "neg_dy"]: + self.losses[stage][loss_type] = defaultdict(list) def _reset_ema_dict(self): - self.ema = {"train_y": None, "val_y": None, "train_neg_dy": None, "val_neg_dy": None} + self.ema = {} + for stage in ["train", "val"]: + self.ema[stage] = {} + for loss_type in ["y", "neg_dy"]: + self.ema[stage][loss_type] = {} diff --git a/torchmdnet/priors/atomref.py b/torchmdnet/priors/atomref.py index 7ef4f64f1..22e31f67c 100644 --- a/torchmdnet/priors/atomref.py +++ b/torchmdnet/priors/atomref.py @@ -2,7 +2,7 @@ from typing import Optional, Dict import torch from torch import nn, Tensor -from pytorch_lightning.utilities import rank_zero_warn +from lightning_utilities.core.rank_zero import rank_zero_warn class Atomref(BasePrior): diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index e39c48dc8..61b8dc33b 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -2,11 +2,13 @@ import os import argparse import logging -import pytorch_lightning as pl -from pytorch_lightning.callbacks import EarlyStopping -from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint -from pytorch_lightning.loggers import CSVLogger, WandbLogger -from pytorch_lightning.strategies.ddp import DDPStrategy +import lightning.pytorch as pl +from lightning.pytorch.strategies import DDPStrategy +from lightning.pytorch.loggers import WandbLogger, CSVLogger, TensorBoardLogger +from lightning.pytorch.callbacks import ( + ModelCheckpoint, + EarlyStopping, +) from torchmdnet.module import LNNP from torchmdnet import datasets, priors, models from torchmdnet.data import DataModule @@ -14,7 +16,8 @@ from torchmdnet.models.model import create_prior_models from torchmdnet.models.utils import rbf_class_mapping, act_class_mapping, dtype_mapping from torchmdnet.utils import LoadFromFile, LoadFromCheckpoint, save_argparse, number -import torch +from lightning_utilities.core.rank_zero import rank_zero_warn + def get_args(): # fmt: off @@ -43,7 +46,7 @@ def get_args(): parser.add_argument('--train-size', type=number, default=None, help='Percentage/number of samples in training set (None to use all remaining samples)') parser.add_argument('--val-size', type=number, default=0.05, help='Percentage/number of samples in validation set (None to use all remaining samples)') parser.add_argument('--test-size', type=number, default=0.1, help='Percentage/number of samples in test set (None to use all remaining samples)') - parser.add_argument('--test-interval', type=int, default=10, help='Test interval, one test per n epochs (default: 10)') + parser.add_argument('--test-interval', type=int, default=-1, help='Test interval, one test per n epochs (default: 10)') parser.add_argument('--save-interval', type=int, default=10, help='Save interval, one save per n epochs (default: 10)') parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)') parser.add_argument('--num-workers', type=int, default=4, help='Number of workers for data prefetch') @@ -82,7 +85,7 @@ def get_args(): parser.add_argument('--distance-influence', type=str, default='both', choices=['keys', 'values', 'both', 'none'], help='Where distance information is included inside the attention') parser.add_argument('--attn-activation', default='silu', choices=list(act_class_mapping.keys()), help='Attention activation function') parser.add_argument('--num-heads', type=int, default=8, help='Number of attention heads') - + # TensorNet specific parser.add_argument('--equivariance-invariance-group', type=str, default='O(3)', help='Equivariance and invariance group of TensorNet') @@ -157,30 +160,42 @@ def main(): _logger.append(wandb_logger) if args.tensorboard_use: - tb_logger = pl.loggers.TensorBoardLogger( + tb_logger = TensorBoardLogger( args.log_dir, name="tensorbord", version="", default_hp_metric=False ) _logger.append(tb_logger) + if args.test_interval > 0: + rank_zero_warn( + f"WARNING: Test set will be evaluated every {args.test_interval} epochs. This will slow down training." + ) trainer = pl.Trainer( strategy=DDPStrategy(find_unused_parameters=False), max_epochs=args.num_epochs, - gpus=args.ngpus, + accelerator="auto", + devices=args.ngpus, num_nodes=args.num_nodes, default_root_dir=args.log_dir, - auto_lr_find=False, - resume_from_checkpoint=None if args.reset_trainer else args.load_model, callbacks=[early_stopping, checkpoint_callback], logger=_logger, precision=args.precision, gradient_clip_val=args.gradient_clipping, + inference_mode=False, + # Test-during-training requires reloading the dataloaders every epoch + reload_dataloaders_every_n_epochs=1 if args.test_interval > 0 else 0, ) - trainer.fit(model, data) + trainer.fit(model, data, ckpt_path=None if args.reset_trainer else args.load_model) # run test set after completing the fit model = LNNP.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) - trainer = pl.Trainer(logger=_logger) + trainer = pl.Trainer( + logger=_logger, + inference_mode=False, + accelerator="auto", + devices=args.ngpus, + num_nodes=args.num_nodes, + ) trainer.test(model, data) diff --git a/torchmdnet/utils.py b/torchmdnet/utils.py index ffe391922..d7810897b 100644 --- a/torchmdnet/utils.py +++ b/torchmdnet/utils.py @@ -3,7 +3,7 @@ import numpy as np import torch from os.path import dirname, join, exists -from pytorch_lightning.utilities import rank_zero_warn +from lightning_utilities.core.rank_zero import rank_zero_warn # fmt: off # Atomic masses are based on: