From b6a035c9650d6b91898898e49ad71180a00594d3 Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Wed, 13 Mar 2024 17:31:45 +0100 Subject: [PATCH 1/3] cleanup of disco convolutions based on CUDA extension --- Dockerfile | 4 +- setup.py | 21 +- tests/test_convolution.py | 12 +- tests/test_distributed_convolution.py | 46 ++-- torch_harmonics/_disco_convolution.py | 346 +------------------------- torch_harmonics/convolution.py | 63 +++-- 6 files changed, 96 insertions(+), 396 deletions(-) diff --git a/Dockerfile b/Dockerfile index 71da184..55d81f0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -# +# # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # @@ -39,5 +39,5 @@ RUN pip install parameterized # we need to remove old archs ENV TORCH_CUDA_ARCH_LIST "7.0 7.2 7.5 8.0 8.6 8.7 9.0+PTX" -RUN pip install /workspace/torch_harmonics +RUN pip install --cuda_ext /workspace/torch_harmonics diff --git a/setup.py b/setup.py index d6ff64c..2cc6512 100644 --- a/setup.py +++ b/setup.py @@ -28,15 +28,19 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # + +import sys + try: from setuptools import setup, find_packages except ImportError: from distutils.core import setup, find_packages -from torch.utils import cpp_extension import re from pathlib import Path +import torch +from torch.utils import cpp_extension def version(root_path): """Returns the version taken from __init__.py @@ -71,14 +75,19 @@ def readme(root_path): return f.read() -def get_ext_modules(): - import torch +def get_ext_modules(argv): + + compile_cuda_extension = False + + if "--cuda_ext" in sys.argv: + sys.argv.remove("--cuda_ext") + compile_cuda_extension = True ext_modules = [ cpp_extension.CppExtension("disco_helpers", ["torch_harmonics/csrc/disco/disco_helpers.cpp"]), ] - if torch.cuda.is_available(): + if torch.cuda.is_available() or compile_cuda_extension: ext_modules.append( cpp_extension.CUDAExtension( "disco_cuda_extension", @@ -98,7 +107,7 @@ def get_ext_modules(): VERSION = version(root_path) # external modules -ext_modules = get_ext_modules() +ext_modules = get_ext_modules(sys.argv) config = { "name": "torch_harmonics", @@ -110,7 +119,7 @@ def get_ext_modules(): "author": "Boris Bonev", "author_email": "bbonev@nvidia.com", "version": VERSION, - "install_requires": ["torch", "numpy", "triton"], + "install_requires": ["torch", "numpy"], "extras_require": { "sfno": ["tensorly", "tensorly-torch"], }, diff --git a/tests/test_convolution.py b/tests/test_convolution.py index c33d1c1..f6f3de1 100644 --- a/tests/test_convolution.py +++ b/tests/test_convolution.py @@ -161,8 +161,10 @@ def setUp(self): else: self.device = torch.device("cpu") torch.manual_seed(333) - - + + # self.device = torch.device("cpu") + + @parameterized.expand( [ # regular convolution @@ -255,12 +257,12 @@ def test_disco_convolution( y_ref = torch.einsum("ftpqr,bcqr->bcftp", psi_dense, x_ref * conv.quad_weights) y_ref = torch.einsum("oif,biftp->botp", w_ref, y_ref) y_ref.backward(grad_input) - x_ref_grad = x_ref.grad.clone() - + x_ref_grad = x_ref.grad.clone() + # compare results self.assertTrue(torch.allclose(y, y_ref, rtol=tol, atol=tol)) - # compare + # compare self.assertTrue(torch.allclose(x_grad, x_ref_grad, rtol=tol, atol=tol)) self.assertTrue(torch.allclose(conv.weight.grad, w_ref.grad, rtol=tol, atol=tol)) diff --git a/tests/test_distributed_convolution.py b/tests/test_distributed_convolution.py index 59b3537..c03c4cd 100644 --- a/tests/test_distributed_convolution.py +++ b/tests/test_distributed_convolution.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -# +# # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # @@ -72,7 +72,7 @@ def setUpClass(cls): init_method = f"tcp://{master_address}:{port}", rank = cls.world_rank, world_size = cls.world_size) - + cls.wrank = cls.world_rank % cls.grid_size_w cls.hrank = cls.world_rank // cls.grid_size_w @@ -108,7 +108,7 @@ def setUpClass(cls): tmp_group = dist.new_group(ranks=grp) if cls.world_rank in grp: cls.h_group = tmp_group - + if cls.world_rank == 0: print(f"Running distributed tests on grid H x W = {cls.grid_size_h} x {cls.grid_size_w}") @@ -116,7 +116,7 @@ def setUpClass(cls): # initializing sht thd.init(cls.h_group, cls.w_group) - + def _split_helper(self, tensor): with torch.no_grad(): # split in W @@ -128,15 +128,15 @@ def _split_helper(self, tensor): tensor_local = tensor_list_local[self.hrank] return tensor_local - - + + def _gather_helper_fwd(self, tensor, B, C, convolution_dist): # we need the shapes lat_shapes = convolution_dist.lat_out_shapes lon_shapes = convolution_dist.lon_out_shapes #print("tensor before gather shape", tensor.shape) - + # gather in W if self.grid_size_w > 1: gather_shapes = [(B, C, lat_shapes[self.hrank], w) for w in lon_shapes] @@ -148,7 +148,7 @@ def _gather_helper_fwd(self, tensor, B, C, convolution_dist): tensor_gather = tensor #print("tensor_gather shape", tensor_gather.shape) - + # gather in H if self.grid_size_h > 1: gather_shapes = [(B, C, h, convolution_dist.nlon_out) for h in lat_shapes] @@ -159,7 +159,7 @@ def _gather_helper_fwd(self, tensor, B, C, convolution_dist): return tensor_gather - + def _gather_helper_bwd(self, tensor, B, C, convolution_dist): # we need the shapes lat_shapes = convolution_dist.lat_in_shapes @@ -203,14 +203,14 @@ def _gather_helper_bwd(self, tensor, B, C, convolution_dist): ]) def test_distributed_disco_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, kernel_shape, groups, grid_in, grid_out, transpose, tol): - + B, C, H, W = batch_size, num_chan, nlat_in, nlon_in - + disco_args = dict(in_channels=C, out_channels=C, in_shape=(nlat_in, nlon_in), out_shape=(nlat_out, nlon_out), kernel_shape=kernel_shape, groups=groups, grid_in=grid_in, grid_out=grid_out, bias=True) - + # set up handles if transpose: conv_local = harmonics.DiscreteContinuousConvTransposeS2(**disco_args).to(self.device) @@ -218,7 +218,7 @@ def test_distributed_disco_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, batc else: conv_local = harmonics.DiscreteContinuousConvS2(**disco_args).to(self.device) conv_dist = thd.DistributedDiscreteContinuousConvS2(**disco_args).to(self.device) - + # copy the weights from the local conv into the dist conv with torch.no_grad(): conv_dist.weight.copy_(conv_local.weight) @@ -226,37 +226,37 @@ def test_distributed_disco_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, batc # create tensors inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device) - + ############################################################# # local conv ############################################################# # FWD pass inp_full.requires_grad = True - out_full = conv_local(inp_full, use_triton_kernel=True) - + out_full = conv_local(inp_full) + # create grad for backward with torch.no_grad(): # create full grad ograd_full = torch.randn_like(out_full) - + # BWD pass out_full.backward(ograd_full) igrad_full = inp_full.grad.clone() - + ############################################################# # distributed conv ############################################################# # FWD pass inp_local = self._split_helper(inp_full) inp_local.requires_grad = True - out_local = conv_dist(inp_local, use_triton_kernel=True) - + out_local = conv_dist(inp_local) + # BWD pass ograd_local = self._split_helper(ograd_full) - out_local = conv_dist(inp_local, use_triton_kernel=True) + out_local = conv_dist(inp_local) out_local.backward(ograd_local) igrad_local = inp_local.grad.clone() - + ############################################################# # evaluate FWD pass ############################################################# @@ -266,7 +266,7 @@ def test_distributed_disco_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, batc if self.world_rank == 0: print(f"final relative error of output: {err.item()}") self.assertTrue(err.item() <= tol) - + ############################################################# # evaluate BWD pass ############################################################# diff --git a/torch_harmonics/_disco_convolution.py b/torch_harmonics/_disco_convolution.py index 67d9742..45891b2 100644 --- a/torch_harmonics/_disco_convolution.py +++ b/torch_harmonics/_disco_convolution.py @@ -33,329 +33,12 @@ import torch -import triton -import triton.language as tl - -if torch.cuda.is_available(): +try: import disco_cuda_extension - -BLOCK_SIZE_BATCH = 4 -BLOCK_SIZE_NZ = 8 -BLOCK_SIZE_POUT = 8 - - -@triton.jit -def _disco_s2_contraction_kernel( - inz_ptr, - vnz_ptr, - nnz, - inz_stride_ii, - inz_stride_nz, - vnz_stride, - x_ptr, - batch_size, - nlat_in, - nlon_in, - x_stride_b, - x_stride_t, - x_stride_p, - y_ptr, - kernel_size, - nlat_out, - nlon_out, - y_stride_b, - y_stride_f, - y_stride_t, - y_stride_p, - pscale, - backward: tl.constexpr, - BLOCK_SIZE_BATCH: tl.constexpr, - BLOCK_SIZE_NZ: tl.constexpr, - BLOCK_SIZE_POUT: tl.constexpr, -): - """ - Kernel for the sparse-dense contraction for the S2 DISCO convolution. - """ - - pid_batch = tl.program_id(0) - pid_pout = tl.program_id(2) - - # pid_nz should always be 0 as we do not account for larger grids in this dimension - pid_nz = tl.program_id(1) # should be always 0 - tl.device_assert(pid_nz == 0) - - # create the pointer block for pout - pout = pid_pout * BLOCK_SIZE_POUT + tl.arange(0, BLOCK_SIZE_POUT) - b = pid_batch * BLOCK_SIZE_BATCH + tl.arange(0, BLOCK_SIZE_BATCH) - - # create pointer blocks for the psi datastructure - iinz = tl.arange(0, BLOCK_SIZE_NZ) - - # get the initial pointers - fout_ptrs = inz_ptr + iinz * inz_stride_nz - tout_ptrs = inz_ptr + iinz * inz_stride_nz + inz_stride_ii - tpnz_ptrs = inz_ptr + iinz * inz_stride_nz + 2 * inz_stride_ii - vals_ptrs = vnz_ptr + iinz * vnz_stride - - # iterate in a blocked fashion over the non-zero entries - for offs_nz in range(0, nnz, BLOCK_SIZE_NZ): - # load input output latitude coordinate pairs - fout = tl.load(fout_ptrs + offs_nz * inz_stride_nz, mask=(offs_nz + iinz < nnz), other=-1) - tout = tl.load(tout_ptrs + offs_nz * inz_stride_nz, mask=(offs_nz + iinz < nnz), other=-1) - tpnz = tl.load(tpnz_ptrs + offs_nz * inz_stride_nz, mask=(offs_nz + iinz < nnz), other=-1) - - # load corresponding values - vals = tl.load(vals_ptrs + offs_nz * vnz_stride, mask=(offs_nz + iinz < nnz), other=0.0) - - # compute the shifted longitude coordinates p+p' to read in a coalesced fashion - tnz = tpnz // nlon_in - pnz = tpnz % nlon_in - - # make sure the value is not out of bounds - tl.device_assert(fout < kernel_size) - tl.device_assert(tout < nlat_out) - tl.device_assert(tnz < nlat_in) - tl.device_assert(pnz < nlon_in) - - # load corresponding portion of the input array - x_ptrs = ( - x_ptr - + tnz[None, :, None] * x_stride_t - + ((pnz[None, :, None] + pout[None, None, :] * pscale) % nlon_in) * x_stride_p - + b[:, None, None] * x_stride_b - ) - y_ptrs = ( - y_ptr - + fout[None, :, None] * y_stride_f - + tout[None, :, None] * y_stride_t - + (pout[None, None, :] % nlon_out) * y_stride_p - + b[:, None, None] * y_stride_b - ) - - # precompute the mask - mask = ((b[:, None, None] < batch_size) and (offs_nz + iinz[None, :, None] < nnz)) and ( - pout[None, None, :] < nlon_out - ) - - # do the actual computation. Backward is essentially just the same operation with swapped tensors. - if not backward: - x = tl.load(x_ptrs, mask=mask, other=0.0) - y = vals[None, :, None] * x - - # store it to the output array - tl.atomic_add(y_ptrs, y, mask=mask) - else: - y = tl.load(y_ptrs, mask=mask, other=0.0) - x = vals[None, :, None] * y - - # store it to the output array - tl.atomic_add(x_ptrs, x, mask=mask) - - -def _disco_s2_contraction_fwd(x: torch.Tensor, psi: torch.Tensor, nlon_out: int): - """ - Wrapper function for the triton implementation of the efficient DISCO convolution on the sphere. - - Parameters - ---------- - x: torch.Tensor - Input signal on the sphere. Expects a tensor of shape batch_size x channels x nlat_in x nlon_in). - psi : torch.Tensor - Pre-computed convolution tensor. Expects a sparse tensor of shape kernel_size x nlat_out x (nlat_in * nlon_in). - nlon_out: int - Number of longitude points the output should have. - """ - - # check the shapes of all input tensors - assert len(psi.shape) == 3 - assert len(x.shape) == 4 - assert psi.is_sparse, "Psi must be a sparse COO tensor" - - # TODO: check that Psi is also coalesced - - # get the dimensions of the problem - kernel_size, nlat_out, n_in = psi.shape - nnz = psi.indices().shape[-1] - batch_size, n_chans, nlat_in, nlon_in = x.shape - assert nlat_in * nlon_in == n_in - - # TODO: check that Psi index vector is of type long - - # make sure that the grid-points of the output grid fall onto the grid points of the input grid - assert nlon_in % nlon_out == 0 - pscale = nlon_in // nlon_out - - # to simplify things, we merge batch and channel dimensions - x = x.reshape(batch_size * n_chans, nlat_in, nlon_in) - - # prepare the output tensor - y = torch.zeros(batch_size * n_chans, kernel_size, nlat_out, nlon_out, device=x.device, dtype=x.dtype) - - # determine the grid for the computation - grid = ( - triton.cdiv(batch_size * n_chans, BLOCK_SIZE_BATCH), - 1, - triton.cdiv(nlon_out, BLOCK_SIZE_POUT), - ) - - # launch the kernel - _disco_s2_contraction_kernel[grid]( - psi.indices(), - psi.values(), - nnz, - psi.indices().stride(-2), - psi.indices().stride(-1), - psi.values().stride(-1), - x, - batch_size * n_chans, - nlat_in, - nlon_in, - x.stride(0), - x.stride(-2), - x.stride(-1), - y, - kernel_size, - nlat_out, - nlon_out, - y.stride(0), - y.stride(1), - y.stride(-2), - y.stride(-1), - pscale, - False, - BLOCK_SIZE_BATCH, - BLOCK_SIZE_NZ, - BLOCK_SIZE_POUT, - ) - - # reshape y back to expose the correct dimensions - y = y.reshape(batch_size, n_chans, kernel_size, nlat_out, nlon_out) - - return y - - -def _disco_s2_contraction_bwd(grad_y: torch.Tensor, psi: torch.Tensor, nlon_in: int): - """ - Backward pass for the triton implementation of the efficient DISCO convolution on the sphere. - - Parameters - ---------- - grad_y: torch.Tensor - Input gradient on the sphere. Expects a tensor of shape batch_size x channels x kernel_size x nlat_out x nlon_out. - psi : torch.Tensor - Pre-computed convolution tensor. Expects a sparse tensor of shape kernel_size x nlat_out x (nlat_in * nlon_in). - nlon_in: int - Number of longitude points the input used. Is required to infer the correct dimensions - """ - - # check the shapes of all input tensors - assert len(psi.shape) == 3 - assert len(grad_y.shape) == 5 - assert psi.is_sparse, "psi must be a sparse COO tensor" - - # TODO: check that Psi is also coalesced - - # get the dimensions of the problem - kernel_size, nlat_out, n_in = psi.shape - nnz = psi.indices().shape[-1] - assert grad_y.shape[-2] == nlat_out - assert grad_y.shape[-3] == kernel_size - assert n_in % nlon_in == 0 - nlat_in = n_in // nlon_in - batch_size, n_chans, _, _, nlon_out = grad_y.shape - - # make sure that the grid-points of the output grid fall onto the grid points of the input grid - assert nlon_in % nlon_out == 0 - pscale = nlon_in // nlon_out - - # to simplify things, we merge batch and channel dimensions - grad_y = grad_y.reshape(batch_size * n_chans, kernel_size, nlat_out, nlon_out) - - # prepare the output tensor - grad_x = torch.zeros(batch_size * n_chans, nlat_in, nlon_in, device=grad_y.device, dtype=grad_y.dtype) - - # determine the grid for the computation - grid = ( - triton.cdiv(batch_size * n_chans, BLOCK_SIZE_BATCH), - 1, - triton.cdiv(nlon_out, BLOCK_SIZE_POUT), - ) - - # launch the kernel - _disco_s2_contraction_kernel[grid]( - psi.indices(), - psi.values(), - nnz, - psi.indices().stride(-2), - psi.indices().stride(-1), - psi.values().stride(-1), - grad_x, - batch_size * n_chans, - nlat_in, - nlon_in, - grad_x.stride(0), - grad_x.stride(-2), - grad_x.stride(-1), - grad_y, - kernel_size, - nlat_out, - nlon_out, - grad_y.stride(0), - grad_y.stride(1), - grad_y.stride(-2), - grad_y.stride(-1), - pscale, - True, - BLOCK_SIZE_BATCH, - BLOCK_SIZE_NZ, - BLOCK_SIZE_POUT, - ) - - # reshape y back to expose the correct dimensions - grad_x = grad_x.reshape(batch_size, n_chans, nlat_in, nlon_in) - - return grad_x - - -class _DiscoS2ContractionTriton(torch.autograd.Function): - """ - Helper function to make the triton implementation work with PyTorch autograd functionality - """ - - @staticmethod - def forward(ctx, x: torch.Tensor, psi: torch.Tensor, nlon_out: int): - ctx.save_for_backward(psi) - ctx.nlon_in = x.shape[-1] - - return _disco_s2_contraction_fwd(x, psi, nlon_out) - - @staticmethod - def backward(ctx, grad_output): - (psi,) = ctx.saved_tensors - grad_input = _disco_s2_contraction_bwd(grad_output, psi, ctx.nlon_in) - grad_x = grad_psi = None - - return grad_input, None, None - -class _DiscoS2TransposeContractionTriton(torch.autograd.Function): - """ - Helper function to make the triton implementation work with PyTorch autograd functionality - """ - - @staticmethod - def forward(ctx, x: torch.Tensor, psi: torch.Tensor, nlon_out: int): - ctx.save_for_backward(psi) - ctx.nlon_in = x.shape[-1] - - return _disco_s2_contraction_bwd(x, psi, nlon_out) - - @staticmethod - def backward(ctx, grad_output): - (psi,) = ctx.saved_tensors - grad_input = _disco_s2_contraction_fwd(grad_output, psi, ctx.nlon_in) - grad_x = grad_psi = None - - return grad_input, None, None + _cuda_extension_available = True +except ImportError as err: + disco_cuda_extension = None + _cuda_extension_available = False class _DiscoS2ContractionCuda(torch.autograd.Function): @@ -368,12 +51,12 @@ def forward(ctx, x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, ctx.nlat_in = x.shape[-2] ctx.nlon_in = x.shape[-1] - return disco_cuda.forward(x, roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, nlat_out, nlon_out) + return disco_cuda_extension.forward(x.contiguous(), roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, nlat_out, nlon_out) @staticmethod def backward(ctx, grad_output): roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors - grad_input = disco_cuda.backward(grad_output, roff_idx, ker_idx, row_idx, col_idx, vals, + grad_input = disco_cuda_extension.backward(grad_output.contiguous(), roff_idx, ker_idx, row_idx, col_idx, vals, ctx.kernel_size, ctx.nlat_in, ctx.nlon_in) return grad_input, None, None, None, None, None, None, None, None @@ -389,23 +72,16 @@ def forward(ctx, x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, ctx.nlat_in = x.shape[-2] ctx.nlon_in = x.shape[-1] - return disco_cuda.backward(x, roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, nlat_out, nlon_out) + return disco_cuda_extension.backward(x.contiguous(), roff_idx, ker_idx, row_idx, col_idx, vals, kernel_size, nlat_out, nlon_out) @staticmethod def backward(ctx, grad_output): roff_idx, ker_idx, row_idx, col_idx, vals = ctx.saved_tensors - grad_input = disco_cuda.forward(grad_output, roff_idx, ker_idx, row_idx, col_idx, vals, + grad_input = disco_cuda_extension.forward(grad_output.contiguous(), roff_idx, ker_idx, row_idx, col_idx, vals, ctx.kernel_size, ctx.nlat_in, ctx.nlon_in) return grad_input, None, None, None, None, None, None, None, None -# triton -def _disco_s2_contraction_triton(x: torch.Tensor, psi: torch.Tensor, nlon_out: int) -> torch.Tensor: - return _DiscoS2ContractionTriton.apply(x, psi, nlon_out) - -def _disco_s2_transpose_contraction_triton(x: torch.Tensor, psi: torch.Tensor, nlon_out: int) -> torch.Tensor: - return _DiscoS2TransposeContractionTriton.apply(x, psi, nlon_out) - # CUDA def _disco_s2_contraction_cuda(x: torch.Tensor, roff_idx: torch.Tensor, ker_idx: torch.Tensor, row_idx: torch.Tensor, col_idx: torch.Tensor, vals: torch.Tensor, @@ -424,7 +100,7 @@ def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: in """ Reference implementation of the custom contraction as described in [1]. This requires repeated shifting of the input tensor, which can potentially be costly. For an efficient implementation - on GPU, make sure to use the custom kernel written in Triton. + on GPU, make sure to use the custom kernel written in CUDA. """ assert len(psi.shape) == 3 assert len(x.shape) == 4 @@ -460,7 +136,7 @@ def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nl """ Reference implementation of the custom contraction as described in [1]. This requires repeated shifting of the input tensor, which can potentially be costly. For an efficient implementation - on GPU, make sure to use the custom kernel written in Triton. + on GPU, make sure to use the custom kernel written in CUDA. """ assert len(psi.shape) == 3 assert len(x.shape) == 5 diff --git a/torch_harmonics/convolution.py b/torch_harmonics/convolution.py index a775918..2d98237 100644 --- a/torch_harmonics/convolution.py +++ b/torch_harmonics/convolution.py @@ -31,6 +31,7 @@ import abc from typing import List, Tuple, Union, Optional +from warnings import warn import math @@ -40,19 +41,18 @@ from functools import partial from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes -from torch_harmonics._disco_convolution import ( - _disco_s2_contraction_torch, - _disco_s2_transpose_contraction_torch, - _disco_s2_contraction_triton, - _disco_s2_transpose_contraction_triton, - _disco_s2_contraction_cuda, - _disco_s2_transpose_contraction_cuda, -) +from torch_harmonics._disco_convolution import _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch +from torch_harmonics._disco_convolution import _disco_s2_contraction_cuda, _disco_s2_transpose_contraction_cuda # import custom C++/CUDA extensions from disco_helpers import preprocess_psi -if torch.cuda.is_available(): + +try: import disco_cuda_extension + _cuda_extension_available = True +except ImportError as err: + disco_cuda_extension = None + _cuda_extension_available = False def _compute_support_vals_isotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, r_cutoff: float, norm: str = "s2"): @@ -331,23 +331,18 @@ def __init__( col_idx = idx[2, ...].contiguous() roff_idx = preprocess_psi(self.kernel_size, out_shape[0], ker_idx, row_idx, col_idx, vals) - # GPU kernel + # preprocessed data-structure for GPU kernel self.register_buffer("psi_roff_idx", roff_idx, persistent=False) self.register_buffer("psi_ker_idx", ker_idx, persistent=False) self.register_buffer("psi_row_idx", row_idx, persistent=False) self.register_buffer("psi_col_idx", col_idx, persistent=False) self.register_buffer("psi_vals", vals, persistent=False) - # cpu kernel - #self.register_buffer("psi_idx", idx, persistent=False) - #self.register_buffer("psi_vals", vals, persistent=False) - @property def psi_idx(self): return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous() def get_psi(self): - #psi_idx = torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous() psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in)).coalesce() return psi @@ -355,11 +350,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # pre-multiply x with the quadrature weights x = self.quad_weights * x - if x.is_cuda: - #x = _disco_s2_contraction_triton(x, psi, self.nlon_out) + if x.is_cuda and _cuda_extension_available: x = _disco_s2_contraction_cuda(x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out) else: + if x.is_cuda: + warn("couldn't find CUDA extension, falling back to slow PyTorch implementation") psi = self.get_psi() x = _disco_s2_contraction_torch(x, psi, self.nlon_out) @@ -417,11 +413,25 @@ def __init__( # switch in_shape and out_shape since we want transpose conv idx, vals = _precompute_convolution_tensor_s2(out_shape, in_shape, self.kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff) - self.register_buffer("psi_idx", idx, persistent=False) + # sort the values + ker_idx = idx[0, ...].contiguous() + row_idx = idx[1, ...].contiguous() + col_idx = idx[2, ...].contiguous() + roff_idx = preprocess_psi(self.kernel_size, out_shape[0], ker_idx, row_idx, col_idx, vals) + + # preprocessed data-structure for GPU kernel + self.register_buffer("psi_roff_idx", roff_idx, persistent=False) + self.register_buffer("psi_ker_idx", ker_idx, persistent=False) + self.register_buffer("psi_row_idx", row_idx, persistent=False) + self.register_buffer("psi_col_idx", col_idx, persistent=False) self.register_buffer("psi_vals", vals, persistent=False) - def get_psi(self, use_triton_kernel=True): - if not use_triton_kernel: + @property + def psi_idx(self): + return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous() + + def get_psi(self, semi_transposed: bool=False): + if semi_transposed: # we do a semi-transposition to faciliate the computation tout = self.psi_idx[2] // self.nlon_out pout = self.psi_idx[2] % self.nlon_out @@ -432,9 +442,10 @@ def get_psi(self, use_triton_kernel=True): psi = torch.sparse_coo_tensor(idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_out)).coalesce() else: psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out)).coalesce() + return psi - def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: # extract shape B, C, H, W = x.shape x = x.reshape(B, self.groups, self.groupsize, H, W) @@ -446,11 +457,13 @@ def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tens # pre-multiply x with the quadrature weights x = self.quad_weights * x - psi = self.get_psi(x.is_cuda and use_triton_kernel) - - if x.is_cuda and use_triton_kernel: - out = _disco_s2_transpose_contraction_triton(x, psi, self.nlon_out) + if x.is_cuda and _cuda_extension_available: + out = _disco_s2_transpose_contraction_cuda(x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, + self.kernel_size, self.nlat_out, self.nlon_out) else: + if x.is_cuda: + warn("couldn't find CUDA extension, falling back to slow PyTorch implementation") + psi = self.get_psi(semi_transposed=True) out = _disco_s2_transpose_contraction_torch(x, psi, self.nlon_out) if self.bias is not None: From d3a6b3e2253c970b610fbd44a642acba73e7ad59 Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Wed, 13 Mar 2024 17:46:06 +0100 Subject: [PATCH 2/3] fixing unittest --- tests/test_convolution.py | 69 +++++++++++++++------------------------ 1 file changed, 26 insertions(+), 43 deletions(-) diff --git a/tests/test_convolution.py b/tests/test_convolution.py index f6f3de1..a13c069 100644 --- a/tests/test_convolution.py +++ b/tests/test_convolution.py @@ -49,16 +49,7 @@ def _compute_vals_isotropic(theta: torch.Tensor, phi: torch.Tensor, ntheta: int, ikernel = torch.arange(ntheta).reshape(-1, 1, 1) itheta = ikernel * dtheta - norm_factor = ( - 2 - * math.pi - * ( - 1 - - math.cos(theta_cutoff - dtheta) - + math.cos(theta_cutoff - dtheta) - + (math.sin(theta_cutoff - dtheta) - math.sin(theta_cutoff)) / dtheta - ) - ) + norm_factor = 2 * math.pi * (1 - math.cos(theta_cutoff - dtheta) + math.cos(theta_cutoff - dtheta) + (math.sin(theta_cutoff - dtheta) - math.sin(theta_cutoff)) / dtheta) vals = torch.where( ((theta - itheta).abs() <= dtheta) & (theta <= theta_cutoff), @@ -67,6 +58,7 @@ def _compute_vals_isotropic(theta: torch.Tensor, phi: torch.Tensor, ntheta: int, ) return vals + def _compute_vals_anisotropic(theta: torch.Tensor, phi: torch.Tensor, ntheta: int, nphi: int, theta_cutoff: float): """ helper routine to compute the values of the anisotropic kernel densely @@ -75,7 +67,7 @@ def _compute_vals_anisotropic(theta: torch.Tensor, phi: torch.Tensor, ntheta: in # compute the support dtheta = (theta_cutoff - 0.0) / ntheta dphi = 2.0 * math.pi / nphi - kernel_size = (ntheta-1)*nphi + 1 + kernel_size = (ntheta - 1) * nphi + 1 ikernel = torch.arange(kernel_size).reshape(-1, 1, 1) itheta = ((ikernel - 1) // nphi + 1) * dtheta iphi = ((ikernel - 1) % nphi) * dphi @@ -84,15 +76,14 @@ def _compute_vals_anisotropic(theta: torch.Tensor, phi: torch.Tensor, ntheta: in # find the indices where the rotated position falls into the support of the kernel cond_theta = ((theta - itheta).abs() <= dtheta) & (theta <= theta_cutoff) - cond_phi = ((phi - iphi).abs() <= dphi) | ((2*math.pi - (phi - iphi).abs()) <= dphi) + cond_phi = ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi) theta_vals = torch.where(cond_theta, (1 - (theta - itheta).abs() / dtheta) / norm_factor, 0.0) - phi_vals = torch.where(cond_phi, (1 - torch.minimum((phi - iphi).abs(), (2*math.pi - (phi - iphi).abs()) ) / dphi ), 0.0) + phi_vals = torch.where(cond_phi, (1 - torch.minimum((phi - iphi).abs(), (2 * math.pi - (phi - iphi).abs())) / dphi), 0.0) vals = torch.where(ikernel > 0, theta_vals * phi_vals, theta_vals) return vals -def _precompute_convolution_tensor_dense( - in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi -): + +def _precompute_convolution_tensor_dense(in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi): """ Helper routine to compute the convolution Tensor in a dense fashion """ @@ -105,7 +96,7 @@ def _precompute_convolution_tensor_dense( kernel_size = kernel_shape[0] elif len(kernel_shape) == 2: kernel_handle = partial(_compute_vals_anisotropic, ntheta=kernel_shape[0], nphi=kernel_shape[1], theta_cutoff=theta_cutoff) - kernel_size = (kernel_shape[0]-1)*kernel_shape[1] + 1 + kernel_size = (kernel_shape[0] - 1) * kernel_shape[1] + 1 else: raise ValueError("kernel_shape should be either one- or two-dimensional.") @@ -162,27 +153,24 @@ def setUp(self): self.device = torch.device("cpu") torch.manual_seed(333) - # self.device = torch.device("cpu") - - @parameterized.expand( [ # regular convolution - [8, 4, 2, (16, 32), (16, 32), [2 ], "equiangular", "equiangular", False, 5e-5], - [8, 4, 2, (16, 32), ( 8, 16), [3 ], "equiangular", "equiangular", False, 5e-5], - [8, 4, 2, (16, 32), ( 8, 16), [2, 3], "equiangular", "equiangular", False, 5e-5], - [8, 4, 2, (18, 36), ( 6, 12), [4 ], "equiangular", "equiangular", False, 5e-5], - [8, 4, 2, (16, 32), ( 8, 16), [3 ], "equiangular", "legendre-gauss", False, 5e-5], - [8, 4, 2, (16, 32), ( 8, 16), [3 ], "legendre-gauss", "equiangular", False, 5e-5], - [8, 4, 2, (16, 32), ( 8, 16), [3 ], "legendre-gauss", "legendre-gauss", False, 5e-5], + [8, 4, 2, (16, 32), (16, 32), [2], "equiangular", "equiangular", False, 5e-5], + [8, 4, 2, (16, 32), (8, 16), [3], "equiangular", "equiangular", False, 5e-5], + [8, 4, 2, (16, 32), (8, 16), [2, 3], "equiangular", "equiangular", False, 5e-5], + [8, 4, 2, (18, 36), (6, 12), [4], "equiangular", "equiangular", False, 5e-5], + [8, 4, 2, (16, 32), (8, 16), [3], "equiangular", "legendre-gauss", False, 5e-5], + [8, 4, 2, (16, 32), (8, 16), [3], "legendre-gauss", "equiangular", False, 5e-5], + [8, 4, 2, (16, 32), (8, 16), [3], "legendre-gauss", "legendre-gauss", False, 5e-5], # transpose convolution - [8, 4, 2, (16, 32), (16, 32), [2 ], "equiangular", "equiangular", True, 5e-5], - [8, 4, 2, ( 8, 16), (16, 32), [3 ], "equiangular", "equiangular", True, 5e-5], - [8, 4, 2, ( 8, 16), (16, 32), [2, 3], "equiangular", "equiangular", True, 5e-5], - [8, 4, 2, ( 6, 12), (18, 36), [4 ], "equiangular", "equiangular", True, 5e-5], - [8, 4, 2, ( 8, 16), (16, 32), [3 ], "equiangular", "legendre-gauss", True, 5e-5], - [8, 4, 2, ( 8, 16), (16, 32), [3 ], "legendre-gauss", "equiangular", True, 5e-5], - [8, 4, 2, ( 8, 16), (16, 32), [3 ], "legendre-gauss", "legendre-gauss", True, 5e-5], + [8, 4, 2, (16, 32), (16, 32), [2], "equiangular", "equiangular", True, 5e-5], + [8, 4, 2, (8, 16), (16, 32), [3], "equiangular", "equiangular", True, 5e-5], + [8, 4, 2, (8, 16), (16, 32), [2, 3], "equiangular", "equiangular", True, 5e-5], + [8, 4, 2, (6, 12), (18, 36), [4], "equiangular", "equiangular", True, 5e-5], + [8, 4, 2, (8, 16), (16, 32), [3], "equiangular", "legendre-gauss", True, 5e-5], + [8, 4, 2, (8, 16), (16, 32), [3], "legendre-gauss", "equiangular", True, 5e-5], + [8, 4, 2, (8, 16), (16, 32), [3], "legendre-gauss", "legendre-gauss", True, 5e-5], ] ) def test_disco_convolution( @@ -217,19 +205,13 @@ def test_disco_convolution( theta_cutoff = (kernel_shape[0] + 1) * torch.pi / float(nlat_in - 1) if transpose: - psi_dense = _precompute_convolution_tensor_dense( - out_shape, in_shape, kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff - ).to(self.device) + psi_dense = _precompute_convolution_tensor_dense(out_shape, in_shape, kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff).to(self.device) else: - psi_dense = _precompute_convolution_tensor_dense( - in_shape, out_shape, kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff - ).to(self.device) + psi_dense = _precompute_convolution_tensor_dense(in_shape, out_shape, kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff).to(self.device) psi = torch.sparse_coo_tensor(conv.psi_idx, conv.psi_vals, size=(conv.kernel_size, conv.nlat_out, conv.nlat_in * conv.nlon_in)).to_dense() - self.assertTrue( - torch.allclose(psi, psi_dense[:, :, 0].reshape(-1, nlat_out, nlat_in * nlon_in)) - ) + self.assertTrue(torch.allclose(psi, psi_dense[:, :, 0].reshape(-1, nlat_out, nlat_in * nlon_in))) # create a copy of the weight w_ref = torch.empty_like(conv.weight) @@ -266,5 +248,6 @@ def test_disco_convolution( self.assertTrue(torch.allclose(x_grad, x_ref_grad, rtol=tol, atol=tol)) self.assertTrue(torch.allclose(conv.weight.grad, w_ref.grad, rtol=tol, atol=tol)) + if __name__ == "__main__": unittest.main() From 5ce34f856adbf84706e7717f4874e6181383ed6e Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Wed, 13 Mar 2024 17:51:27 +0100 Subject: [PATCH 3/3] changing version to experimental 0.7.0a --- torch_harmonics/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_harmonics/__init__.py b/torch_harmonics/__init__.py index 97dd543..9c4e7ec 100644 --- a/torch_harmonics/__init__.py +++ b/torch_harmonics/__init__.py @@ -29,7 +29,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # -__version__ = '0.7.0' +__version__ = "0.7.0a" from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2