From cec07d7a3ee9d056212b752baeee02c37e9a43d9 Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Wed, 6 Sep 2023 17:50:47 +0200 Subject: [PATCH] Bbonev/gradcheck (#9) * Added gradient check to test suite * reduced size of the unit test * switched to parametrized for unittests --- .github/workflows/tests.yml | 6 +- Changelog.md | 4 ++ README.md | 6 +- torch_harmonics/__init__.py | 2 +- torch_harmonics/tests.py | 120 ++++++++++++++++++------------------ 5 files changed, 70 insertions(+), 68 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index eb80a75..e522315 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -1,4 +1,4 @@ -name: Unit tests +name: tests on: [push] @@ -22,5 +22,5 @@ jobs: python -m pip install -e . - name: Test with pytest run: | - pip install pytest pytest-cov - pytest ./torch_harmonics/tests.py \ No newline at end of file + python -m pip install pytest pytest-cov parameterized + python -m pytest ./torch_harmonics/tests.py \ No newline at end of file diff --git a/Changelog.md b/Changelog.md index eaf0b0b..60c00e1 100644 --- a/Changelog.md +++ b/Changelog.md @@ -2,6 +2,10 @@ ## Versioning +### v0.6.3 + +* Adding gradient check in unit tests + ### v0.6.2 * Adding github CI diff --git a/README.md b/README.md index 96e5dd0..de97ef8 100644 --- a/README.md +++ b/README.md @@ -43,11 +43,11 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -`torch-harmonics` is a differentiable implementation of the Spherical Harmonic transform in PyTorch. It was originally implemented to enable Spherical Fourier Neural Operators (SFNO). It uses quadrature rules to compute the projection onto the associated Legendre polynomials and FFTs for the projection onto the harmonic basis. This algorithm tends to outperform others with better asymptotic scaling for most practical purposes. +torch-harmonics is a differentiable implementation of the Spherical Harmonic transform in PyTorch. It was originally implemented to enable Spherical Fourier Neural Operators (SFNO). It uses quadrature rules to compute the projection onto the associated Legendre polynomials and FFTs for the projection onto the harmonic basis. This algorithm tends to outperform others with better asymptotic scaling for most practical purposes. -`torch-harmonics` uses PyTorch primitives to implement these operations, making it fully differentiable. Moreover, the quadrature can be distributed onto multiple ranks making it spatially distributed. +torch-harmonics uses PyTorch primitives to implement these operations, making it fully differentiable. Moreover, the quadrature can be distributed onto multiple ranks making it spatially distributed. -`torch-harmonics` has been used to implement a variety of differentiable PDE solvers which generated the animations below. Moreover, it has enabled the development of Spherical Fourier Neural Operators (SFNOs) [1]. +torch-harmonics has been used to implement a variety of differentiable PDE solvers which generated the animations below. Moreover, it has enabled the development of Spherical Fourier Neural Operators (SFNOs) [1]. diff --git a/torch_harmonics/__init__.py b/torch_harmonics/__init__.py index 94e8dca..69a98f3 100644 --- a/torch_harmonics/__init__.py +++ b/torch_harmonics/__init__.py @@ -29,7 +29,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # -__version__ = '0.6.2' +__version__ = '0.6.3' from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT from . import quadrature diff --git a/torch_harmonics/tests.py b/torch_harmonics/tests.py index 9b6063d..3de8b75 100644 --- a/torch_harmonics/tests.py +++ b/torch_harmonics/tests.py @@ -30,8 +30,11 @@ # import unittest +from parameterized import parameterized +import math import numpy as np import torch +from torch.autograd import gradcheck from torch_harmonics import * # try: @@ -44,7 +47,7 @@ class TestLegendrePolynomials(unittest.TestCase): def setUp(self): - self.cml = lambda m, l : np.sqrt((2*l + 1) / 4 / np.pi) * np.sqrt(np.math.factorial(l-m) / np.math.factorial(l+m)) + self.cml = lambda m, l : np.sqrt((2*l + 1) / 4 / np.pi) * np.sqrt(math.factorial(l-m) / math.factorial(l+m)) self.pml = dict() # preparing associated Legendre Polynomials (These include the Condon-Shortley phase) @@ -62,28 +65,23 @@ def setUp(self): self.lmax = self.mmax = 4 + self.tol = 1e-9 + def test_legendre(self): print("Testing computation of associated Legendre polynomials") from torch_harmonics.legendre import precompute_legpoly - TOL = 1e-9 - t = np.linspace(0, np.pi, 100) pct = precompute_legpoly(self.mmax, self.lmax, t) for l in range(self.lmax): for m in range(l+1): diff = pct[m, l].numpy() / self.cml(m,l) - self.pml[(m,l)](np.cos(t)) - self.assertTrue(diff.max() <= TOL) - print("done.") + self.assertTrue(diff.max() <= self.tol) class TestSphericalHarmonicTransform(unittest.TestCase): - def __init__(self, testname, norm="ortho"): - super(TestSphericalHarmonicTransform, self).__init__(testname) # calling the super class init varies for different python versions. This works for 2.7 - self.norm = norm - def setUp(self): if torch.cuda.is_available(): @@ -93,76 +91,76 @@ def setUp(self): print("Running test on CPU") self.device = torch.device('cpu') - self.batch_size = 128 - self.nlat = 256 - self.nlon = 2*self.nlat + @parameterized.expand([ + [256, 512, 32, "ortho", "equiangular", 1e-9], + [256, 512, 32, "ortho", "legendre-gauss", 1e-9], + [256, 512, 32, "four-pi", "equiangular", 1e-9], + [256, 512, 32, "four-pi", "legendre-gauss", 1e-9], + [256, 512, 32, "schmidt", "equiangular", 1e-9], + [256, 512, 32, "schmidt", "legendre-gauss", 1e-9], + ]) + def test_sht(self, nlat, nlon, batch_size, norm, grid, tol): + print(f"Testing real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization") - def test_sht_leggauss(self): - print(f"Testing real-valued SHT on Legendre-Gauss grid with {self.norm} normalization") - - TOL = 1e-9 testiters = [1, 2, 4, 8, 16] - mmax = self.nlat + if grid == "equiangular": + mmax = nlat // 2 + else: + mmax = nlat lmax = mmax - sht = RealSHT(self.nlat, self.nlon, mmax=mmax, lmax=lmax, grid="legendre-gauss", norm=self.norm).to(self.device) - isht = InverseRealSHT(self.nlat, self.nlon, mmax=mmax, lmax=lmax, grid="legendre-gauss", norm=self.norm).to(self.device) + sht = RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device) + isht = InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device) - coeffs = torch.zeros(self.batch_size, lmax, mmax, device=self.device, dtype=torch.complex128) - coeffs[:, :lmax, :mmax] = torch.randn(self.batch_size, lmax, mmax, device=self.device, dtype=torch.complex128) - signal = isht(coeffs) + with torch.no_grad(): + coeffs = torch.zeros(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128) + coeffs[:, :lmax, :mmax] = torch.randn(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128) + signal = isht(coeffs) + # testing error accumulation for iter in testiters: with self.subTest(i = iter): - print(f"{iter} iterations of batchsize {self.batch_size}:") + print(f"{iter} iterations of batchsize {batch_size}:") base = signal for _ in tqdm(range(iter)): base = isht(sht(base)) - # err = ( torch.norm(base-self.signal, p='fro') / torch.norm(self.signal, p='fro') ).item() - err = torch.mean(torch.norm(base-signal, p='fro', dim=(-1,-2)) / torch.norm(signal, p='fro', dim=(-1,-2)) ).item() - print(f"final relative error: {err}") - self.assertTrue(err <= TOL) - - def test_sht_equiangular(self): - print(f"Testing real-valued SHT on equiangular grid with {self.norm} normalization") - - TOL = 1e-1 - testiters = [1, 2, 4, 8] - mmax = self.nlat // 2 + err = torch.mean(torch.norm(base-signal, p='fro', dim=(-1,-2)) / torch.norm(signal, p='fro', dim=(-1,-2)) ) + print(f"final relative error: {err.item()}") + self.assertTrue(err.item() <= tol) + + @parameterized.expand([ + [12, 24, 2, "ortho", "equiangular", 1e-5], + [12, 24, 2, "ortho", "legendre-gauss", 1e-5], + [12, 24, 2, "four-pi", "equiangular", 1e-5], + [12, 24, 2, "four-pi", "legendre-gauss", 1e-5], + [12, 24, 2, "schmidt", "equiangular", 1e-5], + [12, 24, 2, "schmidt", "legendre-gauss", 1e-5], + ]) + def test_sht_grad(self, nlat, nlon, batch_size, norm, grid, tol): + print(f"Testing gradients of real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization") + + if grid == "equiangular": + mmax = nlat // 2 + else: + mmax = nlat lmax = mmax - sht = RealSHT(self.nlat, self.nlon, mmax=mmax, lmax=lmax, grid="equiangular", norm=self.norm).to(self.device) - isht = InverseRealSHT(self.nlat, self.nlon, mmax=mmax, lmax=lmax, grid="equiangular", norm=self.norm).to(self.device) + sht = RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device) + isht = InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device) - coeffs = torch.zeros(self.batch_size, sht.lmax, sht.mmax, device=self.device, dtype=torch.complex128) - coeffs[:, :lmax, :mmax] = torch.randn(self.batch_size, lmax, mmax, device=self.device, dtype=torch.complex128) - signal = isht(coeffs) + with torch.no_grad(): + coeffs = torch.zeros(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128) + coeffs[:, :lmax, :mmax] = torch.randn(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128) + signal = isht(coeffs) - for iter in testiters: - with self.subTest(i = iter): - print(f"{iter} iterations of batchsize {self.batch_size}:") - - base = signal - - for _ in tqdm(range(iter)): - base = isht(sht(base)) - - # err = ( torch.norm(base-self.signal, p='fro') / torch.norm(self.signal, p='fro') ).item() - err = torch.mean(torch.norm(base-signal, p='fro', dim=(-1,-2)) / torch.norm(signal, p='fro', dim=(-1,-2)) ).item() - print(f"final relative error: {err}") - self.assertTrue(err <= TOL) + input = torch.randn_like(signal, requires_grad=True) + err_handle = lambda x : torch.mean(torch.norm( isht(sht(x)) - signal , p='fro', dim=(-1,-2)) / torch.norm(signal, p='fro', dim=(-1,-2)) ) + test_result = gradcheck(err_handle, input, eps=1e-6, atol=tol) + self.assertTrue(test_result) if __name__ == '__main__': - sht_test_suite = unittest.TestSuite() - sht_test_suite.addTest(TestLegendrePolynomials('test_legendre')) - sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht_leggauss', norm="ortho")) - sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht_equiangular', norm="ortho")) - sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht_leggauss', norm="four-pi")) - sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht_equiangular', norm="four-pi")) - sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht_leggauss', norm="schmidt")) - sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht_equiangular', norm="schmidt")) - unittest.TextTestRunner(verbosity=2).run(sht_test_suite) \ No newline at end of file + unittest.main() \ No newline at end of file