Skip to content

Commit

Permalink
Merge pull request #36 from NVIDIA/bbonev/disco-filterbasis-bugfixes
Browse files Browse the repository at this point in the history
Bbonev/disco filterbasis bugfixes
  • Loading branch information
azrael417 authored Apr 26, 2024
2 parents 3680ab6 + 40d5569 commit fcc1aa2
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 112 deletions.
85 changes: 57 additions & 28 deletions tests/test_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from torch.autograd import gradcheck
from torch_harmonics import *

from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes

def _compute_vals_isotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, r_cutoff: float):
"""
Expand All @@ -54,13 +55,12 @@ def _compute_vals_isotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, r_cutof
else:
ir = (ikernel + 0.5) * dr

norm_factor = 2 * math.pi * (1 - math.cos(r_cutoff - dr) + math.cos(r_cutoff - dr) + (math.sin(r_cutoff - dr) - math.sin(r_cutoff)) / dr)

vals = torch.where(
((r - ir).abs() <= dr) & (r <= r_cutoff),
(1 - (r - ir).abs() / dr) / norm_factor,
(1 - (r - ir).abs() / dr),
0,
)

return vals


Expand All @@ -82,21 +82,19 @@ def _compute_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, nphi:
ir = (ikernel // nphi + 0.5) * dr
iphi = (ikernel % nphi) * dphi

norm_factor = 2 * math.pi * (1 - math.cos(r_cutoff - dr) + math.cos(r_cutoff - dr) + (math.sin(r_cutoff - dr) - math.sin(r_cutoff)) / dr)

# compute the value of the filter
if nr % 2 == 1:
# find the indices where the rotated position falls into the support of the kernel
cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff)
cond_phi = ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi)
r_vals = torch.where(cond_r, (1 - (r - ir).abs() / dr) / norm_factor, 0.0)
r_vals = torch.where(cond_r, (1 - (r - ir).abs() / dr) , 0.0)
phi_vals = torch.where(cond_phi, (1 - torch.minimum((phi - iphi).abs(), (2 * math.pi - (phi - iphi).abs())) / dphi), 0.0)
vals = torch.where(ikernel > 0, r_vals * phi_vals, r_vals)
else:
# find the indices where the rotated position falls into the support of the kernel
cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff)
cond_phi = ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi)
r_vals = torch.where(cond_r, (1 - (r - ir).abs() / dr) / norm_factor, 0.0)
r_vals = torch.where(cond_r, (1 - (r - ir).abs() / dr), 0.0)
phi_vals = torch.where(cond_phi, (1 - torch.minimum((phi - iphi).abs(), (2 * math.pi - (phi - iphi).abs())) / dphi), 0.0)
vals = r_vals * phi_vals

Expand All @@ -105,22 +103,40 @@ def _compute_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, nphi:
phin = torch.where(phi + math.pi >= 2*math.pi, phi - math.pi, phi + math.pi)
cond_rn = ((rn - ir).abs() <= dr) & (rn <= r_cutoff)
cond_phin = ((phin - iphi).abs() <= dphi) | ((2 * math.pi - (phin - iphi).abs()) <= dphi)
rn_vals = torch.where(cond_rn, (1 - (rn - ir).abs() / dr) / norm_factor, 0.0)
rn_vals = torch.where(cond_rn, (1 - (rn - ir).abs() / dr), 0.0)
phin_vals = torch.where(cond_phin, (1 - torch.minimum((phin - iphi).abs(), (2 * math.pi - (phin - iphi).abs())) / dphi), 0.0)
vals += rn_vals * phin_vals


return vals

def _normalize_convolution_tensor_dense(psi, quad_weights, transpose_normalization=False, eps=1e-9):
"""
Discretely normalizes the convolution tensor.
"""

kernel_size, nlat_out, nlon_out, nlat_in, nlon_in = psi.shape
scale_factor = float(nlon_in // nlon_out)

if transpose_normalization:
# the normalization is not quite symmetric due to the compressed way psi is stored in the main code
# look at the normalization code in the actual implementation
psi_norm = torch.sum(quad_weights.reshape(1, -1, 1, 1, 1) * psi[:,:,:1], dim=(1, 4), keepdim=True) / scale_factor
else:
psi_norm = torch.sum(quad_weights.reshape(1, 1, 1, -1, 1) * psi, dim=(3, 4), keepdim=True)

def _precompute_convolution_tensor_dense(in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi):
return psi / (psi_norm + eps)


def _precompute_convolution_tensor_dense(in_shape, out_shape, kernel_shape, quad_weights, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi, transpose_normalization=False):
"""
Helper routine to compute the convolution Tensor in a dense fashion
"""

assert len(in_shape) == 2
assert len(out_shape) == 2

quad_weights = quad_weights.reshape(-1, 1)

if len(kernel_shape) == 1:
kernel_handle = partial(_compute_vals_isotropic, nr=kernel_shape[0], r_cutoff=theta_cutoff)
kernel_size = math.ceil( kernel_shape[0] / 2)
Expand Down Expand Up @@ -170,6 +186,9 @@ def _precompute_convolution_tensor_dense(in_shape, out_shape, kernel_shape, grid
# find the indices where the rotated position falls into the support of the kernel
out[:, t, p, :, :] = kernel_handle(theta, phi)

# take care of normalization
out = _normalize_convolution_tensor_dense(out, quad_weights=quad_weights, transpose_normalization=transpose_normalization)

return out


Expand All @@ -181,28 +200,31 @@ def setUp(self):
torch.cuda.manual_seed(333)
else:
self.device = torch.device("cpu")

torch.manual_seed(333)

@parameterized.expand(
[
# regular convolution
[8, 4, 2, (16, 32), (16, 32), [3], "equiangular", "equiangular", False, 5e-5],
[8, 4, 2, (16, 32), (8, 16), [5], "equiangular", "equiangular", False, 5e-5],
[8, 4, 2, (16, 32), (8, 16), [3, 3], "equiangular", "equiangular", False, 5e-5],
[8, 4, 2, (16, 32), (8, 16), [2, 3], "equiangular", "equiangular", False, 5e-5],
[8, 4, 2, (18, 36), (6, 12), [7], "equiangular", "equiangular", False, 5e-5],
[8, 4, 2, (16, 32), (8, 16), [5], "equiangular", "legendre-gauss", False, 5e-5],
[8, 4, 2, (16, 32), (8, 16), [5], "legendre-gauss", "equiangular", False, 5e-5],
[8, 4, 2, (16, 32), (8, 16), [5], "legendre-gauss", "legendre-gauss", False, 5e-5],
[8, 4, 2, (16, 32), (16, 32), [3], "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (16, 32), (8, 16), [5], "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (16, 32), (8, 16), [3, 3], "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (16, 32), (8, 16), [4, 3], "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (16, 24), (8, 8), [3], "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (18, 36), (6, 12), [7], "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (16, 32), (8, 16), [5], "equiangular", "legendre-gauss", False, 1e-4],
[8, 4, 2, (16, 32), (8, 16), [5], "legendre-gauss", "equiangular", False, 1e-4],
[8, 4, 2, (16, 32), (8, 16), [5], "legendre-gauss", "legendre-gauss", False, 1e-4],
# transpose convolution
[8, 4, 2, (16, 32), (16, 32), [3], "equiangular", "equiangular", True, 5e-5],
[8, 4, 2, (8, 16), (16, 32), [5], "equiangular", "equiangular", True, 5e-5],
[8, 4, 2, (8, 16), (16, 32), [3, 3], "equiangular", "equiangular", True, 5e-5],
[8, 4, 2, (8, 16), (16, 32), [2, 3], "equiangular", "equiangular", True, 5e-5],
[8, 4, 2, (6, 12), (18, 36), [7], "equiangular", "equiangular", True, 5e-5],
[8, 4, 2, (8, 16), (16, 32), [5], "equiangular", "legendre-gauss", True, 5e-5],
[8, 4, 2, (8, 16), (16, 32), [5], "legendre-gauss", "equiangular", True, 5e-5],
[8, 4, 2, (8, 16), (16, 32), [5], "legendre-gauss", "legendre-gauss", True, 5e-5],
[8, 4, 2, (16, 32), (16, 32), [3], "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (8, 16), (16, 32), [5], "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (8, 16), (16, 32), [3, 3], "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (8, 16), (16, 32), [4, 3], "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (8, 8), (16, 24), [3], "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (6, 12), (18, 36), [7], "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (8, 16), (16, 32), [5], "equiangular", "legendre-gauss", True, 1e-4],
[8, 4, 2, (8, 16), (16, 32), [5], "legendre-gauss", "equiangular", True, 1e-4],
[8, 4, 2, (8, 16), (16, 32), [5], "legendre-gauss", "legendre-gauss", True, 1e-4],
]
)
def test_disco_convolution(
Expand Down Expand Up @@ -237,10 +259,17 @@ def test_disco_convolution(
theta_cutoff=theta_cutoff
).to(self.device)

_, wgl = _precompute_latitudes(nlat_in, grid=grid_in)
quad_weights = 2.0 * torch.pi * torch.from_numpy(wgl).float().reshape(-1, 1) / nlon_in

if transpose:
psi_dense = _precompute_convolution_tensor_dense(out_shape, in_shape, kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff).to(self.device)
psi_dense = _precompute_convolution_tensor_dense(out_shape, in_shape, kernel_shape, quad_weights, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff, transpose_normalization=True).to(self.device)

psi = torch.sparse_coo_tensor(conv.psi_idx, conv.psi_vals, size=(conv.kernel_size, conv.nlat_in, conv.nlat_out * conv.nlon_out)).to_dense()

self.assertTrue(torch.allclose(psi, psi_dense[:, :, 0].reshape(-1, nlat_in, nlat_out * nlon_out)))
else:
psi_dense = _precompute_convolution_tensor_dense(in_shape, out_shape, kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff).to(self.device)
psi_dense = _precompute_convolution_tensor_dense(in_shape, out_shape, kernel_shape, quad_weights, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff, transpose_normalization=False).to(self.device)

psi = torch.sparse_coo_tensor(conv.psi_idx, conv.psi_vals, size=(conv.kernel_size, conv.nlat_out, conv.nlat_in * conv.nlon_in)).to_dense()

Expand Down
2 changes: 1 addition & 1 deletion torch_harmonics/_disco_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nl
kernel_size, nlat_out, n_out = psi.shape

assert n_out % nlon_out == 0
assert nlon_out >= nlat_in
assert nlon_out >= nlon_in
pscale = nlon_out // nlon_in

# interleave zeros along the longitude dimension to allow for fractional offsets to be considered
Expand Down
Loading

0 comments on commit fcc1aa2

Please sign in to comment.