Skip to content

Commit

Permalink
More formatting changes
Browse files Browse the repository at this point in the history
  • Loading branch information
bonevbs committed Apr 28, 2024
1 parent e6a0a61 commit f922468
Showing 1 changed file with 42 additions and 43 deletions.
85 changes: 42 additions & 43 deletions tests/test_distributed_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
def setUpClass(cls):

# set up distributed
cls.world_rank = int(os.getenv('WORLD_RANK', 0))
cls.grid_size_h = int(os.getenv('GRID_H', 1))
cls.grid_size_w = int(os.getenv('GRID_W', 1))
port = int(os.getenv('MASTER_PORT', '29501'))
master_address = os.getenv('MASTER_ADDR', 'localhost')
cls.world_rank = int(os.getenv("WORLD_RANK", 0))
cls.grid_size_h = int(os.getenv("GRID_H", 1))
cls.grid_size_w = int(os.getenv("GRID_W", 1))
port = int(os.getenv("MASTER_PORT", "29501"))
master_address = os.getenv("MASTER_ADDR", "localhost")
cls.world_size = cls.grid_size_h * cls.grid_size_w

if torch.cuda.is_available():
Expand All @@ -60,24 +60,21 @@ def setUpClass(cls):
cls.device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(local_rank)
torch.cuda.manual_seed(333)
proc_backend = 'nccl'
proc_backend = "nccl"
else:
if cls.world_rank == 0:
print("Running test on CPU")
cls.device = torch.device('cpu')
proc_backend = 'gloo'
cls.device = torch.device("cpu")
proc_backend = "gloo"
torch.manual_seed(333)

dist.init_process_group(backend = proc_backend,
init_method = f"tcp://{master_address}:{port}",
rank = cls.world_rank,
world_size = cls.world_size)
dist.init_process_group(backend=proc_backend, init_method=f"tcp://{master_address}:{port}", rank=cls.world_rank, world_size=cls.world_size)

cls.wrank = cls.world_rank % cls.grid_size_w
cls.hrank = cls.world_rank // cls.grid_size_w

# now set up the comm groups:
#set default
# set default
cls.w_group = None
cls.h_group = None

Expand Down Expand Up @@ -109,14 +106,12 @@ def setUpClass(cls):
if cls.world_rank in grp:
cls.h_group = tmp_group


if cls.world_rank == 0:
print(f"Running distributed tests on grid H x W = {cls.grid_size_h} x {cls.grid_size_w}")

# initializing sht
thd.init(cls.h_group, cls.w_group)


def _split_helper(self, tensor):
with torch.no_grad():
# split in W
Expand All @@ -129,7 +124,6 @@ def _split_helper(self, tensor):

return tensor_local


def _gather_helper_fwd(self, tensor, B, C, convolution_dist):
# we need the shapes
lat_shapes = convolution_dist.lat_out_shapes
Expand All @@ -155,7 +149,6 @@ def _gather_helper_fwd(self, tensor, B, C, convolution_dist):

return tensor_gather


def _gather_helper_bwd(self, tensor, B, C, convolution_dist):
# we need the shapes
lat_shapes = convolution_dist.lat_in_shapes
Expand All @@ -181,31 +174,37 @@ def _gather_helper_bwd(self, tensor, B, C, convolution_dist):

return tensor_gather


@parameterized.expand([
[128, 256, 128, 256, 32, 8, [3 ], 1, "equiangular", "equiangular", False, 1e-6],
[129, 256, 128, 256, 32, 8, [3 ], 1, "equiangular", "equiangular", False, 1e-6],
[128, 256, 128, 256, 32, 8, [3, 2], 1, "equiangular", "equiangular", False, 1e-6],
[128, 256, 64, 128, 32, 8, [3 ], 1, "equiangular", "equiangular", False, 1e-6],
[128, 256, 128, 256, 32, 8, [3 ], 2, "equiangular", "equiangular", False, 1e-6],
[128, 256, 128, 256, 32, 5, [3 ], 1, "equiangular", "equiangular", False, 1e-6],
[128, 256, 128, 256, 32, 8, [3 ], 1, "equiangular", "equiangular", True, 1e-6],
[129, 256, 128, 256, 32, 8, [3 ], 1, "equiangular", "equiangular", True, 1e-6],
[128, 256, 128, 256, 32, 8, [3, 2], 1, "equiangular", "equiangular", True, 1e-6],
[ 64, 128, 128, 256, 32, 8, [3 ], 1, "equiangular", "equiangular", True, 1e-6],
[128, 256, 128, 256, 32, 8, [3 ], 2, "equiangular", "equiangular", True, 1e-6],
[128, 256, 128, 256, 32, 5, [3 ], 1, "equiangular", "equiangular", True, 1e-6],
])
def test_distributed_disco_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan,
kernel_shape, groups, grid_in, grid_out, transpose, tol):
@parameterized.expand(
[
[128, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", False, 1e-6],
[129, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", False, 1e-6],
[128, 256, 128, 256, 32, 8, [3, 2], 1, "equiangular", "equiangular", False, 1e-6],
[128, 256, 64, 128, 32, 8, [3], 1, "equiangular", "equiangular", False, 1e-6],
[128, 256, 128, 256, 32, 8, [3], 2, "equiangular", "equiangular", False, 1e-6],
[128, 256, 128, 256, 32, 5, [3], 1, "equiangular", "equiangular", False, 1e-6],
[128, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", True, 1e-6],
[129, 256, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", True, 1e-6],
[128, 256, 128, 256, 32, 8, [3, 2], 1, "equiangular", "equiangular", True, 1e-6],
[64, 128, 128, 256, 32, 8, [3], 1, "equiangular", "equiangular", True, 1e-6],
[128, 256, 128, 256, 32, 8, [3], 2, "equiangular", "equiangular", True, 1e-6],
[128, 256, 128, 256, 32, 5, [3], 1, "equiangular", "equiangular", True, 1e-6],
]
)
def test_distributed_disco_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, kernel_shape, groups, grid_in, grid_out, transpose, tol):

B, C, H, W = batch_size, num_chan, nlat_in, nlon_in

disco_args = dict(in_channels=C, out_channels=C,
in_shape=(nlat_in, nlon_in), out_shape=(nlat_out, nlon_out),
kernel_shape=kernel_shape, groups=groups,
grid_in=grid_in, grid_out=grid_out, bias=True)
disco_args = dict(
in_channels=C,
out_channels=C,
in_shape=(nlat_in, nlon_in),
out_shape=(nlat_out, nlon_out),
kernel_shape=kernel_shape,
groups=groups,
grid_in=grid_in,
grid_out=grid_out,
bias=True,
)

# set up handles
if transpose:
Expand Down Expand Up @@ -253,13 +252,13 @@ def test_distributed_disco_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, batc
out_local = conv_dist(inp_local)
out_local.backward(ograd_local)
igrad_local = inp_local.grad.clone()

#############################################################
# evaluate FWD pass
#############################################################
with torch.no_grad():
out_gather_full = self._gather_helper_fwd(out_local, B, C, conv_dist)
err = torch.mean(torch.norm(out_full-out_gather_full, p='fro', dim=(-1,-2)) / torch.norm(out_full, p='fro', dim=(-1,-2)) )
err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2)))
if self.world_rank == 0:
print(f"final relative error of output: {err.item()}")
self.assertTrue(err.item() <= tol)
Expand All @@ -269,11 +268,11 @@ def test_distributed_disco_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, batc
#############################################################
with torch.no_grad():
igrad_gather_full = self._gather_helper_bwd(igrad_local, B, C, conv_dist)
err = torch.mean(torch.norm(igrad_full-igrad_gather_full, p='fro', dim=(-1,-2)) / torch.norm(igrad_full, p='fro', dim=(-1,-2)) )
err = torch.mean(torch.norm(igrad_full - igrad_gather_full, p="fro", dim=(-1, -2)) / torch.norm(igrad_full, p="fro", dim=(-1, -2)))
if self.world_rank == 0:
print(f"final relative error of gradients: {err.item()}")
self.assertTrue(err.item() <= tol)


if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

0 comments on commit f922468

Please sign in to comment.