Skip to content

Commit

Permalink
Bbonev/disc even filters (#35)
Browse files Browse the repository at this point in the history
* initial working commit with new convention of counting collocation points across the diameter instead of across the radius

* fixed a bug in the computation of the even kernels

* changing heuristic for computing theta_cutoff

* Fixing unittest

* Readability improvements
  • Loading branch information
bonevbs authored and azrael417 committed Aug 19, 2024
1 parent 95131c3 commit d78f7c7
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 74 deletions.
5 changes: 4 additions & 1 deletion Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
### v0.7.0

* CUDA-accelerated DISCO convolutions
* Updated unit tests
* Updated DISCO convolutions to support even number of collocation points across the diameter
* Distributed DISCO convolutions
* Removed DISCO convolution in the plane to focus on the sphere
* Updated unit tests which now include tests for the distributed convolutions

### v0.6.5

Expand Down
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ $$

Here, $x_j \in [-1,1]$ are the quadrature nodes with the respective quadrature weights $w_j$.

### Discrete-continuous convolutions

torch-harmonics now provides local discrete-continuous (DISCO) convolutions as outlined in [4] on the sphere.

## Getting started

The main functionality of `torch_harmonics` is provided in the form of `torch.nn.Modules` for composability. A minimum example is given by:
Expand Down
117 changes: 75 additions & 42 deletions tests/test_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,47 +39,77 @@
from torch_harmonics import *


def _compute_vals_isotropic(theta: torch.Tensor, phi: torch.Tensor, ntheta: int, theta_cutoff: float):
def _compute_vals_isotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, r_cutoff: float):
"""
helper routine to compute the values of the isotropic kernel densely
"""

kernel_size = (nr // 2) + nr % 2
ikernel = torch.arange(kernel_size).reshape(-1, 1, 1)
dr = 2 * r_cutoff / (nr + 1)

# compute the support
dtheta = (theta_cutoff - 0.0) / ntheta
ikernel = torch.arange(ntheta).reshape(-1, 1, 1)
itheta = ikernel * dtheta
if nr % 2 == 1:
ir = ikernel * dr
else:
ir = (ikernel + 0.5) * dr

norm_factor = 2 * math.pi * (1 - math.cos(theta_cutoff - dtheta) + math.cos(theta_cutoff - dtheta) + (math.sin(theta_cutoff - dtheta) - math.sin(theta_cutoff)) / dtheta)
norm_factor = 2 * math.pi * (1 - math.cos(r_cutoff - dr) + math.cos(r_cutoff - dr) + (math.sin(r_cutoff - dr) - math.sin(r_cutoff)) / dr)

vals = torch.where(
((theta - itheta).abs() <= dtheta) & (theta <= theta_cutoff),
(1 - (theta - itheta).abs() / dtheta) / norm_factor,
((r - ir).abs() <= dr) & (r <= r_cutoff),
(1 - (r - ir).abs() / dr) / norm_factor,
0,
)
return vals


def _compute_vals_anisotropic(theta: torch.Tensor, phi: torch.Tensor, ntheta: int, nphi: int, theta_cutoff: float):
def _compute_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, nphi: int, r_cutoff: float):
"""
helper routine to compute the values of the anisotropic kernel densely
"""

# compute the support
dtheta = (theta_cutoff - 0.0) / ntheta
dphi = 2.0 * math.pi / nphi
kernel_size = (ntheta - 1) * nphi + 1
kernel_size = (nr // 2) * nphi + nr % 2
ikernel = torch.arange(kernel_size).reshape(-1, 1, 1)
itheta = ((ikernel - 1) // nphi + 1) * dtheta
iphi = ((ikernel - 1) % nphi) * dphi
dr = 2 * r_cutoff / (nr + 1)
dphi = 2.0 * math.pi / nphi

# disambiguate even and uneven cases and compute the support
if nr % 2 == 1:
ir = ((ikernel - 1) // nphi + 1) * dr
iphi = ((ikernel - 1) % nphi) * dphi
else:
ir = (ikernel // nphi + 0.5) * dr
iphi = (ikernel % nphi) * dphi

norm_factor = 2 * math.pi * (1 - math.cos(r_cutoff - dr) + math.cos(r_cutoff - dr) + (math.sin(r_cutoff - dr) - math.sin(r_cutoff)) / dr)

# compute the value of the filter
if nr % 2 == 1:
# find the indices where the rotated position falls into the support of the kernel
cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff)
cond_phi = ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi)
r_vals = torch.where(cond_r, (1 - (r - ir).abs() / dr) / norm_factor, 0.0)
phi_vals = torch.where(cond_phi, (1 - torch.minimum((phi - iphi).abs(), (2 * math.pi - (phi - iphi).abs())) / dphi), 0.0)
vals = torch.where(ikernel > 0, r_vals * phi_vals, r_vals)
else:
# find the indices where the rotated position falls into the support of the kernel
cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff)
cond_phi = ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi)
r_vals = torch.where(cond_r, (1 - (r - ir).abs() / dr) / norm_factor, 0.0)
phi_vals = torch.where(cond_phi, (1 - torch.minimum((phi - iphi).abs(), (2 * math.pi - (phi - iphi).abs())) / dphi), 0.0)
vals = r_vals * phi_vals

# in the even case, the inner casis functions overlap into areas with a negative areas
rn = - r
phin = torch.where(phi + math.pi >= 2*math.pi, phi - math.pi, phi + math.pi)
cond_rn = ((rn - ir).abs() <= dr) & (rn <= r_cutoff)
cond_phin = ((phin - iphi).abs() <= dphi) | ((2 * math.pi - (phin - iphi).abs()) <= dphi)
rn_vals = torch.where(cond_rn, (1 - (rn - ir).abs() / dr) / norm_factor, 0.0)
phin_vals = torch.where(cond_phin, (1 - torch.minimum((phin - iphi).abs(), (2 * math.pi - (phin - iphi).abs())) / dphi), 0.0)
vals += rn_vals * phin_vals

norm_factor = 2 * math.pi * (1 - math.cos(theta_cutoff - dtheta) + math.cos(theta_cutoff - dtheta) + (math.sin(theta_cutoff - dtheta) - math.sin(theta_cutoff)) / dtheta)

# find the indices where the rotated position falls into the support of the kernel
cond_theta = ((theta - itheta).abs() <= dtheta) & (theta <= theta_cutoff)
cond_phi = ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi)
theta_vals = torch.where(cond_theta, (1 - (theta - itheta).abs() / dtheta) / norm_factor, 0.0)
phi_vals = torch.where(cond_phi, (1 - torch.minimum((phi - iphi).abs(), (2 * math.pi - (phi - iphi).abs())) / dphi), 0.0)
vals = torch.where(ikernel > 0, theta_vals * phi_vals, theta_vals)
return vals


Expand All @@ -92,11 +122,11 @@ def _precompute_convolution_tensor_dense(in_shape, out_shape, kernel_shape, grid
assert len(out_shape) == 2

if len(kernel_shape) == 1:
kernel_handle = partial(_compute_vals_isotropic, ntheta=kernel_shape[0], theta_cutoff=theta_cutoff)
kernel_size = kernel_shape[0]
kernel_handle = partial(_compute_vals_isotropic, nr=kernel_shape[0], r_cutoff=theta_cutoff)
kernel_size = math.ceil( kernel_shape[0] / 2)
elif len(kernel_shape) == 2:
kernel_handle = partial(_compute_vals_anisotropic, ntheta=kernel_shape[0], nphi=kernel_shape[1], theta_cutoff=theta_cutoff)
kernel_size = (kernel_shape[0] - 1) * kernel_shape[1] + 1
kernel_handle = partial(_compute_vals_anisotropic, nr=kernel_shape[0], nphi=kernel_shape[1], r_cutoff=theta_cutoff)
kernel_size = (kernel_shape[0] // 2) * kernel_shape[1] + kernel_shape[0] % 2
else:
raise ValueError("kernel_shape should be either one- or two-dimensional.")

Expand Down Expand Up @@ -156,21 +186,23 @@ def setUp(self):
@parameterized.expand(
[
# regular convolution
[8, 4, 2, (16, 32), (16, 32), [2], "equiangular", "equiangular", False, 5e-5],
[8, 4, 2, (16, 32), (8, 16), [3], "equiangular", "equiangular", False, 5e-5],
[8, 4, 2, (16, 32), (16, 32), [3], "equiangular", "equiangular", False, 5e-5],
[8, 4, 2, (16, 32), (8, 16), [5], "equiangular", "equiangular", False, 5e-5],
[8, 4, 2, (16, 32), (8, 16), [3, 3], "equiangular", "equiangular", False, 5e-5],
[8, 4, 2, (16, 32), (8, 16), [2, 3], "equiangular", "equiangular", False, 5e-5],
[8, 4, 2, (18, 36), (6, 12), [4], "equiangular", "equiangular", False, 5e-5],
[8, 4, 2, (16, 32), (8, 16), [3], "equiangular", "legendre-gauss", False, 5e-5],
[8, 4, 2, (16, 32), (8, 16), [3], "legendre-gauss", "equiangular", False, 5e-5],
[8, 4, 2, (16, 32), (8, 16), [3], "legendre-gauss", "legendre-gauss", False, 5e-5],
[8, 4, 2, (18, 36), (6, 12), [7], "equiangular", "equiangular", False, 5e-5],
[8, 4, 2, (16, 32), (8, 16), [5], "equiangular", "legendre-gauss", False, 5e-5],
[8, 4, 2, (16, 32), (8, 16), [5], "legendre-gauss", "equiangular", False, 5e-5],
[8, 4, 2, (16, 32), (8, 16), [5], "legendre-gauss", "legendre-gauss", False, 5e-5],
# transpose convolution
[8, 4, 2, (16, 32), (16, 32), [2], "equiangular", "equiangular", True, 5e-5],
[8, 4, 2, (8, 16), (16, 32), [3], "equiangular", "equiangular", True, 5e-5],
[8, 4, 2, (16, 32), (16, 32), [3], "equiangular", "equiangular", True, 5e-5],
[8, 4, 2, (8, 16), (16, 32), [5], "equiangular", "equiangular", True, 5e-5],
[8, 4, 2, (8, 16), (16, 32), [3, 3], "equiangular", "equiangular", True, 5e-5],
[8, 4, 2, (8, 16), (16, 32), [2, 3], "equiangular", "equiangular", True, 5e-5],
[8, 4, 2, (6, 12), (18, 36), [4], "equiangular", "equiangular", True, 5e-5],
[8, 4, 2, (8, 16), (16, 32), [3], "equiangular", "legendre-gauss", True, 5e-5],
[8, 4, 2, (8, 16), (16, 32), [3], "legendre-gauss", "equiangular", True, 5e-5],
[8, 4, 2, (8, 16), (16, 32), [3], "legendre-gauss", "legendre-gauss", True, 5e-5],
[8, 4, 2, (6, 12), (18, 36), [7], "equiangular", "equiangular", True, 5e-5],
[8, 4, 2, (8, 16), (16, 32), [5], "equiangular", "legendre-gauss", True, 5e-5],
[8, 4, 2, (8, 16), (16, 32), [5], "legendre-gauss", "equiangular", True, 5e-5],
[8, 4, 2, (8, 16), (16, 32), [5], "legendre-gauss", "legendre-gauss", True, 5e-5],
]
)
def test_disco_convolution(
Expand All @@ -186,6 +218,11 @@ def test_disco_convolution(
transpose,
tol,
):
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape

theta_cutoff = (kernel_shape[0] + 1) / 2 * torch.pi / float(nlat_out - 1)

Conv = DiscreteContinuousConvTransposeS2 if transpose else DiscreteContinuousConvS2
conv = Conv(
in_channels,
Expand All @@ -197,13 +234,9 @@ def test_disco_convolution(
grid_in=grid_in,
grid_out=grid_out,
bias=False,
theta_cutoff=theta_cutoff
).to(self.device)

nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape

theta_cutoff = (kernel_shape[0]) * torch.pi / float(nlat_in - 1)

if transpose:
psi_dense = _precompute_convolution_tensor_dense(out_shape, in_shape, kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff).to(self.device)
else:
Expand Down
98 changes: 71 additions & 27 deletions torch_harmonics/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,15 @@ def _compute_support_vals_isotropic(r: torch.Tensor, phi: torch.Tensor, nr: int,
Computes the index set that falls into the isotropic kernel's support and returns both indices and values.
"""

kernel_size = (nr // 2) + nr % 2
ikernel = torch.arange(kernel_size).reshape(-1, 1, 1)
dr = 2 * r_cutoff / (nr + 1)

# compute the support
dr = (r_cutoff - 0.0) / nr
ikernel = torch.arange(nr).reshape(-1, 1, 1)
ir = ikernel * dr
if nr % 2 == 1:
ir = ikernel * dr
else:
ir = (ikernel + 0.5) * dr

if norm == "none":
norm_factor = 1.0
Expand All @@ -80,19 +85,27 @@ def _compute_support_vals_isotropic(r: torch.Tensor, phi: torch.Tensor, nr: int,
return iidx, vals



def _compute_support_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, nphi: int, r_cutoff: float, norm: str = "s2"):
"""
Computes the index set that falls into the anisotropic kernel's support and returns both indices and values.
Computes the index set that falls into the anisotropic kernel's support and returns both indices and values. Handles the special case
when there is an uneven number of collocation points across the diameter of the kernel.
"""

# compute the support
dr = (r_cutoff - 0.0) / nr
dphi = 2.0 * math.pi / nphi
kernel_size = (nr - 1) * nphi + 1
kernel_size = (nr // 2) * nphi + nr % 2
ikernel = torch.arange(kernel_size).reshape(-1, 1, 1)
ir = ((ikernel - 1) // nphi + 1) * dr
iphi = ((ikernel - 1) % nphi) * dphi
dr = 2 * r_cutoff / (nr + 1)
dphi = 2.0 * math.pi / nphi

# disambiguate even and uneven cases and compute the support
if nr % 2 == 1:
ir = ((ikernel - 1) // nphi + 1) * dr
iphi = ((ikernel - 1) % nphi) * dphi
else:
ir = (ikernel // nphi + 0.5) * dr
iphi = (ikernel % nphi) * dphi

# TODO: this has been computed for the odd case. Still needs to be verified for the even case.
if norm == "none":
norm_factor = 1.0
elif norm == "2d":
Expand All @@ -103,15 +116,44 @@ def _compute_support_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: in
raise ValueError(f"Unknown normalization mode {norm}.")

# find the indices where the rotated position falls into the support of the kernel
cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff)
cond_phi = (ikernel == 0) | ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi)
iidx = torch.argwhere(cond_r & cond_phi)
vals = (1 - (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() / dr) / norm_factor
vals *= torch.where(
iidx[:, 0] > 0,
(1 - torch.minimum((phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs(), (2 * math.pi - (phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs())) / dphi),
1.0,
)
if nr % 2 == 1:
# find the support
cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff)
cond_phi = (ikernel == 0) | ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi)
# find indices where conditions are met
iidx = torch.argwhere(cond_r & cond_phi)
# compute the distance to the collocation points
dist_r = (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs()
dist_phi = (phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs()
# compute the value of the basis functions
vals = (1 - dist_r / dr) / norm_factor
vals *= torch.where(
(iidx[:, 0] > 0),
(1 - torch.minimum(dist_phi, (2 * math.pi - dist_phi)) / dphi),
1.0,
)
else:
# in the even case, the inner casis functions overlap into areas with a negative areas
rn = - r
phin = torch.where(phi + math.pi >= 2*math.pi, phi - math.pi, phi + math.pi)
# find the support
cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff)
cond_phi = ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi)
cond_rn = ((rn - ir).abs() <= dr) & (rn <= r_cutoff)
cond_phin = ((phin - iphi).abs() <= dphi) | ((2 * math.pi - (phin - iphi).abs()) <= dphi)
# find indices where conditions are met
iidx = torch.argwhere((cond_r & cond_phi) | (cond_rn & cond_phin))
dist_r = (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs()
dist_phi = (phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs()
dist_rn = (rn[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs()
dist_phin = (phin[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs()
# compute the value of the basis functions
vals = cond_r[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_r / dr) / norm_factor
vals *= cond_phi[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - torch.minimum(dist_phi, (2 * math.pi - dist_phi)) / dphi)
valsn = cond_rn[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_rn / dr) / norm_factor
valsn *= cond_phin[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - torch.minimum(dist_phin, (2 * math.pi - dist_phin)) / dphi)
vals += valsn

return iidx, vals


Expand Down Expand Up @@ -258,10 +300,12 @@ def __init__(
self.kernel_shape = kernel_shape

if len(self.kernel_shape) == 1:
self.kernel_size = self.kernel_shape[0]
self.kernel_size = math.ceil( self.kernel_shape[0] / 2)
if self.kernel_shape[0] % 2 == 0:
warn("Detected isotropic kernel with even number of collocation points in the radial direction. This feature is only supported out of consistency and may lead to unexpected behavior.")
elif len(self.kernel_shape) == 2:
self.kernel_size = (self.kernel_shape[0] - 1) * self.kernel_shape[1] + 1
else:
self.kernel_size = (self.kernel_shape[0] // 2) * self.kernel_shape[1] + self.kernel_shape[0] % 2
if len(self.kernel_shape) > 2:
raise ValueError("kernel_shape should be either one- or two-dimensional.")

# groups
Expand Down Expand Up @@ -311,9 +355,9 @@ def __init__(
self.nlat_in, self.nlon_in = in_shape
self.nlat_out, self.nlon_out = out_shape

# compute theta cutoff based on the bandlimit of the input field
# heuristic to compute theta cutoff based on the bandlimit of the input field and overlaps of the basis functions
if theta_cutoff is None:
theta_cutoff = self.kernel_shape[0] * torch.pi / float(self.nlat_in - 1)
theta_cutoff = (self.kernel_shape[0] + 1) / 2 * torch.pi / float(self.nlat_out - 1)

if theta_cutoff <= 0.0:
raise ValueError("Error, theta_cutoff has to be positive.")
Expand Down Expand Up @@ -399,7 +443,7 @@ def __init__(

# bandlimit
if theta_cutoff is None:
theta_cutoff = self.kernel_shape[0] * torch.pi / float(self.nlat_in - 1)
theta_cutoff = (self.kernel_shape[0] + 1) / 2 * torch.pi / float(self.nlat_in - 1)

if theta_cutoff <= 0.0:
raise ValueError("Error, theta_cutoff has to be positive.")
Expand Down Expand Up @@ -455,7 +499,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

# pre-multiply x with the quadrature weights
x = self.quad_weights * x

if x.is_cuda and _cuda_extension_available:
out = _disco_s2_transpose_contraction_cuda(x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals,
self.kernel_size, self.nlat_out, self.nlon_out)
Expand All @@ -464,7 +508,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
psi = self.get_psi(semi_transposed=True)
out = _disco_s2_transpose_contraction_torch(x, psi, self.nlon_out)

if self.bias is not None:
out = out + self.bias.reshape(1, -1, 1, 1)

Expand Down
8 changes: 4 additions & 4 deletions torch_harmonics/distributed/distributed_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def _precompute_distributed_convolution_tensor_s2(in_shape, out_shape, kernel_sh

nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape

lats_in, _ = _precompute_latitudes(nlat_in, grid=grid_in)
lats_in = torch.from_numpy(lats_in).float()
lats_out, _ = _precompute_latitudes(nlat_out, grid=grid_out)
Expand Down Expand Up @@ -217,7 +217,7 @@ def __init__(
self.nlat_out_local = self.nlat_out
idx, vals = _precompute_distributed_convolution_tensor_s2(in_shape, out_shape, self.kernel_shape, grid_in=grid_in, grid_out=grid_out,
theta_cutoff=theta_cutoff)

# split the weight tensor as well
quad_weights = split_tensor_along_dim(quad_weights, dim=0, num_chunks=self.comm_size_polar)[self.comm_rank_polar]
self.register_buffer("quad_weights", quad_weights, persistent=False)
Expand Down Expand Up @@ -389,7 +389,7 @@ def get_psi(self, semi_transposed: bool=False):
return psi

def forward(self, x: torch.Tensor) -> torch.Tensor:

# extract shape
B, C, H, W = x.shape
x = x.reshape(B, self.groups, self.groupsize, H, W)
Expand All @@ -411,7 +411,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

# register allreduce for bwd pass
x = copy_to_polar_region(x)

if x.is_cuda and _cuda_extension_available:
out = _disco_s2_transpose_contraction_cuda(x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals,
self.kernel_size, self.nlat_out_local, self.nlon_out)
Expand Down

0 comments on commit d78f7c7

Please sign in to comment.