From 94b05cb9210196981de0201c8f01af849ac2962e Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Wed, 6 Sep 2023 16:20:00 +0200 Subject: [PATCH] fixed warning in unit tests --- torch_harmonics/tests.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/torch_harmonics/tests.py b/torch_harmonics/tests.py index 0e0e7e1..58bd449 100644 --- a/torch_harmonics/tests.py +++ b/torch_harmonics/tests.py @@ -29,7 +29,8 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # -import unittest, argparse +import unittest +import math import numpy as np import torch from torch.autograd import gradcheck @@ -49,7 +50,7 @@ def __init__(self, testname, tol=1e-9): self.tol = tol def setUp(self): - self.cml = lambda m, l : np.sqrt((2*l + 1) / 4 / np.pi) * np.sqrt(np.math.factorial(l-m) / np.math.factorial(l+m)) + self.cml = lambda m, l : np.sqrt((2*l + 1) / 4 / np.pi) * np.sqrt(math.factorial(l-m) / math.factorial(l+m)) self.pml = dict() # preparing associated Legendre Polynomials (These include the Condon-Shortley phase) @@ -133,7 +134,6 @@ def test_sht(self): print(f"final relative error: {err.item()}") self.assertTrue(err.item() <= self.tol) - # @unittest.skipIf(mylib.__version__ < (1, 3), "skipping slow tests") def test_sht_grad(self): print(f"Testing gradients of real-valued SHT on {self.nlat}x{self.nlon} {self.grid} grid with {self.norm} normalization") @@ -158,22 +158,18 @@ def test_sht_grad(self): if __name__ == '__main__': - # parser = argparse.ArgumentParser() - # parser.add_argument("--run_slow", action='store_true', help="Run the slow tests.") - # args = parser.parse_args() - sht_test_suite = unittest.TestSuite() # test computation of Legendre polynomials sht_test_suite.addTest(TestLegendrePolynomials('test_legendre', tol=1e-9)) # test error growth when computing repeatedly isht(sht(x)) - sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht', norm="ortho", grid="equiangular", tol=1e-1)) - sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht', norm="ortho", grid="legendre-gauss", tol=1e-9)) - sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht', norm="four-pi", grid="equiangular", tol=1e-1)) - sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht', norm="four-pi", grid="legendre-gauss", tol=1e-9)) - sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht', norm="schmidt", grid="equiangular", tol=1e-1)) - sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht', norm="schmidt", grid="legendre-gauss", tol=1e-9)) + sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht', norm="ortho", grid="equiangular", tol=1e-1, nlat=256, nlon=512, batch_size=16)) + sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht', norm="ortho", grid="legendre-gauss", tol=1e-9, nlat=256, nlon=512, batch_size=16)) + sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht', norm="four-pi", grid="equiangular", tol=1e-1, nlat=256, nlon=512, batch_size=16)) + sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht', norm="four-pi", grid="legendre-gauss", tol=1e-9, nlat=256, nlon=512, batch_size=16)) + sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht', norm="schmidt", grid="equiangular", tol=1e-1, nlat=256, nlon=512, batch_size=16)) + sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht', norm="schmidt", grid="legendre-gauss", tol=1e-9, nlat=256, nlon=512, batch_size=16)) # test error growth when computing repeatedly isht(sht(x)) sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht_grad', norm="ortho", grid="equiangular", tol=1e-4, nlat=12, nlon=24, batch_size=2))