Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bbonev/cuda disco cleanup #32

Merged
merged 3 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
Expand Down Expand Up @@ -39,5 +39,5 @@ RUN pip install parameterized

# we need to remove old archs
ENV TORCH_CUDA_ARCH_LIST "7.0 7.2 7.5 8.0 8.6 8.7 9.0+PTX"
RUN pip install /workspace/torch_harmonics
RUN pip install --cuda_ext /workspace/torch_harmonics

21 changes: 15 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,19 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import sys

try:
from setuptools import setup, find_packages
except ImportError:
from distutils.core import setup, find_packages

from torch.utils import cpp_extension
import re
from pathlib import Path

import torch
from torch.utils import cpp_extension

def version(root_path):
"""Returns the version taken from __init__.py
Expand Down Expand Up @@ -71,14 +75,19 @@ def readme(root_path):
return f.read()


def get_ext_modules():
import torch
def get_ext_modules(argv):

compile_cuda_extension = False

if "--cuda_ext" in sys.argv:
sys.argv.remove("--cuda_ext")
compile_cuda_extension = True

ext_modules = [
cpp_extension.CppExtension("disco_helpers", ["torch_harmonics/csrc/disco/disco_helpers.cpp"]),
]

if torch.cuda.is_available():
if torch.cuda.is_available() or compile_cuda_extension:
ext_modules.append(
cpp_extension.CUDAExtension(
"disco_cuda_extension",
Expand All @@ -98,7 +107,7 @@ def get_ext_modules():
VERSION = version(root_path)

# external modules
ext_modules = get_ext_modules()
ext_modules = get_ext_modules(sys.argv)

config = {
"name": "torch_harmonics",
Expand All @@ -110,7 +119,7 @@ def get_ext_modules():
"author": "Boris Bonev",
"author_email": "bbonev@nvidia.com",
"version": VERSION,
"install_requires": ["torch", "numpy", "triton"],
"install_requires": ["torch", "numpy"],
"extras_require": {
"sfno": ["tensorly", "tensorly-torch"],
},
Expand Down
75 changes: 30 additions & 45 deletions tests/test_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,7 @@ def _compute_vals_isotropic(theta: torch.Tensor, phi: torch.Tensor, ntheta: int,
ikernel = torch.arange(ntheta).reshape(-1, 1, 1)
itheta = ikernel * dtheta

norm_factor = (
2
* math.pi
* (
1
- math.cos(theta_cutoff - dtheta)
+ math.cos(theta_cutoff - dtheta)
+ (math.sin(theta_cutoff - dtheta) - math.sin(theta_cutoff)) / dtheta
)
)
norm_factor = 2 * math.pi * (1 - math.cos(theta_cutoff - dtheta) + math.cos(theta_cutoff - dtheta) + (math.sin(theta_cutoff - dtheta) - math.sin(theta_cutoff)) / dtheta)

vals = torch.where(
((theta - itheta).abs() <= dtheta) & (theta <= theta_cutoff),
Expand All @@ -67,6 +58,7 @@ def _compute_vals_isotropic(theta: torch.Tensor, phi: torch.Tensor, ntheta: int,
)
return vals


def _compute_vals_anisotropic(theta: torch.Tensor, phi: torch.Tensor, ntheta: int, nphi: int, theta_cutoff: float):
"""
helper routine to compute the values of the anisotropic kernel densely
Expand All @@ -75,7 +67,7 @@ def _compute_vals_anisotropic(theta: torch.Tensor, phi: torch.Tensor, ntheta: in
# compute the support
dtheta = (theta_cutoff - 0.0) / ntheta
dphi = 2.0 * math.pi / nphi
kernel_size = (ntheta-1)*nphi + 1
kernel_size = (ntheta - 1) * nphi + 1
ikernel = torch.arange(kernel_size).reshape(-1, 1, 1)
itheta = ((ikernel - 1) // nphi + 1) * dtheta
iphi = ((ikernel - 1) % nphi) * dphi
Expand All @@ -84,15 +76,14 @@ def _compute_vals_anisotropic(theta: torch.Tensor, phi: torch.Tensor, ntheta: in

# find the indices where the rotated position falls into the support of the kernel
cond_theta = ((theta - itheta).abs() <= dtheta) & (theta <= theta_cutoff)
cond_phi = ((phi - iphi).abs() <= dphi) | ((2*math.pi - (phi - iphi).abs()) <= dphi)
cond_phi = ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi)
theta_vals = torch.where(cond_theta, (1 - (theta - itheta).abs() / dtheta) / norm_factor, 0.0)
phi_vals = torch.where(cond_phi, (1 - torch.minimum((phi - iphi).abs(), (2*math.pi - (phi - iphi).abs()) ) / dphi ), 0.0)
phi_vals = torch.where(cond_phi, (1 - torch.minimum((phi - iphi).abs(), (2 * math.pi - (phi - iphi).abs())) / dphi), 0.0)
vals = torch.where(ikernel > 0, theta_vals * phi_vals, theta_vals)
return vals

def _precompute_convolution_tensor_dense(
in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi
):

def _precompute_convolution_tensor_dense(in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi):
"""
Helper routine to compute the convolution Tensor in a dense fashion
"""
Expand All @@ -105,7 +96,7 @@ def _precompute_convolution_tensor_dense(
kernel_size = kernel_shape[0]
elif len(kernel_shape) == 2:
kernel_handle = partial(_compute_vals_anisotropic, ntheta=kernel_shape[0], nphi=kernel_shape[1], theta_cutoff=theta_cutoff)
kernel_size = (kernel_shape[0]-1)*kernel_shape[1] + 1
kernel_size = (kernel_shape[0] - 1) * kernel_shape[1] + 1
else:
raise ValueError("kernel_shape should be either one- or two-dimensional.")

Expand Down Expand Up @@ -161,26 +152,25 @@ def setUp(self):
else:
self.device = torch.device("cpu")
torch.manual_seed(333)



@parameterized.expand(
[
# regular convolution
[8, 4, 2, (16, 32), (16, 32), [2 ], "equiangular", "equiangular", False, 5e-5],
[8, 4, 2, (16, 32), ( 8, 16), [3 ], "equiangular", "equiangular", False, 5e-5],
[8, 4, 2, (16, 32), ( 8, 16), [2, 3], "equiangular", "equiangular", False, 5e-5],
[8, 4, 2, (18, 36), ( 6, 12), [4 ], "equiangular", "equiangular", False, 5e-5],
[8, 4, 2, (16, 32), ( 8, 16), [3 ], "equiangular", "legendre-gauss", False, 5e-5],
[8, 4, 2, (16, 32), ( 8, 16), [3 ], "legendre-gauss", "equiangular", False, 5e-5],
[8, 4, 2, (16, 32), ( 8, 16), [3 ], "legendre-gauss", "legendre-gauss", False, 5e-5],
[8, 4, 2, (16, 32), (16, 32), [2], "equiangular", "equiangular", False, 5e-5],
[8, 4, 2, (16, 32), (8, 16), [3], "equiangular", "equiangular", False, 5e-5],
[8, 4, 2, (16, 32), (8, 16), [2, 3], "equiangular", "equiangular", False, 5e-5],
[8, 4, 2, (18, 36), (6, 12), [4], "equiangular", "equiangular", False, 5e-5],
[8, 4, 2, (16, 32), (8, 16), [3], "equiangular", "legendre-gauss", False, 5e-5],
[8, 4, 2, (16, 32), (8, 16), [3], "legendre-gauss", "equiangular", False, 5e-5],
[8, 4, 2, (16, 32), (8, 16), [3], "legendre-gauss", "legendre-gauss", False, 5e-5],
# transpose convolution
[8, 4, 2, (16, 32), (16, 32), [2 ], "equiangular", "equiangular", True, 5e-5],
[8, 4, 2, ( 8, 16), (16, 32), [3 ], "equiangular", "equiangular", True, 5e-5],
[8, 4, 2, ( 8, 16), (16, 32), [2, 3], "equiangular", "equiangular", True, 5e-5],
[8, 4, 2, ( 6, 12), (18, 36), [4 ], "equiangular", "equiangular", True, 5e-5],
[8, 4, 2, ( 8, 16), (16, 32), [3 ], "equiangular", "legendre-gauss", True, 5e-5],
[8, 4, 2, ( 8, 16), (16, 32), [3 ], "legendre-gauss", "equiangular", True, 5e-5],
[8, 4, 2, ( 8, 16), (16, 32), [3 ], "legendre-gauss", "legendre-gauss", True, 5e-5],
[8, 4, 2, (16, 32), (16, 32), [2], "equiangular", "equiangular", True, 5e-5],
[8, 4, 2, (8, 16), (16, 32), [3], "equiangular", "equiangular", True, 5e-5],
[8, 4, 2, (8, 16), (16, 32), [2, 3], "equiangular", "equiangular", True, 5e-5],
[8, 4, 2, (6, 12), (18, 36), [4], "equiangular", "equiangular", True, 5e-5],
[8, 4, 2, (8, 16), (16, 32), [3], "equiangular", "legendre-gauss", True, 5e-5],
[8, 4, 2, (8, 16), (16, 32), [3], "legendre-gauss", "equiangular", True, 5e-5],
[8, 4, 2, (8, 16), (16, 32), [3], "legendre-gauss", "legendre-gauss", True, 5e-5],
]
)
def test_disco_convolution(
Expand Down Expand Up @@ -215,19 +205,13 @@ def test_disco_convolution(
theta_cutoff = (kernel_shape[0] + 1) * torch.pi / float(nlat_in - 1)

if transpose:
psi_dense = _precompute_convolution_tensor_dense(
out_shape, in_shape, kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff
).to(self.device)
psi_dense = _precompute_convolution_tensor_dense(out_shape, in_shape, kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff).to(self.device)
else:
psi_dense = _precompute_convolution_tensor_dense(
in_shape, out_shape, kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff
).to(self.device)
psi_dense = _precompute_convolution_tensor_dense(in_shape, out_shape, kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff).to(self.device)

psi = torch.sparse_coo_tensor(conv.psi_idx, conv.psi_vals, size=(conv.kernel_size, conv.nlat_out, conv.nlat_in * conv.nlon_in)).to_dense()

self.assertTrue(
torch.allclose(psi, psi_dense[:, :, 0].reshape(-1, nlat_out, nlat_in * nlon_in))
)
self.assertTrue(torch.allclose(psi, psi_dense[:, :, 0].reshape(-1, nlat_out, nlat_in * nlon_in)))

# create a copy of the weight
w_ref = torch.empty_like(conv.weight)
Expand Down Expand Up @@ -255,14 +239,15 @@ def test_disco_convolution(
y_ref = torch.einsum("ftpqr,bcqr->bcftp", psi_dense, x_ref * conv.quad_weights)
y_ref = torch.einsum("oif,biftp->botp", w_ref, y_ref)
y_ref.backward(grad_input)
x_ref_grad = x_ref.grad.clone()
x_ref_grad = x_ref.grad.clone()

# compare results
self.assertTrue(torch.allclose(y, y_ref, rtol=tol, atol=tol))

# compare
# compare
self.assertTrue(torch.allclose(x_grad, x_ref_grad, rtol=tol, atol=tol))
self.assertTrue(torch.allclose(conv.weight.grad, w_ref.grad, rtol=tol, atol=tol))


if __name__ == "__main__":
unittest.main()
46 changes: 23 additions & 23 deletions tests/test_distributed_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
Expand Down Expand Up @@ -72,7 +72,7 @@ def setUpClass(cls):
init_method = f"tcp://{master_address}:{port}",
rank = cls.world_rank,
world_size = cls.world_size)

cls.wrank = cls.world_rank % cls.grid_size_w
cls.hrank = cls.world_rank // cls.grid_size_w

Expand Down Expand Up @@ -108,15 +108,15 @@ def setUpClass(cls):
tmp_group = dist.new_group(ranks=grp)
if cls.world_rank in grp:
cls.h_group = tmp_group


if cls.world_rank == 0:
print(f"Running distributed tests on grid H x W = {cls.grid_size_h} x {cls.grid_size_w}")

# initializing sht
thd.init(cls.h_group, cls.w_group)


def _split_helper(self, tensor):
with torch.no_grad():
# split in W
Expand All @@ -128,15 +128,15 @@ def _split_helper(self, tensor):
tensor_local = tensor_list_local[self.hrank]

return tensor_local


def _gather_helper_fwd(self, tensor, B, C, convolution_dist):
# we need the shapes
lat_shapes = convolution_dist.lat_out_shapes
lon_shapes = convolution_dist.lon_out_shapes

#print("tensor before gather shape", tensor.shape)

# gather in W
if self.grid_size_w > 1:
gather_shapes = [(B, C, lat_shapes[self.hrank], w) for w in lon_shapes]
Expand All @@ -148,7 +148,7 @@ def _gather_helper_fwd(self, tensor, B, C, convolution_dist):
tensor_gather = tensor

#print("tensor_gather shape", tensor_gather.shape)

# gather in H
if self.grid_size_h > 1:
gather_shapes = [(B, C, h, convolution_dist.nlon_out) for h in lat_shapes]
Expand All @@ -159,7 +159,7 @@ def _gather_helper_fwd(self, tensor, B, C, convolution_dist):

return tensor_gather


def _gather_helper_bwd(self, tensor, B, C, convolution_dist):
# we need the shapes
lat_shapes = convolution_dist.lat_in_shapes
Expand Down Expand Up @@ -203,60 +203,60 @@ def _gather_helper_bwd(self, tensor, B, C, convolution_dist):
])
def test_distributed_disco_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan,
kernel_shape, groups, grid_in, grid_out, transpose, tol):

B, C, H, W = batch_size, num_chan, nlat_in, nlon_in

disco_args = dict(in_channels=C, out_channels=C,
in_shape=(nlat_in, nlon_in), out_shape=(nlat_out, nlon_out),
kernel_shape=kernel_shape, groups=groups,
grid_in=grid_in, grid_out=grid_out, bias=True)

# set up handles
if transpose:
conv_local = harmonics.DiscreteContinuousConvTransposeS2(**disco_args).to(self.device)
conv_dist = thd.DistributedDiscreteContinuousConvTransposeS2(**disco_args).to(self.device)
else:
conv_local = harmonics.DiscreteContinuousConvS2(**disco_args).to(self.device)
conv_dist = thd.DistributedDiscreteContinuousConvS2(**disco_args).to(self.device)

# copy the weights from the local conv into the dist conv
with torch.no_grad():
conv_dist.weight.copy_(conv_local.weight)
conv_dist.bias.copy_(conv_local.bias)

# create tensors
inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device)

#############################################################
# local conv
#############################################################
# FWD pass
inp_full.requires_grad = True
out_full = conv_local(inp_full, use_triton_kernel=True)
out_full = conv_local(inp_full)

# create grad for backward
with torch.no_grad():
# create full grad
ograd_full = torch.randn_like(out_full)

# BWD pass
out_full.backward(ograd_full)
igrad_full = inp_full.grad.clone()

#############################################################
# distributed conv
#############################################################
# FWD pass
inp_local = self._split_helper(inp_full)
inp_local.requires_grad = True
out_local = conv_dist(inp_local, use_triton_kernel=True)
out_local = conv_dist(inp_local)

# BWD pass
ograd_local = self._split_helper(ograd_full)
out_local = conv_dist(inp_local, use_triton_kernel=True)
out_local = conv_dist(inp_local)
out_local.backward(ograd_local)
igrad_local = inp_local.grad.clone()

#############################################################
# evaluate FWD pass
#############################################################
Expand All @@ -266,7 +266,7 @@ def test_distributed_disco_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, batc
if self.world_rank == 0:
print(f"final relative error of output: {err.item()}")
self.assertTrue(err.item() <= tol)

#############################################################
# evaluate BWD pass
#############################################################
Expand Down
2 changes: 1 addition & 1 deletion torch_harmonics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

__version__ = '0.7.0'
__version__ = "0.7.0a"

from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT
from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
Expand Down
Loading