Skip to content

Commit

Permalink
fixing contig in distributed sht
Browse files Browse the repository at this point in the history
  • Loading branch information
azrael417 committed Oct 18, 2023
1 parent 7a68534 commit 2293309
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions torch_harmonics/distributed/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,32 +72,36 @@ class distributed_transpose_azimuth(torch.autograd.Function):

@staticmethod
def forward(ctx, x, dim):
input_format = get_memory_format(x)
xlist, _ = _transpose(x, dim[0], dim[1], group=azimuth_group())
x = torch.cat(xlist, dim=dim[1])
x = torch.cat(xlist, dim=dim[1]).contiguous(memory_format=input_format)
ctx.dim = dim
return x

@staticmethod
def backward(ctx, go):
input_format = get_memory_format(go)
dim = ctx.dim
gilist, _ = _transpose(go, dim[1], dim[0], group=azimuth_group())
gi = torch.cat(gilist, dim=dim[0])
gi = torch.cat(gilist, dim=dim[0]).contiguous(memory_format=input_format)
return gi, None


class distributed_transpose_polar(torch.autograd.Function):

@staticmethod
def forward(ctx, x, dim):
input_format = get_memory_format(x)
xlist, _ = _transpose(x, dim[0], dim[1], group=polar_group())
x = torch.cat(xlist, dim=dim[1])
x = torch.cat(xlist, dim=dim[1]).contiguous(memory_format=input_format)
ctx.dim = dim
return x

@staticmethod
def backward(ctx, go):
input_format = get_memory_format(go)
dim = ctx.dim
gilist, _ = _transpose(go, dim[1], dim[0], group=polar_group())
gi = torch.cat(gilist, dim=dim[0])
gi = torch.cat(gilist, dim=dim[0]).contiguous(memory_format=input_format)
return gi, None

0 comments on commit 2293309

Please sign in to comment.