Skip to content

Commit

Permalink
switched to parametrized for unittests
Browse files Browse the repository at this point in the history
  • Loading branch information
bonevbs committed Sep 6, 2023
1 parent 94b05cb commit 8656e36
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 62 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@ jobs:
python -m pip install -e .
- name: Test with pytest
run: |
pip install pytest pytest-cov
pytest ./torch_harmonics/tests.py
pip install pytest pytest-cov parametrized
python -m unittest ./torch_harmonics/tests.py
103 changes: 43 additions & 60 deletions torch_harmonics/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#

import unittest
from parameterized import parameterized
import math
import numpy as np
import torch
Expand All @@ -45,10 +46,6 @@

class TestLegendrePolynomials(unittest.TestCase):

def __init__(self, testname, tol=1e-9):
super(TestLegendrePolynomials, self).__init__(testname)
self.tol = tol

def setUp(self):
self.cml = lambda m, l : np.sqrt((2*l + 1) / 4 / np.pi) * np.sqrt(math.factorial(l-m) / math.factorial(l+m))
self.pml = dict()
Expand All @@ -68,6 +65,8 @@ def setUp(self):

self.lmax = self.mmax = 4

self.tol = 1e-9

def test_legendre(self):
print("Testing computation of associated Legendre polynomials")
from torch_harmonics.legendre import precompute_legpoly
Expand All @@ -79,20 +78,10 @@ def test_legendre(self):
for m in range(l+1):
diff = pct[m, l].numpy() / self.cml(m,l) - self.pml[(m,l)](np.cos(t))
self.assertTrue(diff.max() <= self.tol)
print("done.")


class TestSphericalHarmonicTransform(unittest.TestCase):

def __init__(self, testname, nlat=12, nlon=24, batch_size=32, norm="ortho", grid="legendre-gauss", tol=1e-9):
super(TestSphericalHarmonicTransform, self).__init__(testname) # calling the super class init varies for different python versions. This works for 2.7
self.norm = norm
self.grid = grid
self.tol = tol
self.batch_size = batch_size
self.nlat = nlat
self.nlon = nlon

def setUp(self):

if torch.cuda.is_available():
Expand All @@ -102,28 +91,36 @@ def setUp(self):
print("Running test on CPU")
self.device = torch.device('cpu')

def test_sht(self):
print(f"Testing real-valued SHT on {self.nlat}x{self.nlon} {self.grid} grid with {self.norm} normalization")
@parameterized.expand([
[256, 512, 32, "ortho", "equiangular", 1e-9],
[256, 512, 32, "ortho", "legendre-gauss", 1e-9],
[256, 512, 32, "four-pi", "equiangular", 1e-9],
[256, 512, 32, "four-pi", "legendre-gauss", 1e-9],
[256, 512, 32, "schmidt", "equiangular", 1e-9],
[256, 512, 32, "schmidt", "legendre-gauss", 1e-9],
])
def test_sht(self, nlat, nlon, batch_size, norm, grid, tol):
print(f"Testing real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization")

testiters = [1, 2, 4, 8, 16]
if self.grid == "equiangular":
mmax = self.nlat // 2
if grid == "equiangular":
mmax = nlat // 2
else:
mmax = self.nlat
mmax = nlat
lmax = mmax

sht = RealSHT(self.nlat, self.nlon, mmax=mmax, lmax=lmax, grid=self.grid, norm=self.norm).to(self.device)
isht = InverseRealSHT(self.nlat, self.nlon, mmax=mmax, lmax=lmax, grid=self.grid, norm=self.norm).to(self.device)
sht = RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
isht = InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)

with torch.no_grad():
coeffs = torch.zeros(self.batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
coeffs[:, :lmax, :mmax] = torch.randn(self.batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
coeffs = torch.zeros(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
coeffs[:, :lmax, :mmax] = torch.randn(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
signal = isht(coeffs)

# testing error accumulation
for iter in testiters:
with self.subTest(i = iter):
print(f"{iter} iterations of batchsize {self.batch_size}:")
print(f"{iter} iterations of batchsize {batch_size}:")

base = signal

Expand All @@ -132,52 +129,38 @@ def test_sht(self):

err = torch.mean(torch.norm(base-signal, p='fro', dim=(-1,-2)) / torch.norm(signal, p='fro', dim=(-1,-2)) )
print(f"final relative error: {err.item()}")
self.assertTrue(err.item() <= self.tol)

def test_sht_grad(self):
print(f"Testing gradients of real-valued SHT on {self.nlat}x{self.nlon} {self.grid} grid with {self.norm} normalization")

if self.grid == "equiangular":
mmax = self.nlat // 2
self.assertTrue(err.item() <= tol)

@parameterized.expand([
[12, 24, 2, "ortho", "equiangular", 1e-5],
[12, 24, 2, "ortho", "legendre-gauss", 1e-5],
[12, 24, 2, "four-pi", "equiangular", 1e-5],
[12, 24, 2, "four-pi", "legendre-gauss", 1e-5],
[12, 24, 2, "schmidt", "equiangular", 1e-5],
[12, 24, 2, "schmidt", "legendre-gauss", 1e-5],
])
def test_sht_grad(self, nlat, nlon, batch_size, norm, grid, tol):
print(f"Testing gradients of real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization")

if grid == "equiangular":
mmax = nlat // 2
else:
mmax = self.nlat
mmax = nlat
lmax = mmax

sht = RealSHT(self.nlat, self.nlon, mmax=mmax, lmax=lmax, grid=self.grid, norm=self.norm).to(self.device)
isht = InverseRealSHT(self.nlat, self.nlon, mmax=mmax, lmax=lmax, grid=self.grid, norm=self.norm).to(self.device)
sht = RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
isht = InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)

with torch.no_grad():
coeffs = torch.zeros(self.batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
coeffs[:, :lmax, :mmax] = torch.randn(self.batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
coeffs = torch.zeros(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
coeffs[:, :lmax, :mmax] = torch.randn(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
signal = isht(coeffs)

input = torch.randn_like(signal, requires_grad=True)
err_handle = lambda x : torch.mean(torch.norm( isht(sht(x)) - signal , p='fro', dim=(-1,-2)) / torch.norm(signal, p='fro', dim=(-1,-2)) )
test_result = gradcheck(err_handle, input, eps=1e-6, atol=self.tol)
test_result = gradcheck(err_handle, input, eps=1e-6, atol=tol)
self.assertTrue(test_result)


if __name__ == '__main__':
sht_test_suite = unittest.TestSuite()

# test computation of Legendre polynomials
sht_test_suite.addTest(TestLegendrePolynomials('test_legendre', tol=1e-9))

# test error growth when computing repeatedly isht(sht(x))
sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht', norm="ortho", grid="equiangular", tol=1e-1, nlat=256, nlon=512, batch_size=16))
sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht', norm="ortho", grid="legendre-gauss", tol=1e-9, nlat=256, nlon=512, batch_size=16))
sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht', norm="four-pi", grid="equiangular", tol=1e-1, nlat=256, nlon=512, batch_size=16))
sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht', norm="four-pi", grid="legendre-gauss", tol=1e-9, nlat=256, nlon=512, batch_size=16))
sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht', norm="schmidt", grid="equiangular", tol=1e-1, nlat=256, nlon=512, batch_size=16))
sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht', norm="schmidt", grid="legendre-gauss", tol=1e-9, nlat=256, nlon=512, batch_size=16))

# test error growth when computing repeatedly isht(sht(x))
sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht_grad', norm="ortho", grid="equiangular", tol=1e-4, nlat=12, nlon=24, batch_size=2))
sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht_grad', norm="ortho", grid="legendre-gauss", tol=1e-4, nlat=12, nlon=24, batch_size=2))
sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht_grad', norm="four-pi", grid="equiangular", tol=1e-4, nlat=12, nlon=24, batch_size=2))
sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht_grad', norm="four-pi", grid="legendre-gauss", tol=1e-4, nlat=12, nlon=24, batch_size=2))
sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht_grad', norm="schmidt", grid="equiangular", tol=1e-4, nlat=12, nlon=24, batch_size=2))
sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht_grad', norm="schmidt", grid="legendre-gauss", tol=1e-4, nlat=12, nlon=24, batch_size=2))

# run the test suite
unittest.TextTestRunner(verbosity=2).run(sht_test_suite)
unittest.main()

0 comments on commit 8656e36

Please sign in to comment.