Skip to content

Commit

Permalink
Bbonev/gradient test fix (#53)
Browse files Browse the repository at this point in the history
* added analytic gradients to the gradient_analysis notebook

* fixing sht unittest to not check the roundtrip gradient but sht and isht individually

* Updated changelog
  • Loading branch information
bonevbs authored Sep 20, 2024
1 parent 60b3b5a commit 4fea88b
Show file tree
Hide file tree
Showing 3 changed files with 234 additions and 117 deletions.
1 change: 1 addition & 0 deletions Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

* Added resampling modules for convenience
* Changing behavior of distributed SHT to use `dim=-3` as channel dimension
* Fixing SHT unittests to test SHT and ISHT individually, rather than the roundtrip

### v0.7.1

Expand Down
248 changes: 173 additions & 75 deletions notebooks/gradient_analysis.ipynb

Large diffs are not rendered by default.

102 changes: 60 additions & 42 deletions tests/test_sht.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
Expand Down Expand Up @@ -37,24 +37,25 @@
from torch.autograd import gradcheck
from torch_harmonics import *


class TestLegendrePolynomials(unittest.TestCase):

def setUp(self):
self.cml = lambda m, l : np.sqrt((2*l + 1) / 4 / np.pi) * np.sqrt(math.factorial(l-m) / math.factorial(l+m))
self.cml = lambda m, l: np.sqrt((2 * l + 1) / 4 / np.pi) * np.sqrt(math.factorial(l - m) / math.factorial(l + m))
self.pml = dict()

# preparing associated Legendre Polynomials (These include the Condon-Shortley phase)
# for reference see e.g. https://en.wikipedia.org/wiki/Associated_Legendre_polynomials
self.pml[(0, 0)] = lambda x : np.ones_like(x)
self.pml[(0, 1)] = lambda x : x
self.pml[(1, 1)] = lambda x : - np.sqrt(1. - x**2)
self.pml[(0, 2)] = lambda x : 0.5 * (3*x**2 - 1)
self.pml[(1, 2)] = lambda x : - 3 * x * np.sqrt(1. - x**2)
self.pml[(2, 2)] = lambda x : 3 * (1 - x**2)
self.pml[(0, 3)] = lambda x : 0.5 * (5*x**3 - 3*x)
self.pml[(1, 3)] = lambda x : 1.5 * (1 - 5*x**2) * np.sqrt(1. - x**2)
self.pml[(2, 3)] = lambda x : 15 * x * (1 - x**2)
self.pml[(3, 3)] = lambda x : -15 * np.sqrt(1. - x**2)**3
self.pml[(0, 0)] = lambda x: np.ones_like(x)
self.pml[(0, 1)] = lambda x: x
self.pml[(1, 1)] = lambda x: -np.sqrt(1.0 - x**2)
self.pml[(0, 2)] = lambda x: 0.5 * (3 * x**2 - 1)
self.pml[(1, 2)] = lambda x: -3 * x * np.sqrt(1.0 - x**2)
self.pml[(2, 2)] = lambda x: 3 * (1 - x**2)
self.pml[(0, 3)] = lambda x: 0.5 * (5 * x**3 - 3 * x)
self.pml[(1, 3)] = lambda x: 1.5 * (1 - 5 * x**2) * np.sqrt(1.0 - x**2)
self.pml[(2, 3)] = lambda x: 15 * x * (1 - x**2)
self.pml[(3, 3)] = lambda x: -15 * np.sqrt(1.0 - x**2) ** 3

self.lmax = self.mmax = 4

Expand All @@ -68,8 +69,8 @@ def test_legendre(self):
vdm = legpoly(self.mmax, self.lmax, t)

for l in range(self.lmax):
for m in range(l+1):
diff = vdm[m, l] / self.cml(m,l) - self.pml[(m,l)](t)
for m in range(l + 1):
diff = vdm[m, l] / self.cml(m, l) - self.pml[(m, l)](t)
self.assertTrue(diff.max() <= self.tol)


Expand All @@ -79,19 +80,21 @@ def setUp(self):

if torch.cuda.is_available():
print("Running test on GPU")
self.device = torch.device('cuda')
self.device = torch.device("cuda")
else:
print("Running test on CPU")
self.device = torch.device('cpu')

@parameterized.expand([
[256, 512, 32, "ortho", "equiangular", 1e-9],
[256, 512, 32, "ortho", "legendre-gauss", 1e-9],
[256, 512, 32, "four-pi", "equiangular", 1e-9],
[256, 512, 32, "four-pi", "legendre-gauss", 1e-9],
[256, 512, 32, "schmidt", "equiangular", 1e-9],
[256, 512, 32, "schmidt", "legendre-gauss", 1e-9],
])
self.device = torch.device("cpu")

@parameterized.expand(
[
[256, 512, 32, "ortho", "equiangular", 1e-9],
[256, 512, 32, "ortho", "legendre-gauss", 1e-9],
[256, 512, 32, "four-pi", "equiangular", 1e-9],
[256, 512, 32, "four-pi", "legendre-gauss", 1e-9],
[256, 512, 32, "schmidt", "equiangular", 1e-9],
[256, 512, 32, "schmidt", "legendre-gauss", 1e-9],
]
)
def test_sht(self, nlat, nlon, batch_size, norm, grid, tol):
print(f"Testing real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization")

Expand All @@ -109,30 +112,38 @@ def test_sht(self, nlat, nlon, batch_size, norm, grid, tol):
coeffs = torch.zeros(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
coeffs[:, :lmax, :mmax] = torch.randn(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
signal = isht(coeffs)

# testing error accumulation
for iter in testiters:
with self.subTest(i = iter):
with self.subTest(i=iter):
print(f"{iter} iterations of batchsize {batch_size}:")

base = signal

for _ in range(iter):
base = isht(sht(base))
err = torch.mean(torch.norm(base-signal, p='fro', dim=(-1,-2)) / torch.norm(signal, p='fro', dim=(-1,-2)) )

err = torch.mean(torch.norm(base - signal, p="fro", dim=(-1, -2)) / torch.norm(signal, p="fro", dim=(-1, -2)))
print(f"final relative error: {err.item()}")
self.assertTrue(err.item() <= tol)

@parameterized.expand([
[12, 24, 2, "ortho", "equiangular", 1e-5],
[12, 24, 2, "ortho", "legendre-gauss", 1e-5],
[12, 24, 2, "four-pi", "equiangular", 1e-5],
[12, 24, 2, "four-pi", "legendre-gauss", 1e-5],
[12, 24, 2, "schmidt", "equiangular", 1e-5],
[12, 24, 2, "schmidt", "legendre-gauss", 1e-5],
])
def test_sht_grad(self, nlat, nlon, batch_size, norm, grid, tol):
@parameterized.expand(
[
[12, 24, 2, "ortho", "equiangular", 1e-5],
[12, 24, 2, "ortho", "legendre-gauss", 1e-5],
[12, 24, 2, "four-pi", "equiangular", 1e-5],
[12, 24, 2, "four-pi", "legendre-gauss", 1e-5],
[12, 24, 2, "schmidt", "equiangular", 1e-5],
[12, 24, 2, "schmidt", "legendre-gauss", 1e-5],
[15, 30, 2, "ortho", "equiangular", 1e-5],
[15, 30, 2, "ortho", "legendre-gauss", 1e-5],
[15, 30, 2, "four-pi", "equiangular", 1e-5],
[15, 30, 2, "four-pi", "legendre-gauss", 1e-5],
[15, 30, 2, "schmidt", "equiangular", 1e-5],
[15, 30, 2, "schmidt", "legendre-gauss", 1e-5],
]
)
def test_sht_grads(self, nlat, nlon, batch_size, norm, grid, tol):
print(f"Testing gradients of real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization")

if grid == "equiangular":
Expand All @@ -148,12 +159,19 @@ def test_sht_grad(self, nlat, nlon, batch_size, norm, grid, tol):
coeffs = torch.zeros(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
coeffs[:, :lmax, :mmax] = torch.randn(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
signal = isht(coeffs)


# test the sht
grad_input = torch.randn_like(signal, requires_grad=True)
err_handle = lambda x : torch.mean(torch.norm( isht(sht(x)) - signal , p='fro', dim=(-1,-2)) / torch.norm(signal, p='fro', dim=(-1,-2)) )
err_handle = lambda x: torch.mean(torch.norm(sht(x) - coeffs, p="fro", dim=(-1, -2)) / torch.norm(coeffs, p="fro", dim=(-1, -2)))
test_result = gradcheck(err_handle, grad_input, eps=1e-6, atol=tol)
self.assertTrue(test_result)

# test the isht
grad_input = torch.randn_like(coeffs, requires_grad=True)
err_handle = lambda x: torch.mean(torch.norm(isht(x) - signal, p="fro", dim=(-1, -2)) / torch.norm(signal, p="fro", dim=(-1, -2)))
test_result = gradcheck(err_handle, grad_input, eps=1e-6, atol=tol)
self.assertTrue(test_result)


if __name__ == '__main__':
unittest.main()
if __name__ == "__main__":
unittest.main()

0 comments on commit 4fea88b

Please sign in to comment.