Skip to content

Commit

Permalink
Bbonev/distributed disco refactor (#37)
Browse files Browse the repository at this point in the history
* cleaned up normalization code in convolution

* formatting changes in distributed convolution

* Fixing default theta_cutoff to be the same in distributed and local case

* fixed distributed convolution to support the same normalization as non-distributed one

* readability improvements
  • Loading branch information
bonevbs authored Apr 29, 2024
1 parent ce480c1 commit 1b04a1b
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 79 deletions.
85 changes: 42 additions & 43 deletions tests/test_distributed_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
def setUpClass(cls):

# set up distributed
cls.world_rank = int(os.getenv('WORLD_RANK', 0))
cls.grid_size_h = int(os.getenv('GRID_H', 1))
cls.grid_size_w = int(os.getenv('GRID_W', 1))
port = int(os.getenv('MASTER_PORT', '29501'))
master_address = os.getenv('MASTER_ADDR', 'localhost')
cls.world_rank = int(os.getenv("WORLD_RANK", 0))
cls.grid_size_h = int(os.getenv("GRID_H", 1))
cls.grid_size_w = int(os.getenv("GRID_W", 1))
port = int(os.getenv("MASTER_PORT", "29501"))
master_address = os.getenv("MASTER_ADDR", "localhost")
cls.world_size = cls.grid_size_h * cls.grid_size_w

if torch.cuda.is_available():
Expand All @@ -60,24 +60,21 @@ def setUpClass(cls):
cls.device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(local_rank)
torch.cuda.manual_seed(333)
proc_backend = 'nccl'
proc_backend = "nccl"
else:
if cls.world_rank == 0:
print("Running test on CPU")
cls.device = torch.device('cpu')
proc_backend = 'gloo'
cls.device = torch.device("cpu")
proc_backend = "gloo"
torch.manual_seed(333)

dist.init_process_group(backend = proc_backend,
init_method = f"tcp://{master_address}:{port}",
rank = cls.world_rank,
world_size = cls.world_size)
dist.init_process_group(backend=proc_backend, init_method=f"tcp://{master_address}:{port}", rank=cls.world_rank, world_size=cls.world_size)

cls.wrank = cls.world_rank % cls.grid_size_w
cls.hrank = cls.world_rank // cls.grid_size_w

# now set up the comm groups:
#set default
# set default
cls.w_group = None
cls.h_group = None

Expand Down Expand Up @@ -109,14 +106,12 @@ def setUpClass(cls):
if cls.world_rank in grp:
cls.h_group = tmp_group


if cls.world_rank == 0:
print(f"Running distributed tests on grid H x W = {cls.grid_size_h} x {cls.grid_size_w}")

# initializing sht
thd.init(cls.h_group, cls.w_group)


def _split_helper(self, tensor):
with torch.no_grad():
# split in W
Expand All @@ -129,7 +124,6 @@ def _split_helper(self, tensor):

return tensor_local


def _gather_helper_fwd(self, tensor, B, C, convolution_dist):
# we need the shapes
lat_shapes = convolution_dist.lat_out_shapes
Expand All @@ -155,7 +149,6 @@ def _gather_helper_fwd(self, tensor, B, C, convolution_dist):

return tensor_gather


def _gather_helper_bwd(self, tensor, B, C, convolution_dist):
# we need the shapes
lat_shapes = convolution_dist.lat_in_shapes
Expand All @@ -181,31 +174,37 @@ def _gather_helper_bwd(self, tensor, B, C, convolution_dist):

return tensor_gather


@parameterized.expand([
[128, 256, 128, 256, 32, 8, [3 ], 1, "equiangular", "equiangular", False, 1e-6],
[129, 256, 128, 256, 32, 8, [3 ], 1, "equiangular", "equiangular", False, 1e-6],
[128, 256, 128, 256, 32, 8, [3, 2], 1, "equiangular", "equiangular", False, 1e-6],
[128, 256, 64, 128, 32, 8, [3 ], 1, "equiangular", "equiangular", False, 1e-6],
[128, 256, 128, 256, 32, 8, [3 ], 2, "equiangular", "equiangular", False, 1e-6],
[128, 256, 128, 256, 32, 5, [3 ], 1, "equiangular", "equiangular", False, 1e-6],
[128, 256, 128, 256, 32, 8, [3 ], 1, "equiangular", "equiangular", True, 1e-6],
[129, 256, 128, 256, 32, 8, [3 ], 1, "equiangular", "equiangular", True, 1e-6],
[128, 256, 128, 256, 32, 8, [3, 2], 1, "equiangular", "equiangular", True, 1e-6],
[ 64, 128, 128, 256, 32, 8, [3 ], 1, "equiangular", "equiangular", True, 1e-6],
[128, 256, 128, 256, 32, 8, [3 ], 2, "equiangular", "equiangular", True, 1e-6],
[128, 256, 128, 256, 32, 5, [3 ], 1, "equiangular", "equiangular", True, 1e-6],
])
def test_distributed_disco_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan,
kernel_shape, groups, grid_in, grid_out, transpose, tol):
@parameterized.expand(
[
[128, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", False, 1e-5],
[129, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 8, [3, 2], 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 64, 128, 32, 8, [3], 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 8, [3], 2, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 6, [3], 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", True, 1e-5],
[129, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 8, [3, 2], 1, "equiangular", "equiangular", True, 1e-5],
[64, 128, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 8, [3], 2, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 6, [3], 1, "equiangular", "equiangular", True, 1e-5],
]
)
def test_distributed_disco_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, kernel_shape, groups, grid_in, grid_out, transpose, tol):

B, C, H, W = batch_size, num_chan, nlat_in, nlon_in

disco_args = dict(in_channels=C, out_channels=C,
in_shape=(nlat_in, nlon_in), out_shape=(nlat_out, nlon_out),
kernel_shape=kernel_shape, groups=groups,
grid_in=grid_in, grid_out=grid_out, bias=True)
disco_args = dict(
in_channels=C,
out_channels=C,
in_shape=(nlat_in, nlon_in),
out_shape=(nlat_out, nlon_out),
kernel_shape=kernel_shape,
groups=groups,
grid_in=grid_in,
grid_out=grid_out,
bias=True,
)

# set up handles
if transpose:
Expand Down Expand Up @@ -253,13 +252,13 @@ def test_distributed_disco_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, batc
out_local = conv_dist(inp_local)
out_local.backward(ograd_local)
igrad_local = inp_local.grad.clone()

#############################################################
# evaluate FWD pass
#############################################################
with torch.no_grad():
out_gather_full = self._gather_helper_fwd(out_local, B, C, conv_dist)
err = torch.mean(torch.norm(out_full-out_gather_full, p='fro', dim=(-1,-2)) / torch.norm(out_full, p='fro', dim=(-1,-2)) )
err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2)))
if self.world_rank == 0:
print(f"final relative error of output: {err.item()}")
self.assertTrue(err.item() <= tol)
Expand All @@ -269,11 +268,11 @@ def test_distributed_disco_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, batc
#############################################################
with torch.no_grad():
igrad_gather_full = self._gather_helper_bwd(igrad_local, B, C, conv_dist)
err = torch.mean(torch.norm(igrad_full-igrad_gather_full, p='fro', dim=(-1,-2)) / torch.norm(igrad_full, p='fro', dim=(-1,-2)) )
err = torch.mean(torch.norm(igrad_full - igrad_gather_full, p="fro", dim=(-1, -2)) / torch.norm(igrad_full, p="fro", dim=(-1, -2)))
if self.world_rank == 0:
print(f"final relative error of gradients: {err.item()}")
self.assertTrue(err.item() <= tol)


if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
19 changes: 10 additions & 9 deletions torch_harmonics/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,16 +160,13 @@ def _normalize_onvolution_tensor_s2(psi_idx, psi_vals, in_shape, out_shape, kern
# pre-compute the quadrature weights
q = quad_weights[idx[1]].reshape(-1)

# compute scale factor
scale_factor = float(nlon_in // nlon_out)

# loop through dimensions which require normalization
for ik in range(kernel_size):
for ilat in range(nlat_in):
# get relevant entries
iidx = torch.argwhere((idx[0] == ik) & (idx[2] == ilat))
# normalize, while summing also over the input longitude dimension here as this is not available for the output
vnorm = torch.sum(psi_vals[iidx] * q[iidx]) / scale_factor
vnorm = torch.sum(psi_vals[iidx] * q[iidx])
psi_vals[iidx] = psi_vals[iidx] / (vnorm + eps)
else:
# pre-compute the quadrature weights
Expand All @@ -188,7 +185,7 @@ def _normalize_onvolution_tensor_s2(psi_idx, psi_vals, in_shape, out_shape, kern


def _precompute_convolution_tensor_s2(
in_shape, out_shape, kernel_shape, quad_weights, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi, transpose_normalization=False
in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi, transpose_normalization=False
):
"""
Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
Expand Down Expand Up @@ -219,9 +216,9 @@ def _precompute_convolution_tensor_s2(
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape

lats_in, _ = _precompute_latitudes(nlat_in, grid=grid_in)
lats_in, win = _precompute_latitudes(nlat_in, grid=grid_in)
lats_in = torch.from_numpy(lats_in).float()
lats_out, _ = _precompute_latitudes(nlat_out, grid=grid_out)
lats_out, wout = _precompute_latitudes(nlat_out, grid=grid_out)
lats_out = torch.from_numpy(lats_out).float()

# compute the phi differences
Expand Down Expand Up @@ -268,6 +265,10 @@ def _precompute_convolution_tensor_s2(
out_idx = torch.cat(out_idx, dim=-1).to(torch.long).contiguous()
out_vals = torch.cat(out_vals, dim=-1).to(torch.float32).contiguous()

if transpose_normalization:
quad_weights = 2.0 * torch.pi * torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in
else:
quad_weights = 2.0 * torch.pi * torch.from_numpy(win).float().reshape(-1, 1) / nlon_in
out_vals = _normalize_onvolution_tensor_s2(out_idx, out_vals, in_shape, out_shape, kernel_shape, quad_weights, transpose_normalization=transpose_normalization)

return out_idx, out_vals
Expand Down Expand Up @@ -364,7 +365,7 @@ def __init__(
self.register_buffer("quad_weights", quad_weights, persistent=False)

idx, vals = _precompute_convolution_tensor_s2(
in_shape, out_shape, self.kernel_shape, self.quad_weights, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff, transpose_normalization=False
in_shape, out_shape, self.kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff, transpose_normalization=False
)

# sort the values
Expand Down Expand Up @@ -455,7 +456,7 @@ def __init__(

# switch in_shape and out_shape since we want transpose conv
idx, vals = _precompute_convolution_tensor_s2(
out_shape, in_shape, self.kernel_shape, quad_weights, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff, transpose_normalization=True
out_shape, in_shape, self.kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff, transpose_normalization=True
)

# sort the values
Expand Down
Loading

0 comments on commit 1b04a1b

Please sign in to comment.