Skip to content

Commit

Permalink
Updating to Lightning 2.0 (#210)
Browse files Browse the repository at this point in the history
* Update to lightning 2.0

* Update env

* Update env

* Add pydantic<2

* Update

* Update train.py

* Update a couple imports

* Update

* Update

* Update

* Default test_interval to -1, print warning if its positive

* Reproduce the trick used before to test during training
Alas, it requires to reload the dataloaders every epoch when test_interval>0

* Blacken

* Blacken

* Small update

* Fix typo

* Fix default reset_dataloaders_every_n_epochs

* Add l1_loss to validation

* Report val_loss, train_loss and test_loss with the same name as before
Change the order of operations to reproduce trains done with previous
versions more closely

* blacken

* Use zero_rank_warn

* Set inference_mode to false during testing

* Add more arguments to testing Trainer

* Add a test result log
Always log losses on y and neg_dy even if they are weighted 0 for the
total loss

* Change lightning.pytorch to just lightning

* Change a comprehension into an if/else

* Remove spurious self. in _get_mean_losses

* Use defaultdict

* Update to lightning 2.0.8

* Blacken
  • Loading branch information
RaulPPelaez authored Sep 6, 2023
1 parent dca6679 commit ac16c09
Show file tree
Hide file tree
Showing 10 changed files with 200 additions and 153 deletions.
3 changes: 2 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ dependencies:
- pytorch_geometric==2.3.1
- pytorch_scatter==2.1.1
- pytorch_sparse==0.6.17
- pytorch-lightning==1.6.3
- lightning==2.0.8
- pydantic<2
- torchmetrics==0.11.4
- tqdm
# Dev tools
Expand Down
2 changes: 1 addition & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pickle
from os.path import exists, dirname, join
import torch
import pytorch_lightning as pl
import lightning as pl
from torchmdnet import models
from torchmdnet.models.model import create_model
from torchmdnet.models import output_modules
Expand Down
4 changes: 2 additions & 2 deletions tests/test_module.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pytest import mark
from glob import glob
from os.path import dirname, join
import pytorch_lightning as pl
import lightning as pl
from torchmdnet import models
from torchmdnet.models.model import load_model
from torchmdnet.priors import Atomref
Expand Down Expand Up @@ -49,6 +49,6 @@ def test_train(model_name, use_atomref, precision, tmpdir):

module = LNNP(args, prior_model=prior)

trainer = pl.Trainer(max_steps=10, default_root_dir=tmpdir, precision=args["precision"])
trainer = pl.Trainer(max_steps=10, default_root_dir=tmpdir, precision=args["precision"],inference_mode=False)
trainer.fit(module, datamodule)
trainer.test(module, datamodule)
2 changes: 1 addition & 1 deletion tests/test_priors.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
from pytest import mark
import torch
import pytorch_lightning as pl
import lightning as pl
from torchmdnet import models
from torchmdnet.models.model import create_model, create_prior_models
from torchmdnet.module import LNNP
Expand Down
25 changes: 14 additions & 11 deletions torchmdnet/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import torch
from torch.utils.data import Subset
from torch_geometric.loader import DataLoader
from pytorch_lightning import LightningDataModule
from pytorch_lightning.utilities import rank_zero_warn
from lightning import LightningDataModule
from lightning_utilities.core.rank_zero import rank_zero_warn
from torchmdnet import datasets
from torch_geometric.data import Dataset
from torchmdnet.utils import make_splits, MissingEnergyException
Expand Down Expand Up @@ -92,10 +92,10 @@ def train_dataloader(self):

def val_dataloader(self):
loaders = [self._get_dataloader(self.val_dataset, "val")]
if (
len(self.test_dataset) > 0
and (self.trainer.current_epoch + 1) % self.hparams["test_interval"] == 0
):
# To allow to report the performance on the testing dataset during training
# we send the trainer two dataloaders every few steps and modify the
# validation step to understand the second dataloader as test data.
if self._is_test_during_training_epoch():
loaders.append(self._get_dataloader(self.test_dataset, "test"))
return loaders

Expand All @@ -116,13 +116,16 @@ def mean(self):
def std(self):
return self._std

def _get_dataloader(self, dataset, stage, store_dataloader=True):
store_dataloader = (
store_dataloader and self.trainer.reload_dataloaders_every_n_epochs <= 0
def _is_test_during_training_epoch(self):
return (
len(self.test_dataset) > 0
and self.hparams["test_interval"] > 0
and self.trainer.current_epoch > 0
and self.trainer.current_epoch % self.hparams["test_interval"] == 0
)

def _get_dataloader(self, dataset, stage, store_dataloader=True):
if stage in self._saved_dataloaders and store_dataloader:
# storing the dataloaders like this breaks calls to trainer.reload_train_val_dataloaders
# but makes it possible that the dataloaders are not recreated on every testing epoch
return self._saved_dataloaders[stage]

if stage == "train":
Expand Down
2 changes: 1 addition & 1 deletion torchmdnet/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from torch.autograd import grad
from torch import nn, Tensor
from torch_scatter import scatter
from pytorch_lightning.utilities import rank_zero_warn
from torchmdnet.models import output_modules
from torchmdnet.models.wrappers import AtomFilter
from torchmdnet.models.utils import dtype_mapping
from torchmdnet import priors
from lightning_utilities.core.rank_zero import rank_zero_warn
import warnings


Expand Down
Loading

0 comments on commit ac16c09

Please sign in to comment.