diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 52c2616..14981e9 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -22,5 +22,5 @@ jobs: python -m pip install -e . - name: Test with pytest run: | - pip install pytest pytest-cov - pytest ./torch_harmonics/tests.py \ No newline at end of file + pip install pytest pytest-cov parametrized + python -m unittest ./torch_harmonics/tests.py \ No newline at end of file diff --git a/torch_harmonics/tests.py b/torch_harmonics/tests.py index 58bd449..3de8b75 100644 --- a/torch_harmonics/tests.py +++ b/torch_harmonics/tests.py @@ -30,6 +30,7 @@ # import unittest +from parameterized import parameterized import math import numpy as np import torch @@ -45,10 +46,6 @@ class TestLegendrePolynomials(unittest.TestCase): - def __init__(self, testname, tol=1e-9): - super(TestLegendrePolynomials, self).__init__(testname) - self.tol = tol - def setUp(self): self.cml = lambda m, l : np.sqrt((2*l + 1) / 4 / np.pi) * np.sqrt(math.factorial(l-m) / math.factorial(l+m)) self.pml = dict() @@ -68,6 +65,8 @@ def setUp(self): self.lmax = self.mmax = 4 + self.tol = 1e-9 + def test_legendre(self): print("Testing computation of associated Legendre polynomials") from torch_harmonics.legendre import precompute_legpoly @@ -79,20 +78,10 @@ def test_legendre(self): for m in range(l+1): diff = pct[m, l].numpy() / self.cml(m,l) - self.pml[(m,l)](np.cos(t)) self.assertTrue(diff.max() <= self.tol) - print("done.") class TestSphericalHarmonicTransform(unittest.TestCase): - def __init__(self, testname, nlat=12, nlon=24, batch_size=32, norm="ortho", grid="legendre-gauss", tol=1e-9): - super(TestSphericalHarmonicTransform, self).__init__(testname) # calling the super class init varies for different python versions. This works for 2.7 - self.norm = norm - self.grid = grid - self.tol = tol - self.batch_size = batch_size - self.nlat = nlat - self.nlon = nlon - def setUp(self): if torch.cuda.is_available(): @@ -102,28 +91,36 @@ def setUp(self): print("Running test on CPU") self.device = torch.device('cpu') - def test_sht(self): - print(f"Testing real-valued SHT on {self.nlat}x{self.nlon} {self.grid} grid with {self.norm} normalization") + @parameterized.expand([ + [256, 512, 32, "ortho", "equiangular", 1e-9], + [256, 512, 32, "ortho", "legendre-gauss", 1e-9], + [256, 512, 32, "four-pi", "equiangular", 1e-9], + [256, 512, 32, "four-pi", "legendre-gauss", 1e-9], + [256, 512, 32, "schmidt", "equiangular", 1e-9], + [256, 512, 32, "schmidt", "legendre-gauss", 1e-9], + ]) + def test_sht(self, nlat, nlon, batch_size, norm, grid, tol): + print(f"Testing real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization") testiters = [1, 2, 4, 8, 16] - if self.grid == "equiangular": - mmax = self.nlat // 2 + if grid == "equiangular": + mmax = nlat // 2 else: - mmax = self.nlat + mmax = nlat lmax = mmax - sht = RealSHT(self.nlat, self.nlon, mmax=mmax, lmax=lmax, grid=self.grid, norm=self.norm).to(self.device) - isht = InverseRealSHT(self.nlat, self.nlon, mmax=mmax, lmax=lmax, grid=self.grid, norm=self.norm).to(self.device) + sht = RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device) + isht = InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device) with torch.no_grad(): - coeffs = torch.zeros(self.batch_size, lmax, mmax, device=self.device, dtype=torch.complex128) - coeffs[:, :lmax, :mmax] = torch.randn(self.batch_size, lmax, mmax, device=self.device, dtype=torch.complex128) + coeffs = torch.zeros(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128) + coeffs[:, :lmax, :mmax] = torch.randn(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128) signal = isht(coeffs) # testing error accumulation for iter in testiters: with self.subTest(i = iter): - print(f"{iter} iterations of batchsize {self.batch_size}:") + print(f"{iter} iterations of batchsize {batch_size}:") base = signal @@ -132,52 +129,38 @@ def test_sht(self): err = torch.mean(torch.norm(base-signal, p='fro', dim=(-1,-2)) / torch.norm(signal, p='fro', dim=(-1,-2)) ) print(f"final relative error: {err.item()}") - self.assertTrue(err.item() <= self.tol) - - def test_sht_grad(self): - print(f"Testing gradients of real-valued SHT on {self.nlat}x{self.nlon} {self.grid} grid with {self.norm} normalization") - - if self.grid == "equiangular": - mmax = self.nlat // 2 + self.assertTrue(err.item() <= tol) + + @parameterized.expand([ + [12, 24, 2, "ortho", "equiangular", 1e-5], + [12, 24, 2, "ortho", "legendre-gauss", 1e-5], + [12, 24, 2, "four-pi", "equiangular", 1e-5], + [12, 24, 2, "four-pi", "legendre-gauss", 1e-5], + [12, 24, 2, "schmidt", "equiangular", 1e-5], + [12, 24, 2, "schmidt", "legendre-gauss", 1e-5], + ]) + def test_sht_grad(self, nlat, nlon, batch_size, norm, grid, tol): + print(f"Testing gradients of real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization") + + if grid == "equiangular": + mmax = nlat // 2 else: - mmax = self.nlat + mmax = nlat lmax = mmax - sht = RealSHT(self.nlat, self.nlon, mmax=mmax, lmax=lmax, grid=self.grid, norm=self.norm).to(self.device) - isht = InverseRealSHT(self.nlat, self.nlon, mmax=mmax, lmax=lmax, grid=self.grid, norm=self.norm).to(self.device) + sht = RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device) + isht = InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device) with torch.no_grad(): - coeffs = torch.zeros(self.batch_size, lmax, mmax, device=self.device, dtype=torch.complex128) - coeffs[:, :lmax, :mmax] = torch.randn(self.batch_size, lmax, mmax, device=self.device, dtype=torch.complex128) + coeffs = torch.zeros(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128) + coeffs[:, :lmax, :mmax] = torch.randn(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128) signal = isht(coeffs) input = torch.randn_like(signal, requires_grad=True) err_handle = lambda x : torch.mean(torch.norm( isht(sht(x)) - signal , p='fro', dim=(-1,-2)) / torch.norm(signal, p='fro', dim=(-1,-2)) ) - test_result = gradcheck(err_handle, input, eps=1e-6, atol=self.tol) + test_result = gradcheck(err_handle, input, eps=1e-6, atol=tol) self.assertTrue(test_result) if __name__ == '__main__': - sht_test_suite = unittest.TestSuite() - - # test computation of Legendre polynomials - sht_test_suite.addTest(TestLegendrePolynomials('test_legendre', tol=1e-9)) - - # test error growth when computing repeatedly isht(sht(x)) - sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht', norm="ortho", grid="equiangular", tol=1e-1, nlat=256, nlon=512, batch_size=16)) - sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht', norm="ortho", grid="legendre-gauss", tol=1e-9, nlat=256, nlon=512, batch_size=16)) - sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht', norm="four-pi", grid="equiangular", tol=1e-1, nlat=256, nlon=512, batch_size=16)) - sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht', norm="four-pi", grid="legendre-gauss", tol=1e-9, nlat=256, nlon=512, batch_size=16)) - sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht', norm="schmidt", grid="equiangular", tol=1e-1, nlat=256, nlon=512, batch_size=16)) - sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht', norm="schmidt", grid="legendre-gauss", tol=1e-9, nlat=256, nlon=512, batch_size=16)) - - # test error growth when computing repeatedly isht(sht(x)) - sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht_grad', norm="ortho", grid="equiangular", tol=1e-4, nlat=12, nlon=24, batch_size=2)) - sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht_grad', norm="ortho", grid="legendre-gauss", tol=1e-4, nlat=12, nlon=24, batch_size=2)) - sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht_grad', norm="four-pi", grid="equiangular", tol=1e-4, nlat=12, nlon=24, batch_size=2)) - sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht_grad', norm="four-pi", grid="legendre-gauss", tol=1e-4, nlat=12, nlon=24, batch_size=2)) - sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht_grad', norm="schmidt", grid="equiangular", tol=1e-4, nlat=12, nlon=24, batch_size=2)) - sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht_grad', norm="schmidt", grid="legendre-gauss", tol=1e-4, nlat=12, nlon=24, batch_size=2)) - - # run the test suite - unittest.TextTestRunner(verbosity=2).run(sht_test_suite) \ No newline at end of file + unittest.main() \ No newline at end of file