Skip to content

Commit

Permalink
fixing distributed sht
Browse files Browse the repository at this point in the history
  • Loading branch information
bonevbs committed Sep 7, 2024
1 parent 4a6ed46 commit d34f298
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 29 deletions.
2 changes: 1 addition & 1 deletion examples/train_sfno.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
torch.cuda.manual_seed(333)

# set device
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
torch.cuda.set_device(device.index)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_distributed_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ def setUpClass(cls):

@classmethod
def tearDownClass(cls):
thd.finalize()
dist.destroy_process_group(None)
thd.finalize()
dist.destroy_process_group(None)

def _split_helper(self, tensor):
with torch.no_grad():
Expand Down
33 changes: 20 additions & 13 deletions torch_harmonics/distributed/distributed_sht.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,9 @@ def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho
# combine quadrature weights with the legendre weights
weights = torch.from_numpy(w)
pct = _precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
pct = torch.from_numpy(pct)
pct = torch.from_numpy(pct)
weights = torch.einsum('mlk,k->mlk', pct, weights)

# split weights
weights = split_tensor_along_dim(weights, dim=0, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth]

Expand All @@ -125,12 +125,16 @@ def extra_repr(self):

def forward(self, x: torch.Tensor):

# reshape x to contain only the last two dimensions. Prior dimensions are kept for batching
inshape = x.shape
x = x.reshape(-1, *inshape[-2:])

# we need to ensure that we can split the channels evenly
num_chans = x.shape[1]
num_chans = x.shape[0]

# h and w is split. First we make w local by transposing into channel dim
if self.comm_size_azimuth > 1:
x = distributed_transpose_azimuth.apply(x, (1, -1), self.lon_shapes)
x = distributed_transpose_azimuth.apply(x, (0, -1), self.lon_shapes)

# apply real fft in the longitudinal direction: make sure to truncate to nlon
x = 2.0 * torch.pi * torch.fft.rfft(x, n=self.nlon, dim=-1, norm="forward")
Expand All @@ -141,11 +145,11 @@ def forward(self, x: torch.Tensor):
# transpose: after this, m is split and c is local
if self.comm_size_azimuth > 1:
chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
x = distributed_transpose_azimuth.apply(x, (-1, 1), chan_shapes)
x = distributed_transpose_azimuth.apply(x, (-1, 0), chan_shapes)

# transpose: after this, c is split and h is local
if self.comm_size_polar > 1:
x = distributed_transpose_polar.apply(x, (1, -2), self.lat_shapes)
x = distributed_transpose_polar.apply(x, (0, -2), self.lat_shapes)

# do the Legendre-Gauss quadrature
x = torch.view_as_real(x)
Expand All @@ -159,8 +163,11 @@ def forward(self, x: torch.Tensor):
# transpose: after this, l is split and c is local
if self.comm_size_polar > 1:
chan_shapes = compute_split_shapes(num_chans, self.comm_size_polar)
x = distributed_transpose_polar.apply(x, (-2, 1), chan_shapes)

x = distributed_transpose_polar.apply(x, (-2, 0), chan_shapes)

# cast channel dimensions back to the original shape
x = x.reshape(*inshape[:-2], *x.shape[-2:])

return x


Expand Down Expand Up @@ -210,7 +217,7 @@ def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho
# determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1

# compute splits
# compute splits
self.lat_shapes = compute_split_shapes(self.nlat, self.comm_size_polar)
self.lon_shapes = compute_split_shapes(self.nlon, self.comm_size_azimuth)
self.l_shapes = compute_split_shapes(self.lmax, self.comm_size_polar)
Expand Down Expand Up @@ -351,7 +358,7 @@ def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho
# remember quadrature weights
self.register_buffer('weights', weights, persistent=False)


def extra_repr(self):
"""
Pretty print module
Expand All @@ -361,7 +368,7 @@ def extra_repr(self):
def forward(self, x: torch.Tensor):

assert(len(x.shape) >= 3)

# we need to ensure that we can split the channels evenly
num_chans = x.shape[1]

Expand Down Expand Up @@ -459,7 +466,7 @@ def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho
# determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1

# compute splits
# compute splits
self.lat_shapes = compute_split_shapes(self.nlat, self.comm_size_polar)
self.lon_shapes = compute_split_shapes(self.nlon, self.comm_size_azimuth)
self.l_shapes = compute_split_shapes(self.lmax, self.comm_size_polar)
Expand Down
1 change: 0 additions & 1 deletion torch_harmonics/examples/sfno/models/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,6 @@ def __init__(self,
scale = math.sqrt(gain / in_channels) * torch.ones(self.modes_lat, 2)
scale[0] *= math.sqrt(2)
self.weight = nn.Parameter(scale * torch.view_as_real(torch.randn(*weight_shape, dtype=torch.complex64)))
# self.weight = nn.Parameter(scale * torch.randn(*weight_shape, 2))

# get the right contraction function
self._contract = _contract
Expand Down
24 changes: 12 additions & 12 deletions torch_harmonics/examples/sfno/models/sfno.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
Expand Down Expand Up @@ -57,7 +57,7 @@ def __init__(
separable = False,
rank = 1e-2,
bias = True):
super(SpectralFilterLayer, self).__init__()
super(SpectralFilterLayer, self).__init__()

if factorization is None:
self.filter = SpectralConvS2(forward_transform,
Expand All @@ -67,7 +67,7 @@ def __init__(
gain = gain,
operator_type = operator_type,
bias = bias)

elif factorization is not None:
self.filter = FactorizedSpectralConvS2(forward_transform,
inverse_transform,
Expand Down Expand Up @@ -117,7 +117,7 @@ def __init__(

if inner_skip == "linear" or inner_skip == "identity":
gain_factor /= 2.0

# convolution layer
self.filter = SpectralFilterLayer(forward_transform,
inverse_transform,
Expand Down Expand Up @@ -146,14 +146,14 @@ def __init__(

# first normalisation layer
self.norm0 = norm_layer()

# dropout
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

gain_factor = 1.0
if outer_skip == "linear" or inner_skip == "identity":
gain_factor /= 2.

if use_mlp == True:
mlp_hidden_dim = int(output_dim * mlp_ratio)
self.mlp = MLP(in_features = output_dim,
Expand Down Expand Up @@ -355,7 +355,7 @@ def __init__(
norm_layer0 = nn.Identity
norm_layer1 = norm_layer0
else:
raise NotImplementedError(f"Error, normalization {self.normalization_layer} not implemented.")
raise NotImplementedError(f"Error, normalization {self.normalization_layer} not implemented.")

if pos_embed == "latlon" or pos_embed==True:
self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, self.img_size[0], self.img_size[1]))
Expand Down Expand Up @@ -402,7 +402,7 @@ def __init__(
nn.init.constant_(fc.bias, 0.0)
encoder_layers.append(fc)
self.encoder = nn.Sequential(*encoder_layers)

# prepare the spectral transform
if self.spectral_transform == "sht":

Expand All @@ -424,7 +424,7 @@ def __init__(
self.itrans_up = InverseRealFFT2(*self.img_size, lmax=modes_lat, mmax=modes_lon).float()
self.trans = RealFFT2(self.h, self.w, lmax=modes_lat, mmax=modes_lon).float()
self.itrans = InverseRealFFT2(self.h, self.w, lmax=modes_lat, mmax=modes_lon).float()

else:
raise(ValueError("Unknown spectral transform"))

Expand Down Expand Up @@ -508,7 +508,7 @@ def forward_features(self, x):

for blk in self.blocks:
x = blk(x)

return x

def forward(self, x):
Expand All @@ -529,5 +529,5 @@ def forward(self, x):
x = self.decoder(x)

return x


0 comments on commit d34f298

Please sign in to comment.