diff --git a/torch_harmonics/distributed/distributed_convolution.py b/torch_harmonics/distributed/distributed_convolution.py index 8063f6b..20e6589 100644 --- a/torch_harmonics/distributed/distributed_convolution.py +++ b/torch_harmonics/distributed/distributed_convolution.py @@ -47,6 +47,7 @@ from torch_harmonics.convolution import ( _compute_support_vals_isotropic, _compute_support_vals_anisotropic, + _normalize_onvolution_tensor_s2, DiscreteContinuousConv, ) @@ -68,7 +69,9 @@ _cuda_extension_available = False -def _precompute_distributed_convolution_tensor_s2(in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi): +def _precompute_distributed_convolution_tensor_s2( + in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi, transpose_normalization=False +): """ Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$. Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al. @@ -83,9 +86,6 @@ def _precompute_distributed_convolution_tensor_s2(in_shape, out_shape, kernel_sh \cos(\alpha)\cos(\gamma)-\cos(\beta)\sin(\alpha)\sin(\gamma) \end{bmatrix}} $$ - - This is the distributed version: the matrix can either be split column- or row-wise. Column-wise seems better because the kernel has a lot of summation - atomics concerning the row reductions, which we can combine in a single allreduce. """ assert len(in_shape) == 2 @@ -101,16 +101,11 @@ def _precompute_distributed_convolution_tensor_s2(in_shape, out_shape, kernel_sh nlat_in, nlon_in = in_shape nlat_out, nlon_out = out_shape - lats_in, _ = _precompute_latitudes(nlat_in, grid=grid_in) + lats_in, win = _precompute_latitudes(nlat_in, grid=grid_in) lats_in = torch.from_numpy(lats_in).float() - lats_out, _ = _precompute_latitudes(nlat_out, grid=grid_out) + lats_out, wout = _precompute_latitudes(nlat_out, grid=grid_out) lats_out = torch.from_numpy(lats_out).float() - # split the latitude vector: - comm_size_polar = polar_group_size() - comm_rank_polar = polar_group_rank() - lats_in = split_tensor_along_dim(lats_in, dim=0, num_chunks=comm_size_polar)[comm_rank_polar] - # compute the phi differences # It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0 lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1] @@ -155,6 +150,25 @@ def _precompute_distributed_convolution_tensor_s2(in_shape, out_shape, kernel_sh out_idx = torch.cat(out_idx, dim=-1).to(torch.long).contiguous() out_vals = torch.cat(out_vals, dim=-1).to(torch.float32).contiguous() + # perform the normalization over the entire psi matrix + if transpose_normalization: + quad_weights = 2.0 * torch.pi * torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in + else: + quad_weights = 2.0 * torch.pi * torch.from_numpy(win).float().reshape(-1, 1) / nlon_in + out_vals = _normalize_onvolution_tensor_s2(out_idx, out_vals, in_shape, out_shape, kernel_shape, quad_weights, transpose_normalization=transpose_normalization) + + # split the latitude indices: + comm_size_polar = polar_group_size() + comm_rank_polar = polar_group_rank() + ilat_in = torch.arange(lats_in.shape[0]) + ilat_in = split_tensor_along_dim(ilat_in, dim=0, num_chunks=comm_size_polar)[comm_rank_polar] + + # once normalization is done we can throw away the entries which correspond to input latitudes we do not care about + lats = out_idx[2] // nlon_in + ilats = torch.argwhere((lats < ilat_in[-1] + 1) & (lats >= ilat_in[0])).squeeze() + out_idx = out_idx[:, ilats] + out_vals = out_vals[ilats] + return out_idx, out_vals @@ -215,7 +229,9 @@ def __init__( # set local shapes according to distributed mode: self.nlat_in_local = self.lat_in_shapes[self.comm_rank_polar] self.nlat_out_local = self.nlat_out - idx, vals = _precompute_distributed_convolution_tensor_s2(in_shape, out_shape, self.kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff) + idx, vals = _precompute_distributed_convolution_tensor_s2( + in_shape, out_shape, self.kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff, transpose_normalization=False + ) # split the weight tensor as well quad_weights = split_tensor_along_dim(quad_weights, dim=0, num_chunks=self.comm_size_polar)[self.comm_rank_polar] @@ -349,7 +365,9 @@ def __init__( # switch in_shape and out_shape since we want transpose conv # distributed mode here is swapped because of the transpose - idx, vals = _precompute_distributed_convolution_tensor_s2(out_shape, in_shape, self.kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff) + idx, vals = _precompute_distributed_convolution_tensor_s2( + out_shape, in_shape, self.kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff, transpose_normalization=True + ) # split the weight tensor as well quad_weights = split_tensor_along_dim(quad_weights, dim=0, num_chunks=self.comm_size_polar)[self.comm_rank_polar]