Skip to content

Commit

Permalink
Bbonev/gradcheck (#9)
Browse files Browse the repository at this point in the history
* Added gradient check to test suite
* reduced size of the unit test
* switched to parametrized for unittests
  • Loading branch information
bonevbs authored Sep 6, 2023
1 parent 17eefa5 commit cec07d7
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 68 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Unit tests
name: tests

on: [push]

Expand All @@ -22,5 +22,5 @@ jobs:
python -m pip install -e .
- name: Test with pytest
run: |
pip install pytest pytest-cov
pytest ./torch_harmonics/tests.py
python -m pip install pytest pytest-cov parameterized
python -m pytest ./torch_harmonics/tests.py
4 changes: 4 additions & 0 deletions Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

## Versioning

### v0.6.3

* Adding gradient check in unit tests

### v0.6.2

* Adding github CI
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

<!-- ## What is torch-harmonics? -->

`torch-harmonics` is a differentiable implementation of the Spherical Harmonic transform in PyTorch. It was originally implemented to enable Spherical Fourier Neural Operators (SFNO). It uses quadrature rules to compute the projection onto the associated Legendre polynomials and FFTs for the projection onto the harmonic basis. This algorithm tends to outperform others with better asymptotic scaling for most practical purposes.
torch-harmonics is a differentiable implementation of the Spherical Harmonic transform in PyTorch. It was originally implemented to enable Spherical Fourier Neural Operators (SFNO). It uses quadrature rules to compute the projection onto the associated Legendre polynomials and FFTs for the projection onto the harmonic basis. This algorithm tends to outperform others with better asymptotic scaling for most practical purposes.

`torch-harmonics` uses PyTorch primitives to implement these operations, making it fully differentiable. Moreover, the quadrature can be distributed onto multiple ranks making it spatially distributed.
torch-harmonics uses PyTorch primitives to implement these operations, making it fully differentiable. Moreover, the quadrature can be distributed onto multiple ranks making it spatially distributed.

`torch-harmonics` has been used to implement a variety of differentiable PDE solvers which generated the animations below. Moreover, it has enabled the development of Spherical Fourier Neural Operators (SFNOs) [1].
torch-harmonics has been used to implement a variety of differentiable PDE solvers which generated the animations below. Moreover, it has enabled the development of Spherical Fourier Neural Operators (SFNOs) [1].


<table border="0" cellspacing="0" cellpadding="0">
Expand Down
2 changes: 1 addition & 1 deletion torch_harmonics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

__version__ = '0.6.2'
__version__ = '0.6.3'

from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT
from . import quadrature
Expand Down
120 changes: 59 additions & 61 deletions torch_harmonics/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,11 @@
#

import unittest
from parameterized import parameterized
import math
import numpy as np
import torch
from torch.autograd import gradcheck
from torch_harmonics import *

# try:
Expand All @@ -44,7 +47,7 @@
class TestLegendrePolynomials(unittest.TestCase):

def setUp(self):
self.cml = lambda m, l : np.sqrt((2*l + 1) / 4 / np.pi) * np.sqrt(np.math.factorial(l-m) / np.math.factorial(l+m))
self.cml = lambda m, l : np.sqrt((2*l + 1) / 4 / np.pi) * np.sqrt(math.factorial(l-m) / math.factorial(l+m))
self.pml = dict()

# preparing associated Legendre Polynomials (These include the Condon-Shortley phase)
Expand All @@ -62,28 +65,23 @@ def setUp(self):

self.lmax = self.mmax = 4

self.tol = 1e-9

def test_legendre(self):
print("Testing computation of associated Legendre polynomials")
from torch_harmonics.legendre import precompute_legpoly

TOL = 1e-9

t = np.linspace(0, np.pi, 100)
pct = precompute_legpoly(self.mmax, self.lmax, t)

for l in range(self.lmax):
for m in range(l+1):
diff = pct[m, l].numpy() / self.cml(m,l) - self.pml[(m,l)](np.cos(t))
self.assertTrue(diff.max() <= TOL)
print("done.")
self.assertTrue(diff.max() <= self.tol)


class TestSphericalHarmonicTransform(unittest.TestCase):

def __init__(self, testname, norm="ortho"):
super(TestSphericalHarmonicTransform, self).__init__(testname) # calling the super class init varies for different python versions. This works for 2.7
self.norm = norm

def setUp(self):

if torch.cuda.is_available():
Expand All @@ -93,76 +91,76 @@ def setUp(self):
print("Running test on CPU")
self.device = torch.device('cpu')

self.batch_size = 128
self.nlat = 256
self.nlon = 2*self.nlat
@parameterized.expand([
[256, 512, 32, "ortho", "equiangular", 1e-9],
[256, 512, 32, "ortho", "legendre-gauss", 1e-9],
[256, 512, 32, "four-pi", "equiangular", 1e-9],
[256, 512, 32, "four-pi", "legendre-gauss", 1e-9],
[256, 512, 32, "schmidt", "equiangular", 1e-9],
[256, 512, 32, "schmidt", "legendre-gauss", 1e-9],
])
def test_sht(self, nlat, nlon, batch_size, norm, grid, tol):
print(f"Testing real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization")

def test_sht_leggauss(self):
print(f"Testing real-valued SHT on Legendre-Gauss grid with {self.norm} normalization")

TOL = 1e-9
testiters = [1, 2, 4, 8, 16]
mmax = self.nlat
if grid == "equiangular":
mmax = nlat // 2
else:
mmax = nlat
lmax = mmax

sht = RealSHT(self.nlat, self.nlon, mmax=mmax, lmax=lmax, grid="legendre-gauss", norm=self.norm).to(self.device)
isht = InverseRealSHT(self.nlat, self.nlon, mmax=mmax, lmax=lmax, grid="legendre-gauss", norm=self.norm).to(self.device)
sht = RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
isht = InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)

coeffs = torch.zeros(self.batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
coeffs[:, :lmax, :mmax] = torch.randn(self.batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
signal = isht(coeffs)
with torch.no_grad():
coeffs = torch.zeros(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
coeffs[:, :lmax, :mmax] = torch.randn(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
signal = isht(coeffs)

# testing error accumulation
for iter in testiters:
with self.subTest(i = iter):
print(f"{iter} iterations of batchsize {self.batch_size}:")
print(f"{iter} iterations of batchsize {batch_size}:")

base = signal

for _ in tqdm(range(iter)):
base = isht(sht(base))

# err = ( torch.norm(base-self.signal, p='fro') / torch.norm(self.signal, p='fro') ).item()
err = torch.mean(torch.norm(base-signal, p='fro', dim=(-1,-2)) / torch.norm(signal, p='fro', dim=(-1,-2)) ).item()
print(f"final relative error: {err}")
self.assertTrue(err <= TOL)

def test_sht_equiangular(self):
print(f"Testing real-valued SHT on equiangular grid with {self.norm} normalization")

TOL = 1e-1
testiters = [1, 2, 4, 8]
mmax = self.nlat // 2
err = torch.mean(torch.norm(base-signal, p='fro', dim=(-1,-2)) / torch.norm(signal, p='fro', dim=(-1,-2)) )
print(f"final relative error: {err.item()}")
self.assertTrue(err.item() <= tol)

@parameterized.expand([
[12, 24, 2, "ortho", "equiangular", 1e-5],
[12, 24, 2, "ortho", "legendre-gauss", 1e-5],
[12, 24, 2, "four-pi", "equiangular", 1e-5],
[12, 24, 2, "four-pi", "legendre-gauss", 1e-5],
[12, 24, 2, "schmidt", "equiangular", 1e-5],
[12, 24, 2, "schmidt", "legendre-gauss", 1e-5],
])
def test_sht_grad(self, nlat, nlon, batch_size, norm, grid, tol):
print(f"Testing gradients of real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization")

if grid == "equiangular":
mmax = nlat // 2
else:
mmax = nlat
lmax = mmax

sht = RealSHT(self.nlat, self.nlon, mmax=mmax, lmax=lmax, grid="equiangular", norm=self.norm).to(self.device)
isht = InverseRealSHT(self.nlat, self.nlon, mmax=mmax, lmax=lmax, grid="equiangular", norm=self.norm).to(self.device)
sht = RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
isht = InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)

coeffs = torch.zeros(self.batch_size, sht.lmax, sht.mmax, device=self.device, dtype=torch.complex128)
coeffs[:, :lmax, :mmax] = torch.randn(self.batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
signal = isht(coeffs)
with torch.no_grad():
coeffs = torch.zeros(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
coeffs[:, :lmax, :mmax] = torch.randn(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
signal = isht(coeffs)

for iter in testiters:
with self.subTest(i = iter):
print(f"{iter} iterations of batchsize {self.batch_size}:")

base = signal

for _ in tqdm(range(iter)):
base = isht(sht(base))

# err = ( torch.norm(base-self.signal, p='fro') / torch.norm(self.signal, p='fro') ).item()
err = torch.mean(torch.norm(base-signal, p='fro', dim=(-1,-2)) / torch.norm(signal, p='fro', dim=(-1,-2)) ).item()
print(f"final relative error: {err}")
self.assertTrue(err <= TOL)
input = torch.randn_like(signal, requires_grad=True)
err_handle = lambda x : torch.mean(torch.norm( isht(sht(x)) - signal , p='fro', dim=(-1,-2)) / torch.norm(signal, p='fro', dim=(-1,-2)) )
test_result = gradcheck(err_handle, input, eps=1e-6, atol=tol)
self.assertTrue(test_result)


if __name__ == '__main__':
sht_test_suite = unittest.TestSuite()
sht_test_suite.addTest(TestLegendrePolynomials('test_legendre'))
sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht_leggauss', norm="ortho"))
sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht_equiangular', norm="ortho"))
sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht_leggauss', norm="four-pi"))
sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht_equiangular', norm="four-pi"))
sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht_leggauss', norm="schmidt"))
sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht_equiangular', norm="schmidt"))
unittest.TextTestRunner(verbosity=2).run(sht_test_suite)
unittest.main()

0 comments on commit cec07d7

Please sign in to comment.