From ea809845ddbd14fb7b3f623dd5086e435eada7ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Th=C3=B6lke?= Date: Sun, 6 Nov 2022 16:10:19 -0500 Subject: [PATCH] Add missing __all__ attribute --- tests/test_priors.py | 8 ++++++++ torchmdnet/priors/__init__.py | 2 ++ 2 files changed, 10 insertions(+) diff --git a/tests/test_priors.py b/tests/test_priors.py index 31481233d..d5547919c 100644 --- a/tests/test_priors.py +++ b/tests/test_priors.py @@ -4,6 +4,7 @@ import pytorch_lightning as pl from torchmdnet import models from torchmdnet.models.model import create_model +from torchmdnet import priors from torchmdnet.priors import Atomref from torch_scatter import scatter from utils import load_example_args, create_example_batch, DummyDataset @@ -31,3 +32,10 @@ def test_atomref(model_name): # check if the output of both models differs by the expected atomref contribution expected_offset = scatter(dataset.get_atomref().squeeze()[z], batch).unsqueeze(1) torch.testing.assert_allclose(x_atomref, x_no_atomref + expected_offset) + + +def test_prior_list(): + # make sure the priors submodule defines the __all__ attribute + assert hasattr(priors, "__all__") + # make sure it's not empty + assert len(priors.__all__) > 0 diff --git a/torchmdnet/priors/__init__.py b/torchmdnet/priors/__init__.py index 50c134b30..93aa7cea7 100644 --- a/torchmdnet/priors/__init__.py +++ b/torchmdnet/priors/__init__.py @@ -1 +1,3 @@ from torchmdnet.priors.atomref import Atomref + +__all__ = ["Atomref"]