diff --git a/tests/test_distributed_convolution.py b/tests/test_distributed_convolution.py index 554ab9e..c24801b 100644 --- a/tests/test_distributed_convolution.py +++ b/tests/test_distributed_convolution.py @@ -46,11 +46,11 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase): def setUpClass(cls): # set up distributed - cls.world_rank = int(os.getenv('WORLD_RANK', 0)) - cls.grid_size_h = int(os.getenv('GRID_H', 1)) - cls.grid_size_w = int(os.getenv('GRID_W', 1)) - port = int(os.getenv('MASTER_PORT', '29501')) - master_address = os.getenv('MASTER_ADDR', 'localhost') + cls.world_rank = int(os.getenv("WORLD_RANK", 0)) + cls.grid_size_h = int(os.getenv("GRID_H", 1)) + cls.grid_size_w = int(os.getenv("GRID_W", 1)) + port = int(os.getenv("MASTER_PORT", "29501")) + master_address = os.getenv("MASTER_ADDR", "localhost") cls.world_size = cls.grid_size_h * cls.grid_size_w if torch.cuda.is_available(): @@ -60,24 +60,21 @@ def setUpClass(cls): cls.device = torch.device(f"cuda:{local_rank}") torch.cuda.set_device(local_rank) torch.cuda.manual_seed(333) - proc_backend = 'nccl' + proc_backend = "nccl" else: if cls.world_rank == 0: print("Running test on CPU") - cls.device = torch.device('cpu') - proc_backend = 'gloo' + cls.device = torch.device("cpu") + proc_backend = "gloo" torch.manual_seed(333) - dist.init_process_group(backend = proc_backend, - init_method = f"tcp://{master_address}:{port}", - rank = cls.world_rank, - world_size = cls.world_size) + dist.init_process_group(backend=proc_backend, init_method=f"tcp://{master_address}:{port}", rank=cls.world_rank, world_size=cls.world_size) cls.wrank = cls.world_rank % cls.grid_size_w cls.hrank = cls.world_rank // cls.grid_size_w # now set up the comm groups: - #set default + # set default cls.w_group = None cls.h_group = None @@ -109,14 +106,12 @@ def setUpClass(cls): if cls.world_rank in grp: cls.h_group = tmp_group - if cls.world_rank == 0: print(f"Running distributed tests on grid H x W = {cls.grid_size_h} x {cls.grid_size_w}") # initializing sht thd.init(cls.h_group, cls.w_group) - def _split_helper(self, tensor): with torch.no_grad(): # split in W @@ -129,7 +124,6 @@ def _split_helper(self, tensor): return tensor_local - def _gather_helper_fwd(self, tensor, B, C, convolution_dist): # we need the shapes lat_shapes = convolution_dist.lat_out_shapes @@ -155,7 +149,6 @@ def _gather_helper_fwd(self, tensor, B, C, convolution_dist): return tensor_gather - def _gather_helper_bwd(self, tensor, B, C, convolution_dist): # we need the shapes lat_shapes = convolution_dist.lat_in_shapes @@ -181,31 +174,37 @@ def _gather_helper_bwd(self, tensor, B, C, convolution_dist): return tensor_gather - - @parameterized.expand([ - [128, 256, 128, 256, 32, 8, [3 ], 1, "equiangular", "equiangular", False, 1e-6], - [129, 256, 128, 256, 32, 8, [3 ], 1, "equiangular", "equiangular", False, 1e-6], - [128, 256, 128, 256, 32, 8, [3, 2], 1, "equiangular", "equiangular", False, 1e-6], - [128, 256, 64, 128, 32, 8, [3 ], 1, "equiangular", "equiangular", False, 1e-6], - [128, 256, 128, 256, 32, 8, [3 ], 2, "equiangular", "equiangular", False, 1e-6], - [128, 256, 128, 256, 32, 5, [3 ], 1, "equiangular", "equiangular", False, 1e-6], - - [128, 256, 128, 256, 32, 8, [3 ], 1, "equiangular", "equiangular", True, 1e-6], - [129, 256, 128, 256, 32, 8, [3 ], 1, "equiangular", "equiangular", True, 1e-6], - [128, 256, 128, 256, 32, 8, [3, 2], 1, "equiangular", "equiangular", True, 1e-6], - [ 64, 128, 128, 256, 32, 8, [3 ], 1, "equiangular", "equiangular", True, 1e-6], - [128, 256, 128, 256, 32, 8, [3 ], 2, "equiangular", "equiangular", True, 1e-6], - [128, 256, 128, 256, 32, 5, [3 ], 1, "equiangular", "equiangular", True, 1e-6], - ]) - def test_distributed_disco_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, - kernel_shape, groups, grid_in, grid_out, transpose, tol): + @parameterized.expand( + [ + [128, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", False, 1e-6], + [129, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", False, 1e-6], + [128, 256, 128, 256, 32, 8, [3, 2], 1, "equiangular", "equiangular", False, 1e-6], + [128, 256, 64, 128, 32, 8, [3], 1, "equiangular", "equiangular", False, 1e-6], + [128, 256, 128, 256, 32, 8, [3], 2, "equiangular", "equiangular", False, 1e-6], + [128, 256, 128, 256, 32, 5, [3], 1, "equiangular", "equiangular", False, 1e-6], + [128, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", True, 1e-6], + [129, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", True, 1e-6], + [128, 256, 128, 256, 32, 8, [3, 2], 1, "equiangular", "equiangular", True, 1e-6], + [64, 128, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", True, 1e-6], + [128, 256, 128, 256, 32, 8, [3], 2, "equiangular", "equiangular", True, 1e-6], + [128, 256, 128, 256, 32, 5, [3], 1, "equiangular", "equiangular", True, 1e-6], + ] + ) + def test_distributed_disco_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, kernel_shape, groups, grid_in, grid_out, transpose, tol): B, C, H, W = batch_size, num_chan, nlat_in, nlon_in - disco_args = dict(in_channels=C, out_channels=C, - in_shape=(nlat_in, nlon_in), out_shape=(nlat_out, nlon_out), - kernel_shape=kernel_shape, groups=groups, - grid_in=grid_in, grid_out=grid_out, bias=True) + disco_args = dict( + in_channels=C, + out_channels=C, + in_shape=(nlat_in, nlon_in), + out_shape=(nlat_out, nlon_out), + kernel_shape=kernel_shape, + groups=groups, + grid_in=grid_in, + grid_out=grid_out, + bias=True, + ) # set up handles if transpose: @@ -253,13 +252,13 @@ def test_distributed_disco_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, batc out_local = conv_dist(inp_local) out_local.backward(ograd_local) igrad_local = inp_local.grad.clone() - + ############################################################# # evaluate FWD pass ############################################################# with torch.no_grad(): out_gather_full = self._gather_helper_fwd(out_local, B, C, conv_dist) - err = torch.mean(torch.norm(out_full-out_gather_full, p='fro', dim=(-1,-2)) / torch.norm(out_full, p='fro', dim=(-1,-2)) ) + err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2))) if self.world_rank == 0: print(f"final relative error of output: {err.item()}") self.assertTrue(err.item() <= tol) @@ -269,11 +268,11 @@ def test_distributed_disco_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, batc ############################################################# with torch.no_grad(): igrad_gather_full = self._gather_helper_bwd(igrad_local, B, C, conv_dist) - err = torch.mean(torch.norm(igrad_full-igrad_gather_full, p='fro', dim=(-1,-2)) / torch.norm(igrad_full, p='fro', dim=(-1,-2)) ) + err = torch.mean(torch.norm(igrad_full - igrad_gather_full, p="fro", dim=(-1, -2)) / torch.norm(igrad_full, p="fro", dim=(-1, -2))) if self.world_rank == 0: print(f"final relative error of gradients: {err.item()}") self.assertTrue(err.item() <= tol) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main()