Skip to content

Commit

Permalink
debugging commit
Browse files Browse the repository at this point in the history
  • Loading branch information
bonevbs committed Apr 29, 2024
1 parent 1598465 commit e07017f
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 11 deletions.
24 changes: 13 additions & 11 deletions tests/test_distributed_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,17 +177,17 @@ def _gather_helper_bwd(self, tensor, B, C, convolution_dist):
@parameterized.expand(
[
[128, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", False, 1e-6],
[129, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", False, 1e-6],
[128, 256, 128, 256, 32, 8, [3, 2], 1, "equiangular", "equiangular", False, 1e-6],
[128, 256, 64, 128, 32, 8, [3], 1, "equiangular", "equiangular", False, 1e-6],
[128, 256, 128, 256, 32, 8, [3], 2, "equiangular", "equiangular", False, 1e-6],
[128, 256, 128, 256, 32, 5, [3], 1, "equiangular", "equiangular", False, 1e-6],
[128, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", True, 1e-6],
[129, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", True, 1e-6],
[128, 256, 128, 256, 32, 8, [3, 2], 1, "equiangular", "equiangular", True, 1e-6],
[64, 128, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", True, 1e-6],
[128, 256, 128, 256, 32, 8, [3], 2, "equiangular", "equiangular", True, 1e-6],
[128, 256, 128, 256, 32, 5, [3], 1, "equiangular", "equiangular", True, 1e-6],
# [129, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", False, 1e-6],
# [128, 256, 128, 256, 32, 8, [3, 2], 1, "equiangular", "equiangular", False, 1e-6],
# [128, 256, 64, 128, 32, 8, [3], 1, "equiangular", "equiangular", False, 1e-6],
# [128, 256, 128, 256, 32, 8, [3], 2, "equiangular", "equiangular", False, 1e-6],
# [128, 256, 128, 256, 32, 5, [3], 1, "equiangular", "equiangular", False, 1e-6],
# [128, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", True, 1e-6],
# [129, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", True, 1e-6],
# [128, 256, 128, 256, 32, 8, [3, 2], 1, "equiangular", "equiangular", True, 1e-6],
# [64, 128, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", True, 1e-6],
# [128, 256, 128, 256, 32, 8, [3], 2, "equiangular", "equiangular", True, 1e-6],
# [128, 256, 128, 256, 32, 5, [3], 1, "equiangular", "equiangular", True, 1e-6],
]
)
def test_distributed_disco_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, kernel_shape, groups, grid_in, grid_out, transpose, tol):
Expand Down Expand Up @@ -261,6 +261,8 @@ def test_distributed_disco_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, batc
err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2)))
if self.world_rank == 0:
print(f"final relative error of output: {err.item()}")

print(err.item())
self.assertTrue(err.item() <= tol)

#############################################################
Expand Down
2 changes: 2 additions & 0 deletions torch_harmonics/distributed/distributed_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ def _precompute_distributed_convolution_tensor_s2(
start_idx = ([0] + list(accumulate(split_shapes)))[comm_rank_polar]
end_idx = start_idx + split_shapes[comm_rank_polar]

print(f"polar rank: {comm_rank_polar}, start {start_idx}, end {end_idx}")

# once normalization is done we can throw away the entries which correspond to input latitudes we do not care about
lats = out_idx[2] // nlon_in
lons = out_idx[2] % nlon_in
Expand Down

0 comments on commit e07017f

Please sign in to comment.