Skip to content

Commit

Permalink
changing distributeds SHT to use dim=-3 as the channel dimension for …
Browse files Browse the repository at this point in the history
…distributed transpose
  • Loading branch information
bonevbs committed Sep 9, 2024
1 parent d34f298 commit 9f71769
Showing 1 changed file with 5 additions and 12 deletions.
17 changes: 5 additions & 12 deletions torch_harmonics/distributed/distributed_sht.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,16 +125,12 @@ def extra_repr(self):

def forward(self, x: torch.Tensor):

# reshape x to contain only the last two dimensions. Prior dimensions are kept for batching
inshape = x.shape
x = x.reshape(-1, *inshape[-2:])

# we need to ensure that we can split the channels evenly
num_chans = x.shape[0]
num_chans = x.shape[-3]

# h and w is split. First we make w local by transposing into channel dim
if self.comm_size_azimuth > 1:
x = distributed_transpose_azimuth.apply(x, (0, -1), self.lon_shapes)
x = distributed_transpose_azimuth.apply(x, (-3, -1), self.lon_shapes)

# apply real fft in the longitudinal direction: make sure to truncate to nlon
x = 2.0 * torch.pi * torch.fft.rfft(x, n=self.nlon, dim=-1, norm="forward")
Expand All @@ -145,11 +141,11 @@ def forward(self, x: torch.Tensor):
# transpose: after this, m is split and c is local
if self.comm_size_azimuth > 1:
chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
x = distributed_transpose_azimuth.apply(x, (-1, 0), chan_shapes)
x = distributed_transpose_azimuth.apply(x, (-1, -3), chan_shapes)

# transpose: after this, c is split and h is local
if self.comm_size_polar > 1:
x = distributed_transpose_polar.apply(x, (0, -2), self.lat_shapes)
x = distributed_transpose_polar.apply(x, (-3, -2), self.lat_shapes)

# do the Legendre-Gauss quadrature
x = torch.view_as_real(x)
Expand All @@ -163,10 +159,7 @@ def forward(self, x: torch.Tensor):
# transpose: after this, l is split and c is local
if self.comm_size_polar > 1:
chan_shapes = compute_split_shapes(num_chans, self.comm_size_polar)
x = distributed_transpose_polar.apply(x, (-2, 0), chan_shapes)

# cast channel dimensions back to the original shape
x = x.reshape(*inshape[:-2], *x.shape[-2:])
x = distributed_transpose_polar.apply(x, (-2, -3), chan_shapes)

return x

Expand Down

0 comments on commit 9f71769

Please sign in to comment.