diff --git a/tests/test_distributed_convolution.py b/tests/test_distributed_convolution.py index c24801b..e1b65ab 100644 --- a/tests/test_distributed_convolution.py +++ b/tests/test_distributed_convolution.py @@ -177,17 +177,17 @@ def _gather_helper_bwd(self, tensor, B, C, convolution_dist): @parameterized.expand( [ [128, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", False, 1e-6], - [129, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", False, 1e-6], - [128, 256, 128, 256, 32, 8, [3, 2], 1, "equiangular", "equiangular", False, 1e-6], - [128, 256, 64, 128, 32, 8, [3], 1, "equiangular", "equiangular", False, 1e-6], - [128, 256, 128, 256, 32, 8, [3], 2, "equiangular", "equiangular", False, 1e-6], - [128, 256, 128, 256, 32, 5, [3], 1, "equiangular", "equiangular", False, 1e-6], - [128, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", True, 1e-6], - [129, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", True, 1e-6], - [128, 256, 128, 256, 32, 8, [3, 2], 1, "equiangular", "equiangular", True, 1e-6], - [64, 128, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", True, 1e-6], - [128, 256, 128, 256, 32, 8, [3], 2, "equiangular", "equiangular", True, 1e-6], - [128, 256, 128, 256, 32, 5, [3], 1, "equiangular", "equiangular", True, 1e-6], + # [129, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", False, 1e-6], + # [128, 256, 128, 256, 32, 8, [3, 2], 1, "equiangular", "equiangular", False, 1e-6], + # [128, 256, 64, 128, 32, 8, [3], 1, "equiangular", "equiangular", False, 1e-6], + # [128, 256, 128, 256, 32, 8, [3], 2, "equiangular", "equiangular", False, 1e-6], + # [128, 256, 128, 256, 32, 5, [3], 1, "equiangular", "equiangular", False, 1e-6], + # [128, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", True, 1e-6], + # [129, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", True, 1e-6], + # [128, 256, 128, 256, 32, 8, [3, 2], 1, "equiangular", "equiangular", True, 1e-6], + # [64, 128, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", True, 1e-6], + # [128, 256, 128, 256, 32, 8, [3], 2, "equiangular", "equiangular", True, 1e-6], + # [128, 256, 128, 256, 32, 5, [3], 1, "equiangular", "equiangular", True, 1e-6], ] ) def test_distributed_disco_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, kernel_shape, groups, grid_in, grid_out, transpose, tol): @@ -261,6 +261,8 @@ def test_distributed_disco_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, batc err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2))) if self.world_rank == 0: print(f"final relative error of output: {err.item()}") + + print(err.item()) self.assertTrue(err.item() <= tol) ############################################################# diff --git a/torch_harmonics/distributed/distributed_convolution.py b/torch_harmonics/distributed/distributed_convolution.py index 41505ee..acd78b9 100644 --- a/torch_harmonics/distributed/distributed_convolution.py +++ b/torch_harmonics/distributed/distributed_convolution.py @@ -166,6 +166,8 @@ def _precompute_distributed_convolution_tensor_s2( start_idx = ([0] + list(accumulate(split_shapes)))[comm_rank_polar] end_idx = start_idx + split_shapes[comm_rank_polar] + print(f"polar rank: {comm_rank_polar}, start {start_idx}, end {end_idx}") + # once normalization is done we can throw away the entries which correspond to input latitudes we do not care about lats = out_idx[2] // nlon_in lons = out_idx[2] % nlon_in