Skip to content

Commit

Permalink
adding cuda disco reduce scatter
Browse files Browse the repository at this point in the history
  • Loading branch information
azrael417 committed Aug 19, 2024
1 parent 04d28fa commit 92d1a6d
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 24 deletions.
6 changes: 5 additions & 1 deletion tests/test_distributed_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def _gather_helper_fwd(self, tensor, B, C, convolution_dist):
lon_shapes = convolution_dist.lon_out_shapes

# gather in W
tensor = tensor.contiguous()
if self.grid_size_w > 1:
gather_shapes = [(B, C, lat_shapes[self.hrank], w) for w in lon_shapes]
olist = [torch.empty(shape, dtype=tensor.dtype, device=tensor.device) for shape in gather_shapes]
Expand All @@ -140,6 +141,7 @@ def _gather_helper_fwd(self, tensor, B, C, convolution_dist):
tensor_gather = tensor

# gather in H
tensor_gather = tensor_gather.contiguous()
if self.grid_size_h > 1:
gather_shapes = [(B, C, h, convolution_dist.nlon_out) for h in lat_shapes]
olist = [torch.empty(shape, dtype=tensor_gather.dtype, device=tensor_gather.device) for shape in gather_shapes]
Expand Down Expand Up @@ -221,7 +223,8 @@ def test_distributed_disco_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, batc
conv_dist.bias.copy_(conv_local.bias)

# create tensors
inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device)
#inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device)
inp_full = torch.arange(0, B*C*H*W, dtype=torch.float32, device=self.device).reshape(B, C, H, W)

#############################################################
# local conv
Expand Down Expand Up @@ -268,6 +271,7 @@ def test_distributed_disco_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, batc
#############################################################
with torch.no_grad():
igrad_gather_full = self._gather_helper_bwd(igrad_local, B, C, conv_dist)

err = torch.mean(torch.norm(igrad_full - igrad_gather_full, p="fro", dim=(-1, -2)) / torch.norm(igrad_full, p="fro", dim=(-1, -2)))
if self.world_rank == 0:
print(f"final relative error of gradients: {err.item()}")
Expand Down
2 changes: 1 addition & 1 deletion torch_harmonics/distributed/distributed_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

# store number of channels
num_chans = x.shape[1]

# h and w is split. First we make w local by transposing into channel dim
if self.comm_size_azimuth > 1:
x = distributed_transpose_azimuth.apply(x, (1, -1), self.lon_in_shapes)
Expand Down
31 changes: 9 additions & 22 deletions torch_harmonics/distributed/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,6 @@ def compute_split_shapes(size: int, num_chunks: int) -> List[int]:

return sections


# general helpers
def get_memory_format(tensor):
if tensor.is_contiguous(memory_format=torch.channels_last):
return torch.channels_last
else:
return torch.contiguous_format


def split_tensor_along_dim(tensor, dim, num_chunks):
assert dim < tensor.dim(), f"Error, tensor dimension is {tensor.dim()} which cannot be split along {dim}"
Expand All @@ -78,23 +70,20 @@ def split_tensor_along_dim(tensor, dim, num_chunks):


def _transpose(tensor, dim0, dim1, dim1_split_sizes, group=None, async_op=False):
# get input format
input_format = get_memory_format(tensor)

# get comm params
comm_size = dist.get_world_size(group=group)
comm_rank = dist.get_rank(group=group)

# split and local transposition
tsplit = split_tensor_along_dim(tensor, num_chunks=comm_size, dim=dim0)
x_send = [y.contiguous(memory_format=input_format) for y in tsplit]
x_send = [y.contiguous() for y in tsplit]
x_send_shapes = [x.shape for x in x_send]
x_recv = []
x_shape = list(x_send_shapes[comm_rank])
for dim1_len in dim1_split_sizes:
x_shape[dim1] = dim1_len
x_recv.append(torch.empty(x_shape, dtype=tensor.dtype, device=tensor.device, memory_format=input_format))

x_recv.append(torch.empty(x_shape, dtype=tensor.dtype, device=tensor.device))
# global transposition
req = dist.all_to_all(x_recv, x_send, group=group, async_op=async_op)

Expand All @@ -108,49 +97,47 @@ class distributed_transpose_azimuth(torch.autograd.Function):

@staticmethod
def forward(ctx, x, dims, dim1_split_sizes):
input_format = get_memory_format(x)
# WAR for a potential contig check torch bug for channels last contig tensors
x = x.contiguous()
xlist, dim0_split_sizes, _ = _transpose(x, dims[0], dims[1], dim1_split_sizes, group=azimuth_group())
x = torch.cat(xlist, dim=dims[1]).contiguous(memory_format=input_format)
x = torch.cat(xlist, dim=dims[1]).contiguous()
ctx.dims = dims
ctx.dim0_split_sizes = dim0_split_sizes

return x

@staticmethod
def backward(ctx, go):
input_format = get_memory_format(go)
dims = ctx.dims
dim0_split_sizes = ctx.dim0_split_sizes
# WAR for a potential contig check torch bug for channels last contig tensors
go = go.contiguous()
gilist, _, _ = _transpose(go, dims[1], dims[0], dim0_split_sizes, group=azimuth_group())
gi = torch.cat(gilist, dim=dims[0]).contiguous(memory_format=input_format)
gi = torch.cat(gilist, dim=dims[0]).contiguous()

return gi, None, None


class distributed_transpose_polar(torch.autograd.Function):

@staticmethod
def forward(ctx, x, dim, dim1_split_sizes):
input_format = get_memory_format(x)
# WAR for a potential contig check torch bug for channels last contig tensors
x = x.contiguous()
xlist, dim0_split_sizes, _ = _transpose(x, dim[0], dim[1], dim1_split_sizes, group=polar_group())
x = torch.cat(xlist, dim=dim[1]).contiguous(memory_format=input_format)
x = torch.cat(xlist, dim=dim[1]).contiguous()
ctx.dim = dim
ctx.dim0_split_sizes = dim0_split_sizes
return x

@staticmethod
def backward(ctx, go):
input_format = get_memory_format(go)
dim = ctx.dim
dim0_split_sizes = ctx.dim0_split_sizes
# WAR for a potential contig check torch bug for channels last contig tensors
go = go.contiguous()
gilist, _, _ = _transpose(go, dim[1], dim[0], dim0_split_sizes, group=polar_group())
gi = torch.cat(gilist, dim=dim[0]).contiguous(memory_format=input_format)
gi = torch.cat(gilist, dim=dim[0]).contiguous()
return gi, None, None


Expand Down

0 comments on commit 92d1a6d

Please sign in to comment.