Skip to content

Commit

Permalink
Moving conversion from numpy to torch from the legendre module to the…
Browse files Browse the repository at this point in the history
… sht module
  • Loading branch information
bonevbs committed Oct 20, 2023
1 parent bd99bfb commit b5b9d6e
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 6 deletions.
8 changes: 6 additions & 2 deletions torch_harmonics/distributed/distributed_sht.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho
# combine quadrature weights with the legendre weights
weights = torch.from_numpy(w)
pct = _precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
pct = torch.from_numpy(pct)
weights = torch.einsum('mlk,k->mlk', pct, weights)

# we need to split in m, pad before:
Expand Down Expand Up @@ -256,6 +257,7 @@ def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho

# compute legende polynomials
pct = _precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
pct = torch.from_numpy(pct)

# split in m
pct = F.pad(pct, [0, 0, 0, 0, 0, self.mpad], mode="constant")
Expand Down Expand Up @@ -405,6 +407,7 @@ def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho

weights = torch.from_numpy(w)
dpct = _precompute_dlegpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
dpct = torch.from_numpy(dpct)

# combine integration weights, normalization factor in to one:
l = torch.arange(0, self.lmax)
Expand Down Expand Up @@ -567,10 +570,11 @@ def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho

# compute legende polynomials
dpct = _precompute_dlegpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
dpct = torch.from_numpy(dpct)

# split in m
pct = F.pad(pct, [0, 0, 0, 0, 0, self.mpad], mode="constant")
pct = torch.split(pct, (self.mmax+self.mpad) // self.comm_size_azimuth, dim=0)[self.comm_rank_azimuth]
dpct = F.pad(dpct, [0, 0, 0, 0, 0, self.mpad], mode="constant")
dpct = torch.split(dpct, (self.mmax+self.mpad) // self.comm_size_azimuth, dim=0)[self.comm_rank_azimuth]

# register buffer
self.register_buffer('dpct', dpct, persistent=False)
Expand Down
5 changes: 2 additions & 3 deletions torch_harmonics/legendre.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
#

import numpy as np
import torch

def clm(l, m):
"""
Expand Down Expand Up @@ -85,7 +84,7 @@ def legpoly(mmax, lmax, x, norm="ortho", inverse=False, csphase=True):
for m in range(1, mmax, 2):
vdm[m] *= -1

return torch.from_numpy(vdm)
return vdm

def _precompute_legpoly(mmax, lmax, t, norm="ortho", inverse=False, csphase=True):
r"""
Expand Down Expand Up @@ -115,7 +114,7 @@ def _precompute_dlegpoly(mmax, lmax, t, norm="ortho", inverse=False, csphase=Tru

pct = _precompute_legpoly(mmax+1, lmax+1, t, norm=norm, inverse=inverse, csphase=False)

dpct = torch.zeros((2, mmax, lmax, len(t)), dtype=torch.float64)
dpct = np.zeros((2, mmax, lmax, len(t)), dtype=np.float64)

# fill the derivative terms wrt theta
for l in range(0, lmax):
Expand Down
4 changes: 4 additions & 0 deletions torch_harmonics/sht.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho
# combine quadrature weights with the legendre weights
weights = torch.from_numpy(w)
pct = _precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
pct = torch.from_numpy(pct)
weights = torch.einsum('mlk,k->mlk', pct, weights)

# remember quadrature weights
Expand Down Expand Up @@ -167,6 +168,7 @@ def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho
self.mmax = mmax or self.nlon // 2 + 1

pct = _precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
pct = torch.from_numpy(pct)

# register buffer
self.register_buffer('pct', pct, persistent=False)
Expand Down Expand Up @@ -246,6 +248,7 @@ def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho

weights = torch.from_numpy(w)
dpct = _precompute_dlegpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
dpct = torch.from_numpy(dpct)

# combine integration weights, normalization factor in to one:
l = torch.arange(0, self.lmax)
Expand Down Expand Up @@ -338,6 +341,7 @@ def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho
self.mmax = mmax or self.nlon // 2 + 1

dpct = _precompute_dlegpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
dpct = torch.from_numpy(dpct)

# register weights
self.register_buffer('dpct', dpct, persistent=False)
Expand Down
2 changes: 1 addition & 1 deletion torch_harmonics/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def test_legendre(self):

for l in range(self.lmax):
for m in range(l+1):
diff = pct[m, l].numpy() / self.cml(m,l) - self.pml[(m,l)](t)
diff = pct[m, l] / self.cml(m,l) - self.pml[(m,l)](t)
self.assertTrue(diff.max() <= self.tol)


Expand Down

0 comments on commit b5b9d6e

Please sign in to comment.