Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
bonevbs committed Apr 29, 2024
1 parent c1c701f commit 03830e5
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 18 deletions.
23 changes: 10 additions & 13 deletions tests/test_distributed_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,17 +176,17 @@ def _gather_helper_bwd(self, tensor, B, C, convolution_dist):

@parameterized.expand(
[
# [128, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", False, 1e-5],
# [129, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", False, 1e-5],
# [128, 256, 128, 256, 32, 8, [3, 2], 1, "equiangular", "equiangular", False, 1e-5],
# [128, 256, 64, 128, 32, 8, [3], 1, "equiangular", "equiangular", False, 1e-5],
# [128, 256, 128, 256, 32, 8, [3], 2, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", False, 1e-5],
[129, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 8, [3, 2], 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 64, 128, 32, 8, [3], 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 8, [3], 2, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 6, [3], 1, "equiangular", "equiangular", False, 1e-5],
# [128, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", True, 1e-5],
# [129, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", True, 1e-5],
# [128, 256, 128, 256, 32, 8, [3, 2], 1, "equiangular", "equiangular", True, 1e-5],
# [64, 128, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", True, 1e-5],
# [128, 256, 128, 256, 32, 8, [3], 2, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", True, 1e-5],
[129, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 8, [3, 2], 1, "equiangular", "equiangular", True, 1e-5],
[64, 128, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 8, [3], 2, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 6, [3], 1, "equiangular", "equiangular", True, 1e-5],
]
)
Expand Down Expand Up @@ -261,9 +261,6 @@ def test_distributed_disco_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, batc
err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2)))
if self.world_rank == 0:
print(f"final relative error of output: {err.item()}")

print(torch.argmax(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2))))

self.assertTrue(err.item() <= tol)

#############################################################
Expand Down
6 changes: 1 addition & 5 deletions torch_harmonics/distributed/distributed_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,16 +167,12 @@ def _precompute_distributed_convolution_tensor_s2(
start_idx = offsets[comm_rank_polar]
end_idx = offsets[comm_rank_polar+1]

print(f"polar rank: {comm_rank_polar}, start {start_idx}, end {end_idx}")

# once normalization is done we can throw away the entries which correspond to input latitudes we do not care about
lats = out_idx[2] // nlon_in
lons = out_idx[2] % nlon_in
ilats = torch.argwhere((lats < end_idx) & (lats >= start_idx)).squeeze()
out_vals = out_vals[ilats]
# for the indices we need to recompute them to fit the
# out_idx = out_idx[:, ilats]

# for the indices we need to recompute them to refer to local indices of the input tenor
out_idx = torch.stack([out_idx[0, ilats], out_idx[1, ilats], (lats[ilats]-start_idx) * nlon_in + lons[ilats]], dim=0)

return out_idx, out_vals
Expand Down

0 comments on commit 03830e5

Please sign in to comment.