From 8c99244d72e1c17406850bb2818ef8b7f871f55b Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Thu, 25 Apr 2024 00:40:39 +0200 Subject: [PATCH 1/4] reworked normalization of filter basis functions --- tests/test_convolution.py | 47 +++++++++------ torch_harmonics/convolution.py | 103 +++++++++++---------------------- 2 files changed, 63 insertions(+), 87 deletions(-) diff --git a/tests/test_convolution.py b/tests/test_convolution.py index ada388a..b85f21e 100644 --- a/tests/test_convolution.py +++ b/tests/test_convolution.py @@ -38,8 +38,9 @@ from torch.autograd import gradcheck from torch_harmonics import * +from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes -def _compute_vals_isotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, r_cutoff: float): +def _compute_vals_isotropic(r: torch.Tensor, phi: torch.Tensor, quad_weights: torch.Tensor, nr: int, r_cutoff: float): """ helper routine to compute the values of the isotropic kernel densely """ @@ -54,17 +55,20 @@ def _compute_vals_isotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, r_cutof else: ir = (ikernel + 0.5) * dr - norm_factor = 2 * math.pi * (1 - math.cos(r_cutoff - dr) + math.cos(r_cutoff - dr) + (math.sin(r_cutoff - dr) - math.sin(r_cutoff)) / dr) - vals = torch.where( ((r - ir).abs() <= dr) & (r <= r_cutoff), - (1 - (r - ir).abs() / dr) / norm_factor, + (1 - (r - ir).abs() / dr), 0, ) + + # apply normalization + vnorm = torch.sum(vals * quad_weights, dim=(-1, -2), keepdim=True) + vals = vals / vnorm + return vals -def _compute_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, nphi: int, r_cutoff: float): +def _compute_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, quad_weights: torch.Tensor, nr: int, nphi: int, r_cutoff: float): """ helper routine to compute the values of the anisotropic kernel densely """ @@ -82,21 +86,19 @@ def _compute_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, nphi: ir = (ikernel // nphi + 0.5) * dr iphi = (ikernel % nphi) * dphi - norm_factor = 2 * math.pi * (1 - math.cos(r_cutoff - dr) + math.cos(r_cutoff - dr) + (math.sin(r_cutoff - dr) - math.sin(r_cutoff)) / dr) - # compute the value of the filter if nr % 2 == 1: # find the indices where the rotated position falls into the support of the kernel cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff) cond_phi = ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi) - r_vals = torch.where(cond_r, (1 - (r - ir).abs() / dr) / norm_factor, 0.0) + r_vals = torch.where(cond_r, (1 - (r - ir).abs() / dr) , 0.0) phi_vals = torch.where(cond_phi, (1 - torch.minimum((phi - iphi).abs(), (2 * math.pi - (phi - iphi).abs())) / dphi), 0.0) vals = torch.where(ikernel > 0, r_vals * phi_vals, r_vals) else: # find the indices where the rotated position falls into the support of the kernel cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff) cond_phi = ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi) - r_vals = torch.where(cond_r, (1 - (r - ir).abs() / dr) / norm_factor, 0.0) + r_vals = torch.where(cond_r, (1 - (r - ir).abs() / dr), 0.0) phi_vals = torch.where(cond_phi, (1 - torch.minimum((phi - iphi).abs(), (2 * math.pi - (phi - iphi).abs())) / dphi), 0.0) vals = r_vals * phi_vals @@ -105,15 +107,18 @@ def _compute_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, nphi: phin = torch.where(phi + math.pi >= 2*math.pi, phi - math.pi, phi + math.pi) cond_rn = ((rn - ir).abs() <= dr) & (rn <= r_cutoff) cond_phin = ((phin - iphi).abs() <= dphi) | ((2 * math.pi - (phin - iphi).abs()) <= dphi) - rn_vals = torch.where(cond_rn, (1 - (rn - ir).abs() / dr) / norm_factor, 0.0) + rn_vals = torch.where(cond_rn, (1 - (rn - ir).abs() / dr), 0.0) phin_vals = torch.where(cond_phin, (1 - torch.minimum((phin - iphi).abs(), (2 * math.pi - (phin - iphi).abs())) / dphi), 0.0) vals += rn_vals * phin_vals + # apply normalization + vnorm = torch.sum(vals * quad_weights, dim=(-1, -2), keepdim=True) + vals = vals / vnorm return vals -def _precompute_convolution_tensor_dense(in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi): +def _precompute_convolution_tensor_dense(in_shape, out_shape, kernel_shape, quad_weights, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi): """ Helper routine to compute the convolution Tensor in a dense fashion """ @@ -121,11 +126,13 @@ def _precompute_convolution_tensor_dense(in_shape, out_shape, kernel_shape, grid assert len(in_shape) == 2 assert len(out_shape) == 2 + quad_weights = quad_weights.reshape(-1, 1) + if len(kernel_shape) == 1: - kernel_handle = partial(_compute_vals_isotropic, nr=kernel_shape[0], r_cutoff=theta_cutoff) + kernel_handle = partial(_compute_vals_isotropic, quad_weights=quad_weights, nr=kernel_shape[0], r_cutoff=theta_cutoff) kernel_size = math.ceil( kernel_shape[0] / 2) elif len(kernel_shape) == 2: - kernel_handle = partial(_compute_vals_anisotropic, nr=kernel_shape[0], nphi=kernel_shape[1], r_cutoff=theta_cutoff) + kernel_handle = partial(_compute_vals_anisotropic, quad_weights=quad_weights, nr=kernel_shape[0], nphi=kernel_shape[1], r_cutoff=theta_cutoff) kernel_size = (kernel_shape[0] // 2) * kernel_shape[1] + kernel_shape[0] % 2 else: raise ValueError("kernel_shape should be either one- or two-dimensional.") @@ -189,7 +196,7 @@ def setUp(self): [8, 4, 2, (16, 32), (16, 32), [3], "equiangular", "equiangular", False, 5e-5], [8, 4, 2, (16, 32), (8, 16), [5], "equiangular", "equiangular", False, 5e-5], [8, 4, 2, (16, 32), (8, 16), [3, 3], "equiangular", "equiangular", False, 5e-5], - [8, 4, 2, (16, 32), (8, 16), [2, 3], "equiangular", "equiangular", False, 5e-5], + [8, 4, 2, (16, 32), (8, 16), [4, 3], "equiangular", "equiangular", False, 5e-5], [8, 4, 2, (18, 36), (6, 12), [7], "equiangular", "equiangular", False, 5e-5], [8, 4, 2, (16, 32), (8, 16), [5], "equiangular", "legendre-gauss", False, 5e-5], [8, 4, 2, (16, 32), (8, 16), [5], "legendre-gauss", "equiangular", False, 5e-5], @@ -198,7 +205,7 @@ def setUp(self): [8, 4, 2, (16, 32), (16, 32), [3], "equiangular", "equiangular", True, 5e-5], [8, 4, 2, (8, 16), (16, 32), [5], "equiangular", "equiangular", True, 5e-5], [8, 4, 2, (8, 16), (16, 32), [3, 3], "equiangular", "equiangular", True, 5e-5], - [8, 4, 2, (8, 16), (16, 32), [2, 3], "equiangular", "equiangular", True, 5e-5], + [8, 4, 2, (8, 16), (16, 32), [4, 3], "equiangular", "equiangular", True, 5e-5], [8, 4, 2, (6, 12), (18, 36), [7], "equiangular", "equiangular", True, 5e-5], [8, 4, 2, (8, 16), (16, 32), [5], "equiangular", "legendre-gauss", True, 5e-5], [8, 4, 2, (8, 16), (16, 32), [5], "legendre-gauss", "equiangular", True, 5e-5], @@ -238,9 +245,15 @@ def test_disco_convolution( ).to(self.device) if transpose: - psi_dense = _precompute_convolution_tensor_dense(out_shape, in_shape, kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff).to(self.device) + _, wgl = _precompute_latitudes(nlat_out, grid=grid_out) + quad_weights = 2.0 * torch.pi * torch.from_numpy(wgl).float().reshape(-1, 1) / nlon_out + + psi_dense = _precompute_convolution_tensor_dense(out_shape, in_shape, kernel_shape, quad_weights, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff).to(self.device) else: - psi_dense = _precompute_convolution_tensor_dense(in_shape, out_shape, kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff).to(self.device) + _, wgl = _precompute_latitudes(nlat_in, grid=grid_in) + quad_weights = 2.0 * torch.pi * torch.from_numpy(wgl).float().reshape(-1, 1) / nlon_in + + psi_dense = _precompute_convolution_tensor_dense(in_shape, out_shape, kernel_shape, quad_weights, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff).to(self.device) psi = torch.sparse_coo_tensor(conv.psi_idx, conv.psi_vals, size=(conv.kernel_size, conv.nlat_out, conv.nlat_in * conv.nlon_in)).to_dense() diff --git a/torch_harmonics/convolution.py b/torch_harmonics/convolution.py index 63d1160..075c75a 100644 --- a/torch_harmonics/convolution.py +++ b/torch_harmonics/convolution.py @@ -55,7 +55,7 @@ _cuda_extension_available = False -def _compute_support_vals_isotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, r_cutoff: float, norm: str = "s2"): +def _compute_support_vals_isotropic(r: torch.Tensor, phi: torch.Tensor, quad_weights: torch.Tensor, nr: int, r_cutoff: float): """ Computes the index set that falls into the isotropic kernel's support and returns both indices and values. """ @@ -70,23 +70,20 @@ def _compute_support_vals_isotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, else: ir = (ikernel + 0.5) * dr - if norm == "none": - norm_factor = 1.0 - elif norm == "2d": - norm_factor = math.pi * (r_cutoff * nr / (nr + 1))**2 + math.pi * r_cutoff**2 * (2 * nr / (nr + 1) + 1) / (nr + 1) / 3 - elif norm == "s2": - norm_factor = 2 * math.pi * (1 - math.cos(r_cutoff - dr) + math.cos(r_cutoff - dr) + (math.sin(r_cutoff - dr) - math.sin(r_cutoff)) / dr) - else: - raise ValueError(f"Unknown normalization mode {norm}.") - # find the indices where the rotated position falls into the support of the kernel iidx = torch.argwhere(((r - ir).abs() <= dr) & (r <= r_cutoff)) - vals = (1 - (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() / dr) / norm_factor + vals = (1 - (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() / dr) + + # discretely normalize the filter basis functions + q = quad_weights[iidx[:, 1]].reshape(-1) + vnorm = torch.sum(torch.where(iidx[:, 0] == ikernel.reshape(-1, 1), vals.reshape(1, -1), 0.0) * q, dim=-1) + vals = vals / vnorm[iidx[:, 0]] + return iidx, vals -def _compute_support_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, nphi: int, r_cutoff: float, norm: str = "s2"): +def _compute_support_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, quad_weights: torch.Tensor, nr: int, nphi: int, r_cutoff: float): """ Computes the index set that falls into the anisotropic kernel's support and returns both indices and values. Handles the special case when there is an uneven number of collocation points across the diameter of the kernel. @@ -105,16 +102,6 @@ def _compute_support_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: in ir = (ikernel // nphi + 0.5) * dr iphi = (ikernel % nphi) * dphi - # TODO: this has been computed for the odd case. Still needs to be verified for the even case. - if norm == "none": - norm_factor = 1.0 - elif norm == "2d": - norm_factor = math.pi * (r_cutoff * nr / (nr + 1))**2 + math.pi * r_cutoff**2 * (2 * nr / (nr + 1) + 1) / (nr + 1) / 3 - elif norm == "s2": - norm_factor = 2 * math.pi * (1 - math.cos(r_cutoff - dr) + math.cos(r_cutoff - dr) + (math.sin(r_cutoff - dr) - math.sin(r_cutoff)) / dr) - else: - raise ValueError(f"Unknown normalization mode {norm}.") - # find the indices where the rotated position falls into the support of the kernel if nr % 2 == 1: # find the support @@ -126,12 +113,18 @@ def _compute_support_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: in dist_r = (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() dist_phi = (phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs() # compute the value of the basis functions - vals = (1 - dist_r / dr) / norm_factor + vals = (1 - dist_r / dr) vals *= torch.where( (iidx[:, 0] > 0), (1 - torch.minimum(dist_phi, (2 * math.pi - dist_phi)) / dphi), 1.0, ) + + # discretely normalize the filter basis functions + q = quad_weights[iidx[:, 1]].reshape(-1) + vnorm = torch.sum(torch.where(iidx[:, 0] == ikernel.reshape(-1, 1), vals.reshape(1, -1), 0.0) * q, dim=-1) + vals = vals / vnorm[iidx[:, 0]] + else: # in the even case, the inner casis functions overlap into areas with a negative areas rn = - r @@ -148,16 +141,21 @@ def _compute_support_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: in dist_rn = (rn[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() dist_phin = (phin[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs() # compute the value of the basis functions - vals = cond_r[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_r / dr) / norm_factor + vals = cond_r[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_r / dr) vals *= cond_phi[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - torch.minimum(dist_phi, (2 * math.pi - dist_phi)) / dphi) - valsn = cond_rn[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_rn / dr) / norm_factor + valsn = cond_rn[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_rn / dr) valsn *= cond_phin[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - torch.minimum(dist_phin, (2 * math.pi - dist_phin)) / dphi) vals += valsn + # discretely normalize the filter basis functions + q = quad_weights[iidx[:, 1]].reshape(-1) + vnorm = torch.sum(torch.where(iidx[:, 0] == ikernel.reshape(-1, 1), vals.reshape(1, -1), 0.0) * q, dim=-1) + vals = vals / vnorm[iidx[:, 0]] + return iidx, vals -def _precompute_convolution_tensor_s2(in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi): +def _precompute_convolution_tensor_s2(in_shape, out_shape, kernel_shape, quad_weights, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi): """ Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$. Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al. @@ -178,9 +176,9 @@ def _precompute_convolution_tensor_s2(in_shape, out_shape, kernel_shape, grid_in assert len(out_shape) == 2 if len(kernel_shape) == 1: - kernel_handle = partial(_compute_support_vals_isotropic, nr=kernel_shape[0], r_cutoff=theta_cutoff, norm="s2") + kernel_handle = partial(_compute_support_vals_isotropic, quad_weights=quad_weights, nr=kernel_shape[0], r_cutoff=theta_cutoff) elif len(kernel_shape) == 2: - kernel_handle = partial(_compute_support_vals_anisotropic, nr=kernel_shape[0], nphi=kernel_shape[1], r_cutoff=theta_cutoff, norm="s2") + kernel_handle = partial(_compute_support_vals_anisotropic, quad_weights=quad_weights, nr=kernel_shape[0], nphi=kernel_shape[1], r_cutoff=theta_cutoff) else: raise ValueError("kernel_shape should be either one- or two-dimensional.") @@ -239,46 +237,6 @@ def _precompute_convolution_tensor_s2(in_shape, out_shape, kernel_shape, grid_in return out_idx, out_vals -def _precompute_convolution_tensor_2d(grid_in, grid_out, kernel_shape, radius_cutoff=0.01, periodic=False): - """ - Precomputes the translated filters at positions $T^{-1}_j \omega_i = T^{-1}_j T_i \nu$. Similar to the S2 routine, - only that it assumes a non-periodic subset of the euclidean plane - """ - - # check that input arrays are valid point clouds in 2D - assert len(grid_in) == 2 - assert len(grid_out) == 2 - assert grid_in.shape[0] == 2 - assert grid_out.shape[0] == 2 - - n_in = grid_in.shape[-1] - n_out = grid_out.shape[-1] - - if len(kernel_shape) == 1: - kernel_handle = partial(_compute_support_vals_isotropic, nr=kernel_shape[0], r_cutoff=radius_cutoff, norm="2d") - elif len(kernel_shape) == 2: - kernel_handle = partial(_compute_support_vals_anisotropic, nr=kernel_shape[0], nphi=kernel_shape[1], r_cutoff=radius_cutoff, norm="2d") - else: - raise ValueError("kernel_shape should be either one- or two-dimensional.") - - grid_in = grid_in.reshape(2, 1, n_in) - grid_out = grid_out.reshape(2, n_out, 1) - - diffs = grid_in - grid_out - if periodic: - periodic_diffs = torch.where(diffs > 0.0, diffs-1, diffs+1) - diffs = torch.where(diffs.abs() < periodic_diffs.abs(), diffs, periodic_diffs) - - - r = torch.sqrt(diffs[0] ** 2 + diffs[1] ** 2) - phi = torch.arctan2(diffs[1], diffs[0]) + torch.pi - - idx, vals = kernel_handle(r, phi) - idx = idx.permute(1, 0) - - return idx, vals - - class DiscreteContinuousConv(nn.Module, metaclass=abc.ABCMeta): """ Abstract base class for DISCO convolutions @@ -367,7 +325,7 @@ def __init__( quad_weights = 2.0 * torch.pi * torch.from_numpy(wgl).float().reshape(-1, 1) / self.nlon_in self.register_buffer("quad_weights", quad_weights, persistent=False) - idx, vals = _precompute_convolution_tensor_s2(in_shape, out_shape, self.kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff) + idx, vals = _precompute_convolution_tensor_s2(in_shape, out_shape, self.kernel_shape, self.quad_weights, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff) # sort the values ker_idx = idx[0, ...].contiguous() @@ -453,8 +411,13 @@ def __init__( quad_weights = 2.0 * torch.pi * torch.from_numpy(wgl).float().reshape(-1, 1) / self.nlon_in self.register_buffer("quad_weights", quad_weights, persistent=False) + # despite performing quadrature over the input grid, we keep normalization consistent with regular convolution + # this requires computation of the output quadrature rule + _, wgl_out = _precompute_latitudes(self.nlat_out, grid=grid_out) + quad_weights_out = 2.0 * torch.pi * torch.from_numpy(wgl_out).float().reshape(-1, 1) / self.nlon_out + # switch in_shape and out_shape since we want transpose conv - idx, vals = _precompute_convolution_tensor_s2(out_shape, in_shape, self.kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff) + idx, vals = _precompute_convolution_tensor_s2(out_shape, in_shape, self.kernel_shape, quad_weights_out, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff) # sort the values ker_idx = idx[0, ...].contiguous() From b597bfcfdcb2d15155d7588a2dd051bef4d70298 Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Thu, 25 Apr 2024 18:29:14 +0200 Subject: [PATCH 2/4] implemented discrete normalization of disco filters --- tests/test_convolution.py | 54 ++++++++------ torch_harmonics/convolution.py | 125 ++++++++++++++++++++++----------- 2 files changed, 117 insertions(+), 62 deletions(-) diff --git a/tests/test_convolution.py b/tests/test_convolution.py index b85f21e..9ca524c 100644 --- a/tests/test_convolution.py +++ b/tests/test_convolution.py @@ -40,7 +40,7 @@ from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes -def _compute_vals_isotropic(r: torch.Tensor, phi: torch.Tensor, quad_weights: torch.Tensor, nr: int, r_cutoff: float): +def _compute_vals_isotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, r_cutoff: float): """ helper routine to compute the values of the isotropic kernel densely """ @@ -61,14 +61,10 @@ def _compute_vals_isotropic(r: torch.Tensor, phi: torch.Tensor, quad_weights: to 0, ) - # apply normalization - vnorm = torch.sum(vals * quad_weights, dim=(-1, -2), keepdim=True) - vals = vals / vnorm - return vals -def _compute_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, quad_weights: torch.Tensor, nr: int, nphi: int, r_cutoff: float): +def _compute_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, nphi: int, r_cutoff: float): """ helper routine to compute the values of the anisotropic kernel densely """ @@ -111,14 +107,27 @@ def _compute_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, quad_weights: phin_vals = torch.where(cond_phin, (1 - torch.minimum((phin - iphi).abs(), (2 * math.pi - (phin - iphi).abs())) / dphi), 0.0) vals += rn_vals * phin_vals - # apply normalization - vnorm = torch.sum(vals * quad_weights, dim=(-1, -2), keepdim=True) - vals = vals / vnorm - return vals +def _normalize_convolution_tensor_dense(psi, quad_weights, transpose_normalization=False, eps=1e-9): + """ + Discretely normalizes the convolution tensor. + """ + + kernel_size, nlat_out, nlon_out, nlat_in, nlon_in = psi.shape + scale_factor = float(nlon_in // nlon_out) + + if transpose_normalization: + # the normalization is not quite symmetric due to the compressed way psi is stored in the main code + # look at the normalization code in the actual implementation + psi_norm = torch.sum(quad_weights.reshape(1, -1, 1, 1, 1) * psi[:,:,:1], dim=(1, 4), keepdim=True) / scale_factor + else: + psi_norm = torch.sum(quad_weights.reshape(1, 1, 1, -1, 1) * psi, dim=(3, 4), keepdim=True) + + return psi / (psi_norm + eps) + -def _precompute_convolution_tensor_dense(in_shape, out_shape, kernel_shape, quad_weights, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi): +def _precompute_convolution_tensor_dense(in_shape, out_shape, kernel_shape, quad_weights, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi, transpose_normalization=False): """ Helper routine to compute the convolution Tensor in a dense fashion """ @@ -129,10 +138,10 @@ def _precompute_convolution_tensor_dense(in_shape, out_shape, kernel_shape, quad quad_weights = quad_weights.reshape(-1, 1) if len(kernel_shape) == 1: - kernel_handle = partial(_compute_vals_isotropic, quad_weights=quad_weights, nr=kernel_shape[0], r_cutoff=theta_cutoff) + kernel_handle = partial(_compute_vals_isotropic, nr=kernel_shape[0], r_cutoff=theta_cutoff) kernel_size = math.ceil( kernel_shape[0] / 2) elif len(kernel_shape) == 2: - kernel_handle = partial(_compute_vals_anisotropic, quad_weights=quad_weights, nr=kernel_shape[0], nphi=kernel_shape[1], r_cutoff=theta_cutoff) + kernel_handle = partial(_compute_vals_anisotropic, nr=kernel_shape[0], nphi=kernel_shape[1], r_cutoff=theta_cutoff) kernel_size = (kernel_shape[0] // 2) * kernel_shape[1] + kernel_shape[0] % 2 else: raise ValueError("kernel_shape should be either one- or two-dimensional.") @@ -177,6 +186,9 @@ def _precompute_convolution_tensor_dense(in_shape, out_shape, kernel_shape, quad # find the indices where the rotated position falls into the support of the kernel out[:, t, p, :, :] = kernel_handle(theta, phi) + # take care of normalization + out = _normalize_convolution_tensor_dense(out, quad_weights=quad_weights, transpose_normalization=transpose_normalization) + return out @@ -188,6 +200,7 @@ def setUp(self): torch.cuda.manual_seed(333) else: self.device = torch.device("cpu") + torch.manual_seed(333) @parameterized.expand( @@ -244,16 +257,17 @@ def test_disco_convolution( theta_cutoff=theta_cutoff ).to(self.device) + _, wgl = _precompute_latitudes(nlat_in, grid=grid_in) + quad_weights = 2.0 * torch.pi * torch.from_numpy(wgl).float().reshape(-1, 1) / nlon_in + if transpose: - _, wgl = _precompute_latitudes(nlat_out, grid=grid_out) - quad_weights = 2.0 * torch.pi * torch.from_numpy(wgl).float().reshape(-1, 1) / nlon_out + psi_dense = _precompute_convolution_tensor_dense(out_shape, in_shape, kernel_shape, quad_weights, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff, transpose_normalization=True).to(self.device) - psi_dense = _precompute_convolution_tensor_dense(out_shape, in_shape, kernel_shape, quad_weights, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff).to(self.device) - else: - _, wgl = _precompute_latitudes(nlat_in, grid=grid_in) - quad_weights = 2.0 * torch.pi * torch.from_numpy(wgl).float().reshape(-1, 1) / nlon_in + psi = torch.sparse_coo_tensor(conv.psi_idx, conv.psi_vals, size=(conv.kernel_size, conv.nlat_in, conv.nlat_out * conv.nlon_out)).to_dense() - psi_dense = _precompute_convolution_tensor_dense(in_shape, out_shape, kernel_shape, quad_weights, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff).to(self.device) + self.assertTrue(torch.allclose(psi, psi_dense[:, :, 0].reshape(-1, nlat_in, nlat_out * nlon_out))) + else: + psi_dense = _precompute_convolution_tensor_dense(in_shape, out_shape, kernel_shape, quad_weights, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff, transpose_normalization=False).to(self.device) psi = torch.sparse_coo_tensor(conv.psi_idx, conv.psi_vals, size=(conv.kernel_size, conv.nlat_out, conv.nlat_in * conv.nlon_in)).to_dense() diff --git a/torch_harmonics/convolution.py b/torch_harmonics/convolution.py index 075c75a..acd7e6c 100644 --- a/torch_harmonics/convolution.py +++ b/torch_harmonics/convolution.py @@ -49,20 +49,23 @@ try: import disco_cuda_extension + _cuda_extension_available = True except ImportError as err: disco_cuda_extension = None _cuda_extension_available = False +# _cuda_extension_available = False + -def _compute_support_vals_isotropic(r: torch.Tensor, phi: torch.Tensor, quad_weights: torch.Tensor, nr: int, r_cutoff: float): +def _compute_support_vals_isotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, r_cutoff: float): """ Computes the index set that falls into the isotropic kernel's support and returns both indices and values. """ kernel_size = (nr // 2) + nr % 2 ikernel = torch.arange(kernel_size).reshape(-1, 1, 1) - dr = 2 * r_cutoff / (nr + 1) + dr = 2 * r_cutoff / (nr + 1) # compute the support if nr % 2 == 1: @@ -72,18 +75,12 @@ def _compute_support_vals_isotropic(r: torch.Tensor, phi: torch.Tensor, quad_wei # find the indices where the rotated position falls into the support of the kernel iidx = torch.argwhere(((r - ir).abs() <= dr) & (r <= r_cutoff)) - vals = (1 - (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() / dr) - - # discretely normalize the filter basis functions - q = quad_weights[iidx[:, 1]].reshape(-1) - vnorm = torch.sum(torch.where(iidx[:, 0] == ikernel.reshape(-1, 1), vals.reshape(1, -1), 0.0) * q, dim=-1) - vals = vals / vnorm[iidx[:, 0]] + vals = 1 - (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() / dr return iidx, vals - -def _compute_support_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, quad_weights: torch.Tensor, nr: int, nphi: int, r_cutoff: float): +def _compute_support_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, nphi: int, r_cutoff: float): """ Computes the index set that falls into the anisotropic kernel's support and returns both indices and values. Handles the special case when there is an uneven number of collocation points across the diameter of the kernel. @@ -91,7 +88,7 @@ def _compute_support_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, quad_w kernel_size = (nr // 2) * nphi + nr % 2 ikernel = torch.arange(kernel_size).reshape(-1, 1, 1) - dr = 2 * r_cutoff / (nr + 1) + dr = 2 * r_cutoff / (nr + 1) dphi = 2.0 * math.pi / nphi # disambiguate even and uneven cases and compute the support @@ -113,22 +110,17 @@ def _compute_support_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, quad_w dist_r = (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() dist_phi = (phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs() # compute the value of the basis functions - vals = (1 - dist_r / dr) + vals = 1 - dist_r / dr vals *= torch.where( (iidx[:, 0] > 0), (1 - torch.minimum(dist_phi, (2 * math.pi - dist_phi)) / dphi), 1.0, ) - # discretely normalize the filter basis functions - q = quad_weights[iidx[:, 1]].reshape(-1) - vnorm = torch.sum(torch.where(iidx[:, 0] == ikernel.reshape(-1, 1), vals.reshape(1, -1), 0.0) * q, dim=-1) - vals = vals / vnorm[iidx[:, 0]] - else: # in the even case, the inner casis functions overlap into areas with a negative areas - rn = - r - phin = torch.where(phi + math.pi >= 2*math.pi, phi - math.pi, phi + math.pi) + rn = -r + phin = torch.where(phi + math.pi >= 2 * math.pi, phi - math.pi, phi + math.pi) # find the support cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff) cond_phi = ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi) @@ -147,15 +139,59 @@ def _compute_support_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, quad_w valsn *= cond_phin[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - torch.minimum(dist_phin, (2 * math.pi - dist_phin)) / dphi) vals += valsn - # discretely normalize the filter basis functions - q = quad_weights[iidx[:, 1]].reshape(-1) - vnorm = torch.sum(torch.where(iidx[:, 0] == ikernel.reshape(-1, 1), vals.reshape(1, -1), 0.0) * q, dim=-1) - vals = vals / vnorm[iidx[:, 0]] - return iidx, vals -def _precompute_convolution_tensor_s2(in_shape, out_shape, kernel_shape, quad_weights, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi): +def _normalize_onvolution_tensor_s2(psi_idx, psi_vals, in_shape, out_shape, kernel_shape, quad_weights, transpose_normalization=False, eps=1e-9): + """ + Discretely normalizes the convolution tensor. + """ + + nlat_in, nlon_in = in_shape + nlat_out, nlon_out = out_shape + + if len(kernel_shape) == 1: + kernel_size = math.ceil(kernel_shape[0] / 2) + elif len(kernel_shape) == 2: + kernel_size = (kernel_shape[0] // 2) * kernel_shape[1] + kernel_shape[0] % 2 + + # reshape the indices implicitly to be ikernel, lat_out, lat_in, lon_in + idx = torch.stack([psi_idx[0], psi_idx[1], psi_idx[2] // nlon_in, psi_idx[2] % nlon_in], dim=0) + + if transpose_normalization: + # pre-compute the quadrature weights + q = quad_weights[idx[1]].reshape(-1) + + # compute scale factor + scale_factor = float(nlon_in // nlon_out) + + # loop through dimensions which require normalization + for ik in range(kernel_size): + for ilat in range(nlat_in): + # get relevant entries + iidx = torch.argwhere((idx[0] == ik) & (idx[2] == ilat)) + # normalize, while summing also over the input longitude dimension here as this is not available for the output + vnorm = torch.sum(psi_vals[iidx] * q[iidx]) / scale_factor + psi_vals[iidx] = psi_vals[iidx] / (vnorm + eps) + else: + # pre-compute the quadrature weights + q = quad_weights[idx[2]].reshape(-1) + + # loop through dimensions which require normalization + for ik in range(kernel_size): + for ilat in range(nlat_out): + # get relevant entries + iidx = torch.argwhere((idx[0] == ik) & (idx[1] == ilat)) + # normalize + vnorm = torch.sum(psi_vals[iidx] * q[iidx]) + psi_vals[iidx] = psi_vals[iidx] / (vnorm + eps) + + return psi_vals + + +def _precompute_convolution_tensor_s2( + in_shape, out_shape, kernel_shape, quad_weights, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi, transpose_normalization=False +): """ Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$. Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al. @@ -176,9 +212,9 @@ def _precompute_convolution_tensor_s2(in_shape, out_shape, kernel_shape, quad_we assert len(out_shape) == 2 if len(kernel_shape) == 1: - kernel_handle = partial(_compute_support_vals_isotropic, quad_weights=quad_weights, nr=kernel_shape[0], r_cutoff=theta_cutoff) + kernel_handle = partial(_compute_support_vals_isotropic, nr=kernel_shape[0], r_cutoff=theta_cutoff) elif len(kernel_shape) == 2: - kernel_handle = partial(_compute_support_vals_anisotropic, quad_weights=quad_weights, nr=kernel_shape[0], nphi=kernel_shape[1], r_cutoff=theta_cutoff) + kernel_handle = partial(_compute_support_vals_anisotropic, nr=kernel_shape[0], nphi=kernel_shape[1], r_cutoff=theta_cutoff) else: raise ValueError("kernel_shape should be either one- or two-dimensional.") @@ -234,6 +270,8 @@ def _precompute_convolution_tensor_s2(in_shape, out_shape, kernel_shape, quad_we out_idx = torch.cat(out_idx, dim=-1).to(torch.long).contiguous() out_vals = torch.cat(out_vals, dim=-1).to(torch.float32).contiguous() + out_vals = _normalize_onvolution_tensor_s2(out_idx, out_vals, in_shape, out_shape, kernel_shape, quad_weights, transpose_normalization=transpose_normalization) + return out_idx, out_vals @@ -258,9 +296,11 @@ def __init__( self.kernel_shape = kernel_shape if len(self.kernel_shape) == 1: - self.kernel_size = math.ceil( self.kernel_shape[0] / 2) + self.kernel_size = math.ceil(self.kernel_shape[0] / 2) if self.kernel_shape[0] % 2 == 0: - warn("Detected isotropic kernel with even number of collocation points in the radial direction. This feature is only supported out of consistency and may lead to unexpected behavior.") + warn( + "Detected isotropic kernel with even number of collocation points in the radial direction. This feature is only supported out of consistency and may lead to unexpected behavior." + ) elif len(self.kernel_shape) == 2: self.kernel_size = (self.kernel_shape[0] // 2) * self.kernel_shape[1] + self.kernel_shape[0] % 2 if len(self.kernel_shape) > 2: @@ -325,7 +365,9 @@ def __init__( quad_weights = 2.0 * torch.pi * torch.from_numpy(wgl).float().reshape(-1, 1) / self.nlon_in self.register_buffer("quad_weights", quad_weights, persistent=False) - idx, vals = _precompute_convolution_tensor_s2(in_shape, out_shape, self.kernel_shape, self.quad_weights, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff) + idx, vals = _precompute_convolution_tensor_s2( + in_shape, out_shape, self.kernel_shape, self.quad_weights, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff, transpose_normalization=False + ) # sort the values ker_idx = idx[0, ...].contiguous() @@ -353,7 +395,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.quad_weights * x if x.is_cuda and _cuda_extension_available: - x = _disco_s2_contraction_cuda(x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out) + x = _disco_s2_contraction_cuda( + x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out + ) else: if x.is_cuda: warn("couldn't find CUDA extension, falling back to slow PyTorch implementation") @@ -411,13 +455,10 @@ def __init__( quad_weights = 2.0 * torch.pi * torch.from_numpy(wgl).float().reshape(-1, 1) / self.nlon_in self.register_buffer("quad_weights", quad_weights, persistent=False) - # despite performing quadrature over the input grid, we keep normalization consistent with regular convolution - # this requires computation of the output quadrature rule - _, wgl_out = _precompute_latitudes(self.nlat_out, grid=grid_out) - quad_weights_out = 2.0 * torch.pi * torch.from_numpy(wgl_out).float().reshape(-1, 1) / self.nlon_out - # switch in_shape and out_shape since we want transpose conv - idx, vals = _precompute_convolution_tensor_s2(out_shape, in_shape, self.kernel_shape, quad_weights_out, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff) + idx, vals = _precompute_convolution_tensor_s2( + out_shape, in_shape, self.kernel_shape, quad_weights, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff, transpose_normalization=True + ) # sort the values ker_idx = idx[0, ...].contiguous() @@ -436,7 +477,7 @@ def __init__( def psi_idx(self): return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous() - def get_psi(self, semi_transposed: bool=False): + def get_psi(self, semi_transposed: bool = False): if semi_transposed: # we do a semi-transposition to faciliate the computation tout = self.psi_idx[2] // self.nlon_out @@ -444,7 +485,7 @@ def get_psi(self, semi_transposed: bool=False): # flip the axis of longitudes pout = self.nlon_out - 1 - pout tin = self.psi_idx[1] - idx = torch.stack([self.psi_idx[0], tout, tin*self.nlon_out + pout], dim=0) + idx = torch.stack([self.psi_idx[0], tout, tin * self.nlon_out + pout], dim=0) psi = torch.sparse_coo_tensor(idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_out)).coalesce() else: psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out)).coalesce() @@ -464,8 +505,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.quad_weights * x if x.is_cuda and _cuda_extension_available: - out = _disco_s2_transpose_contraction_cuda(x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, - self.kernel_size, self.nlat_out, self.nlon_out) + out = _disco_s2_transpose_contraction_cuda( + x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out + ) else: if x.is_cuda: warn("couldn't find CUDA extension, falling back to slow PyTorch implementation") @@ -476,4 +518,3 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: out = out + self.bias.reshape(1, -1, 1, 1) return out - From 102c22d556df1143a97081007a58464ceadc3a8f Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Thu, 25 Apr 2024 21:08:18 +0200 Subject: [PATCH 3/4] relaxing tolerances in convolution unit test --- tests/test_convolution.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/tests/test_convolution.py b/tests/test_convolution.py index 9ca524c..d766479 100644 --- a/tests/test_convolution.py +++ b/tests/test_convolution.py @@ -206,23 +206,23 @@ def setUp(self): @parameterized.expand( [ # regular convolution - [8, 4, 2, (16, 32), (16, 32), [3], "equiangular", "equiangular", False, 5e-5], - [8, 4, 2, (16, 32), (8, 16), [5], "equiangular", "equiangular", False, 5e-5], - [8, 4, 2, (16, 32), (8, 16), [3, 3], "equiangular", "equiangular", False, 5e-5], - [8, 4, 2, (16, 32), (8, 16), [4, 3], "equiangular", "equiangular", False, 5e-5], - [8, 4, 2, (18, 36), (6, 12), [7], "equiangular", "equiangular", False, 5e-5], - [8, 4, 2, (16, 32), (8, 16), [5], "equiangular", "legendre-gauss", False, 5e-5], - [8, 4, 2, (16, 32), (8, 16), [5], "legendre-gauss", "equiangular", False, 5e-5], - [8, 4, 2, (16, 32), (8, 16), [5], "legendre-gauss", "legendre-gauss", False, 5e-5], + [8, 4, 2, (16, 32), (16, 32), [3], "equiangular", "equiangular", False, 1e-4], + [8, 4, 2, (16, 32), (8, 16), [5], "equiangular", "equiangular", False, 1e-4], + [8, 4, 2, (16, 32), (8, 16), [3, 3], "equiangular", "equiangular", False, 1e-4], + [8, 4, 2, (16, 32), (8, 16), [4, 3], "equiangular", "equiangular", False, 1e-4], + [8, 4, 2, (18, 36), (6, 12), [7], "equiangular", "equiangular", False, 1e-4], + [8, 4, 2, (16, 32), (8, 16), [5], "equiangular", "legendre-gauss", False, 1e-4], + [8, 4, 2, (16, 32), (8, 16), [5], "legendre-gauss", "equiangular", False, 1e-4], + [8, 4, 2, (16, 32), (8, 16), [5], "legendre-gauss", "legendre-gauss", False, 1e-4], # transpose convolution - [8, 4, 2, (16, 32), (16, 32), [3], "equiangular", "equiangular", True, 5e-5], - [8, 4, 2, (8, 16), (16, 32), [5], "equiangular", "equiangular", True, 5e-5], - [8, 4, 2, (8, 16), (16, 32), [3, 3], "equiangular", "equiangular", True, 5e-5], - [8, 4, 2, (8, 16), (16, 32), [4, 3], "equiangular", "equiangular", True, 5e-5], - [8, 4, 2, (6, 12), (18, 36), [7], "equiangular", "equiangular", True, 5e-5], - [8, 4, 2, (8, 16), (16, 32), [5], "equiangular", "legendre-gauss", True, 5e-5], - [8, 4, 2, (8, 16), (16, 32), [5], "legendre-gauss", "equiangular", True, 5e-5], - [8, 4, 2, (8, 16), (16, 32), [5], "legendre-gauss", "legendre-gauss", True, 5e-5], + [8, 4, 2, (16, 32), (16, 32), [3], "equiangular", "equiangular", True, 1e-4], + [8, 4, 2, (8, 16), (16, 32), [5], "equiangular", "equiangular", True, 1e-4], + [8, 4, 2, (8, 16), (16, 32), [3, 3], "equiangular", "equiangular", True, 1e-4], + [8, 4, 2, (8, 16), (16, 32), [4, 3], "equiangular", "equiangular", True, 1e-4], + [8, 4, 2, (6, 12), (18, 36), [7], "equiangular", "equiangular", True, 1e-4], + [8, 4, 2, (8, 16), (16, 32), [5], "equiangular", "legendre-gauss", True, 1e-4], + [8, 4, 2, (8, 16), (16, 32), [5], "legendre-gauss", "equiangular", True, 1e-4], + [8, 4, 2, (8, 16), (16, 32), [5], "legendre-gauss", "legendre-gauss", True, 1e-4], ] ) def test_disco_convolution( From 40d5569712494042b634fb5bab7a2fd0c8754c8c Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Fri, 26 Apr 2024 11:20:55 +0200 Subject: [PATCH 4/4] bugfix to correctly support unequal scale factors in latitudes and longitudes --- tests/test_convolution.py | 2 ++ torch_harmonics/_disco_convolution.py | 2 +- torch_harmonics/convolution.py | 2 -- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_convolution.py b/tests/test_convolution.py index d766479..bcdd551 100644 --- a/tests/test_convolution.py +++ b/tests/test_convolution.py @@ -210,6 +210,7 @@ def setUp(self): [8, 4, 2, (16, 32), (8, 16), [5], "equiangular", "equiangular", False, 1e-4], [8, 4, 2, (16, 32), (8, 16), [3, 3], "equiangular", "equiangular", False, 1e-4], [8, 4, 2, (16, 32), (8, 16), [4, 3], "equiangular", "equiangular", False, 1e-4], + [8, 4, 2, (16, 24), (8, 8), [3], "equiangular", "equiangular", False, 1e-4], [8, 4, 2, (18, 36), (6, 12), [7], "equiangular", "equiangular", False, 1e-4], [8, 4, 2, (16, 32), (8, 16), [5], "equiangular", "legendre-gauss", False, 1e-4], [8, 4, 2, (16, 32), (8, 16), [5], "legendre-gauss", "equiangular", False, 1e-4], @@ -219,6 +220,7 @@ def setUp(self): [8, 4, 2, (8, 16), (16, 32), [5], "equiangular", "equiangular", True, 1e-4], [8, 4, 2, (8, 16), (16, 32), [3, 3], "equiangular", "equiangular", True, 1e-4], [8, 4, 2, (8, 16), (16, 32), [4, 3], "equiangular", "equiangular", True, 1e-4], + [8, 4, 2, (8, 8), (16, 24), [3], "equiangular", "equiangular", True, 1e-4], [8, 4, 2, (6, 12), (18, 36), [7], "equiangular", "equiangular", True, 1e-4], [8, 4, 2, (8, 16), (16, 32), [5], "equiangular", "legendre-gauss", True, 1e-4], [8, 4, 2, (8, 16), (16, 32), [5], "legendre-gauss", "equiangular", True, 1e-4], diff --git a/torch_harmonics/_disco_convolution.py b/torch_harmonics/_disco_convolution.py index 45891b2..1e28923 100644 --- a/torch_harmonics/_disco_convolution.py +++ b/torch_harmonics/_disco_convolution.py @@ -146,7 +146,7 @@ def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nl kernel_size, nlat_out, n_out = psi.shape assert n_out % nlon_out == 0 - assert nlon_out >= nlat_in + assert nlon_out >= nlon_in pscale = nlon_out // nlon_in # interleave zeros along the longitude dimension to allow for fractional offsets to be considered diff --git a/torch_harmonics/convolution.py b/torch_harmonics/convolution.py index acd7e6c..ba8ce45 100644 --- a/torch_harmonics/convolution.py +++ b/torch_harmonics/convolution.py @@ -55,8 +55,6 @@ disco_cuda_extension = None _cuda_extension_available = False -# _cuda_extension_available = False - def _compute_support_vals_isotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, r_cutoff: float): """