From 92d1a6deebfdde6e64e929caf703fcf6adea1d3c Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Mon, 19 Aug 2024 08:32:09 -0700 Subject: [PATCH] adding cuda disco reduce scatter --- tests/test_distributed_convolution.py | 6 +++- .../distributed/distributed_convolution.py | 2 +- torch_harmonics/distributed/primitives.py | 31 ++++++------------- 3 files changed, 15 insertions(+), 24 deletions(-) diff --git a/tests/test_distributed_convolution.py b/tests/test_distributed_convolution.py index ee463ba..77a460d 100644 --- a/tests/test_distributed_convolution.py +++ b/tests/test_distributed_convolution.py @@ -130,6 +130,7 @@ def _gather_helper_fwd(self, tensor, B, C, convolution_dist): lon_shapes = convolution_dist.lon_out_shapes # gather in W + tensor = tensor.contiguous() if self.grid_size_w > 1: gather_shapes = [(B, C, lat_shapes[self.hrank], w) for w in lon_shapes] olist = [torch.empty(shape, dtype=tensor.dtype, device=tensor.device) for shape in gather_shapes] @@ -140,6 +141,7 @@ def _gather_helper_fwd(self, tensor, B, C, convolution_dist): tensor_gather = tensor # gather in H + tensor_gather = tensor_gather.contiguous() if self.grid_size_h > 1: gather_shapes = [(B, C, h, convolution_dist.nlon_out) for h in lat_shapes] olist = [torch.empty(shape, dtype=tensor_gather.dtype, device=tensor_gather.device) for shape in gather_shapes] @@ -221,7 +223,8 @@ def test_distributed_disco_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, batc conv_dist.bias.copy_(conv_local.bias) # create tensors - inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device) + #inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device) + inp_full = torch.arange(0, B*C*H*W, dtype=torch.float32, device=self.device).reshape(B, C, H, W) ############################################################# # local conv @@ -268,6 +271,7 @@ def test_distributed_disco_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, batc ############################################################# with torch.no_grad(): igrad_gather_full = self._gather_helper_bwd(igrad_local, B, C, conv_dist) + err = torch.mean(torch.norm(igrad_full - igrad_gather_full, p="fro", dim=(-1, -2)) / torch.norm(igrad_full, p="fro", dim=(-1, -2))) if self.world_rank == 0: print(f"final relative error of gradients: {err.item()}") diff --git a/torch_harmonics/distributed/distributed_convolution.py b/torch_harmonics/distributed/distributed_convolution.py index 0bebfd1..48f0fb1 100644 --- a/torch_harmonics/distributed/distributed_convolution.py +++ b/torch_harmonics/distributed/distributed_convolution.py @@ -268,7 +268,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # store number of channels num_chans = x.shape[1] - + # h and w is split. First we make w local by transposing into channel dim if self.comm_size_azimuth > 1: x = distributed_transpose_azimuth.apply(x, (1, -1), self.lon_in_shapes) diff --git a/torch_harmonics/distributed/primitives.py b/torch_harmonics/distributed/primitives.py index aa75b1f..aa46b20 100644 --- a/torch_harmonics/distributed/primitives.py +++ b/torch_harmonics/distributed/primitives.py @@ -56,14 +56,6 @@ def compute_split_shapes(size: int, num_chunks: int) -> List[int]: return sections - -# general helpers -def get_memory_format(tensor): - if tensor.is_contiguous(memory_format=torch.channels_last): - return torch.channels_last - else: - return torch.contiguous_format - def split_tensor_along_dim(tensor, dim, num_chunks): assert dim < tensor.dim(), f"Error, tensor dimension is {tensor.dim()} which cannot be split along {dim}" @@ -78,23 +70,20 @@ def split_tensor_along_dim(tensor, dim, num_chunks): def _transpose(tensor, dim0, dim1, dim1_split_sizes, group=None, async_op=False): - # get input format - input_format = get_memory_format(tensor) - # get comm params comm_size = dist.get_world_size(group=group) comm_rank = dist.get_rank(group=group) # split and local transposition tsplit = split_tensor_along_dim(tensor, num_chunks=comm_size, dim=dim0) - x_send = [y.contiguous(memory_format=input_format) for y in tsplit] + x_send = [y.contiguous() for y in tsplit] x_send_shapes = [x.shape for x in x_send] x_recv = [] x_shape = list(x_send_shapes[comm_rank]) for dim1_len in dim1_split_sizes: x_shape[dim1] = dim1_len - x_recv.append(torch.empty(x_shape, dtype=tensor.dtype, device=tensor.device, memory_format=input_format)) - + x_recv.append(torch.empty(x_shape, dtype=tensor.dtype, device=tensor.device)) + # global transposition req = dist.all_to_all(x_recv, x_send, group=group, async_op=async_op) @@ -108,24 +97,24 @@ class distributed_transpose_azimuth(torch.autograd.Function): @staticmethod def forward(ctx, x, dims, dim1_split_sizes): - input_format = get_memory_format(x) # WAR for a potential contig check torch bug for channels last contig tensors x = x.contiguous() xlist, dim0_split_sizes, _ = _transpose(x, dims[0], dims[1], dim1_split_sizes, group=azimuth_group()) - x = torch.cat(xlist, dim=dims[1]).contiguous(memory_format=input_format) + x = torch.cat(xlist, dim=dims[1]).contiguous() ctx.dims = dims ctx.dim0_split_sizes = dim0_split_sizes + return x @staticmethod def backward(ctx, go): - input_format = get_memory_format(go) dims = ctx.dims dim0_split_sizes = ctx.dim0_split_sizes # WAR for a potential contig check torch bug for channels last contig tensors go = go.contiguous() gilist, _, _ = _transpose(go, dims[1], dims[0], dim0_split_sizes, group=azimuth_group()) - gi = torch.cat(gilist, dim=dims[0]).contiguous(memory_format=input_format) + gi = torch.cat(gilist, dim=dims[0]).contiguous() + return gi, None, None @@ -133,24 +122,22 @@ class distributed_transpose_polar(torch.autograd.Function): @staticmethod def forward(ctx, x, dim, dim1_split_sizes): - input_format = get_memory_format(x) # WAR for a potential contig check torch bug for channels last contig tensors x = x.contiguous() xlist, dim0_split_sizes, _ = _transpose(x, dim[0], dim[1], dim1_split_sizes, group=polar_group()) - x = torch.cat(xlist, dim=dim[1]).contiguous(memory_format=input_format) + x = torch.cat(xlist, dim=dim[1]).contiguous() ctx.dim = dim ctx.dim0_split_sizes = dim0_split_sizes return x @staticmethod def backward(ctx, go): - input_format = get_memory_format(go) dim = ctx.dim dim0_split_sizes = ctx.dim0_split_sizes # WAR for a potential contig check torch bug for channels last contig tensors go = go.contiguous() gilist, _, _ = _transpose(go, dim[1], dim[0], dim0_split_sizes, group=polar_group()) - gi = torch.cat(gilist, dim=dim[0]).contiguous(memory_format=input_format) + gi = torch.cat(gilist, dim=dim[0]).contiguous() return gi, None, None