Skip to content

Commit

Permalink
fixed warning in unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bonevbs committed Sep 6, 2023
1 parent a645311 commit 94b05cb
Showing 1 changed file with 9 additions and 13 deletions.
22 changes: 9 additions & 13 deletions torch_harmonics/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import unittest, argparse
import unittest
import math
import numpy as np
import torch
from torch.autograd import gradcheck
Expand All @@ -49,7 +50,7 @@ def __init__(self, testname, tol=1e-9):
self.tol = tol

def setUp(self):
self.cml = lambda m, l : np.sqrt((2*l + 1) / 4 / np.pi) * np.sqrt(np.math.factorial(l-m) / np.math.factorial(l+m))
self.cml = lambda m, l : np.sqrt((2*l + 1) / 4 / np.pi) * np.sqrt(math.factorial(l-m) / math.factorial(l+m))
self.pml = dict()

# preparing associated Legendre Polynomials (These include the Condon-Shortley phase)
Expand Down Expand Up @@ -133,7 +134,6 @@ def test_sht(self):
print(f"final relative error: {err.item()}")
self.assertTrue(err.item() <= self.tol)

# @unittest.skipIf(mylib.__version__ < (1, 3), "skipping slow tests")
def test_sht_grad(self):
print(f"Testing gradients of real-valued SHT on {self.nlat}x{self.nlon} {self.grid} grid with {self.norm} normalization")

Expand All @@ -158,22 +158,18 @@ def test_sht_grad(self):


if __name__ == '__main__':
# parser = argparse.ArgumentParser()
# parser.add_argument("--run_slow", action='store_true', help="Run the slow tests.")
# args = parser.parse_args()

sht_test_suite = unittest.TestSuite()

# test computation of Legendre polynomials
sht_test_suite.addTest(TestLegendrePolynomials('test_legendre', tol=1e-9))

# test error growth when computing repeatedly isht(sht(x))
sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht', norm="ortho", grid="equiangular", tol=1e-1))
sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht', norm="ortho", grid="legendre-gauss", tol=1e-9))
sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht', norm="four-pi", grid="equiangular", tol=1e-1))
sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht', norm="four-pi", grid="legendre-gauss", tol=1e-9))
sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht', norm="schmidt", grid="equiangular", tol=1e-1))
sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht', norm="schmidt", grid="legendre-gauss", tol=1e-9))
sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht', norm="ortho", grid="equiangular", tol=1e-1, nlat=256, nlon=512, batch_size=16))
sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht', norm="ortho", grid="legendre-gauss", tol=1e-9, nlat=256, nlon=512, batch_size=16))
sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht', norm="four-pi", grid="equiangular", tol=1e-1, nlat=256, nlon=512, batch_size=16))
sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht', norm="four-pi", grid="legendre-gauss", tol=1e-9, nlat=256, nlon=512, batch_size=16))
sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht', norm="schmidt", grid="equiangular", tol=1e-1, nlat=256, nlon=512, batch_size=16))
sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht', norm="schmidt", grid="legendre-gauss", tol=1e-9, nlat=256, nlon=512, batch_size=16))

# test error growth when computing repeatedly isht(sht(x))
sht_test_suite.addTest(TestSphericalHarmonicTransform('test_sht_grad', norm="ortho", grid="equiangular", tol=1e-4, nlat=12, nlon=24, batch_size=2))
Expand Down

0 comments on commit 94b05cb

Please sign in to comment.