Skip to content

Commit

Permalink
fixed distributed convolution to support the same normalization as no…
Browse files Browse the repository at this point in the history
…n-distributed one
  • Loading branch information
bonevbs committed Apr 28, 2024
1 parent f922468 commit 8469ee1
Showing 1 changed file with 31 additions and 13 deletions.
44 changes: 31 additions & 13 deletions torch_harmonics/distributed/distributed_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from torch_harmonics.convolution import (
_compute_support_vals_isotropic,
_compute_support_vals_anisotropic,
_normalize_onvolution_tensor_s2,
DiscreteContinuousConv,
)

Expand All @@ -68,7 +69,9 @@
_cuda_extension_available = False


def _precompute_distributed_convolution_tensor_s2(in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi):
def _precompute_distributed_convolution_tensor_s2(
in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi, transpose_normalization=False
):
"""
Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al.
Expand All @@ -83,9 +86,6 @@ def _precompute_distributed_convolution_tensor_s2(in_shape, out_shape, kernel_sh
\cos(\alpha)\cos(\gamma)-\cos(\beta)\sin(\alpha)\sin(\gamma)
\end{bmatrix}}
$$
This is the distributed version: the matrix can either be split column- or row-wise. Column-wise seems better because the kernel has a lot of summation
atomics concerning the row reductions, which we can combine in a single allreduce.
"""

assert len(in_shape) == 2
Expand All @@ -101,16 +101,11 @@ def _precompute_distributed_convolution_tensor_s2(in_shape, out_shape, kernel_sh
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape

lats_in, _ = _precompute_latitudes(nlat_in, grid=grid_in)
lats_in, win = _precompute_latitudes(nlat_in, grid=grid_in)
lats_in = torch.from_numpy(lats_in).float()
lats_out, _ = _precompute_latitudes(nlat_out, grid=grid_out)
lats_out, wout = _precompute_latitudes(nlat_out, grid=grid_out)
lats_out = torch.from_numpy(lats_out).float()

# split the latitude vector:
comm_size_polar = polar_group_size()
comm_rank_polar = polar_group_rank()
lats_in = split_tensor_along_dim(lats_in, dim=0, num_chunks=comm_size_polar)[comm_rank_polar]

# compute the phi differences
# It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1]
Expand Down Expand Up @@ -155,6 +150,25 @@ def _precompute_distributed_convolution_tensor_s2(in_shape, out_shape, kernel_sh
out_idx = torch.cat(out_idx, dim=-1).to(torch.long).contiguous()
out_vals = torch.cat(out_vals, dim=-1).to(torch.float32).contiguous()

# perform the normalization over the entire psi matrix
if transpose_normalization:
quad_weights = 2.0 * torch.pi * torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in
else:
quad_weights = 2.0 * torch.pi * torch.from_numpy(win).float().reshape(-1, 1) / nlon_in
out_vals = _normalize_onvolution_tensor_s2(out_idx, out_vals, in_shape, out_shape, kernel_shape, quad_weights, transpose_normalization=transpose_normalization)

# split the latitude indices:
comm_size_polar = polar_group_size()
comm_rank_polar = polar_group_rank()
ilat_in = torch.arange(lats_in.shape[0])
ilat_in = split_tensor_along_dim(ilat_in, dim=0, num_chunks=comm_size_polar)[comm_rank_polar]

# once normalization is done we can throw away the entries which correspond to input latitudes we do not care about
lats = out_idx[2] // nlon_in
ilats = torch.argwhere((lats < ilat_in[-1] + 1) & (lats >= ilat_in[0])).squeeze()
out_idx = out_idx[:, ilats]
out_vals = out_vals[ilats]

return out_idx, out_vals


Expand Down Expand Up @@ -215,7 +229,9 @@ def __init__(
# set local shapes according to distributed mode:
self.nlat_in_local = self.lat_in_shapes[self.comm_rank_polar]
self.nlat_out_local = self.nlat_out
idx, vals = _precompute_distributed_convolution_tensor_s2(in_shape, out_shape, self.kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff)
idx, vals = _precompute_distributed_convolution_tensor_s2(
in_shape, out_shape, self.kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff, transpose_normalization=False
)

# split the weight tensor as well
quad_weights = split_tensor_along_dim(quad_weights, dim=0, num_chunks=self.comm_size_polar)[self.comm_rank_polar]
Expand Down Expand Up @@ -349,7 +365,9 @@ def __init__(

# switch in_shape and out_shape since we want transpose conv
# distributed mode here is swapped because of the transpose
idx, vals = _precompute_distributed_convolution_tensor_s2(out_shape, in_shape, self.kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff)
idx, vals = _precompute_distributed_convolution_tensor_s2(
out_shape, in_shape, self.kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff, transpose_normalization=True
)

# split the weight tensor as well
quad_weights = split_tensor_along_dim(quad_weights, dim=0, num_chunks=self.comm_size_polar)[self.comm_rank_polar]
Expand Down

0 comments on commit 8469ee1

Please sign in to comment.