Skip to content

Commit

Permalink
Bbonev/disco refactor (#27)
Browse files Browse the repository at this point in the history
* Moved convolutions and exposed them directly

* Added transposition to the unit test

* Minor bugfix in CPU version of DISCO transpose code

* Adding convolution tests to CI

* Added gradient check

* Checking the weight grad as well

* Added test for anisotropic kernels
  • Loading branch information
bonevbs authored Dec 20, 2023
1 parent 1e5f7a2 commit 942aa4e
Show file tree
Hide file tree
Showing 7 changed files with 330 additions and 29 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ jobs:
- name: Test with pytest
run: |
python -m pip install pytest pytest-cov parameterized
python -m pytest --cov-report term --cov-config=.coveragerc --cov=torch_harmonics ./tests/test_sht.py
python -m pytest --cov-report term --cov-config=.coveragerc --cov=torch_harmonics ./tests/test_sht.py ./tests/test_convolution.py
15 changes: 12 additions & 3 deletions Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,19 @@

## Versioning

### v0.6.5

* Discrrete-continuous (DISCO) convolutions on the sphere
* Isotropic and anisotropic DISCO convolutions
* Accelerated DISCO convolutions on GPU via Triton implementation
* Unittests for DISCO convolutions

### v0.6.4
* reworking distributed to allow for uneven split tensors, effectively removing the necessity of padding the transformed tensors
* distributed SHT tests are now using unittest. Test extended to vector SHT versions. Tests are defined in `torch_harmonics/distributed/distributed_tests.py`
* base pytorch container version bumped up to 23.11 in Dockerfile

* Reworking distributed to allow for uneven split tensors, effectively removing the necessity of padding the transformed tensors
* Distributed SHT tests are now using unittest. Test extended to vector SHT versions
* Tests are defined in `torch_harmonics/distributed/distributed_tests.py`
* Base pytorch container version bumped up to 23.11 in Dockerfile

### v0.6.3

Expand Down
262 changes: 262 additions & 0 deletions tests/test_convolution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import unittest
from parameterized import parameterized
from functools import partial
import math
import numpy as np
import torch
from torch.autograd import gradcheck
from torch_harmonics import *


def _compute_vals_isotropic(theta: torch.Tensor, phi: torch.Tensor, ntheta: int, theta_cutoff: float):
"""
helper routine to compute the values of the isotropic kernel densely
"""

# compute the support
dtheta = (theta_cutoff - 0.0) / ntheta
ikernel = torch.arange(ntheta).reshape(-1, 1, 1)
itheta = ikernel * dtheta

norm_factor = (
2
* math.pi
* (
1
- math.cos(theta_cutoff - dtheta)
+ math.cos(theta_cutoff - dtheta)
+ (math.sin(theta_cutoff - dtheta) - math.sin(theta_cutoff)) / dtheta
)
)

vals = torch.where(
((theta - itheta).abs() <= dtheta) & (theta <= theta_cutoff),
(1 - (theta - itheta).abs() / dtheta) / norm_factor,
0,
)
return vals

def _compute_vals_anisotropic(theta: torch.Tensor, phi: torch.Tensor, ntheta: int, nphi: int, theta_cutoff: float):
"""
helper routine to compute the values of the anisotropic kernel densely
"""

# compute the support
dtheta = (theta_cutoff - 0.0) / ntheta
dphi = 2.0 * math.pi / nphi
kernel_size = (ntheta-1)*nphi + 1
ikernel = torch.arange(kernel_size).reshape(-1, 1, 1)
itheta = ((ikernel - 1) // nphi + 1) * dtheta
iphi = ((ikernel - 1) % nphi) * dphi

norm_factor = 2 * math.pi * (1 - math.cos(theta_cutoff - dtheta) + math.cos(theta_cutoff - dtheta) + (math.sin(theta_cutoff - dtheta) - math.sin(theta_cutoff)) / dtheta)

# find the indices where the rotated position falls into the support of the kernel
cond_theta = ((theta - itheta).abs() <= dtheta) & (theta <= theta_cutoff)
cond_phi = ((phi - iphi).abs() <= dphi) | ((2*math.pi - (phi - iphi).abs()) <= dphi)
theta_vals = torch.where(cond_theta, (1 - (theta - itheta).abs() / dtheta) / norm_factor, 0.0)
phi_vals = torch.where(cond_phi, (1 - torch.minimum((phi - iphi).abs(), (2*math.pi - (phi - iphi).abs()) ) / dphi ), 0.0)
vals = torch.where(ikernel > 0, theta_vals * phi_vals, theta_vals)
return vals

def _precompute_convolution_tensor_dense(
in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi
):
"""
Helper routine to compute the convolution Tensor in a dense fashion
"""

assert len(in_shape) == 2
assert len(out_shape) == 2

if len(kernel_shape) == 1:
kernel_handle = partial(_compute_vals_isotropic, ntheta=kernel_shape[0], theta_cutoff=theta_cutoff)
kernel_size = kernel_shape[0]
elif len(kernel_shape) == 2:
kernel_handle = partial(_compute_vals_anisotropic, ntheta=kernel_shape[0], nphi=kernel_shape[1], theta_cutoff=theta_cutoff)
kernel_size = (kernel_shape[0]-1)*kernel_shape[1] + 1
else:
raise ValueError("kernel_shape should be either one- or two-dimensional.")

nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape

lats_in, _ = quadrature._precompute_latitudes(nlat_in, grid=grid_in)
lats_in = torch.from_numpy(lats_in).float()
lats_out, _ = quadrature._precompute_latitudes(nlat_out, grid=grid_out)
lats_out = torch.from_numpy(lats_out).float() # array for accumulating non-zero indices

# compute the phi differences. We need to make the linspace exclusive to not double the last point
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1]
lons_out = torch.linspace(0, 2 * math.pi, nlon_out + 1)[:-1]

out = torch.zeros(kernel_size, nlat_out, nlon_out, nlat_in, nlon_in)

for t in range(nlat_out):
for p in range(nlon_out):
alpha = -lats_out[t]
beta = lons_in - lons_out[p]
gamma = lats_in.reshape(-1, 1)

# compute latitude of the rotated position
z = -torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma) + torch.cos(alpha) * torch.cos(gamma)

# compute cartesian coordinates of the rotated position
x = torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma) + torch.cos(gamma) * torch.sin(alpha)
y = torch.sin(beta) * torch.sin(gamma)

# normalize instead of clipping to ensure correct range
norm = torch.sqrt(x * x + y * y + z * z)
x = x / norm
y = y / norm
z = z / norm

# compute spherical coordinates
theta = torch.arccos(z)
phi = torch.arctan2(y, x) + torch.pi

# find the indices where the rotated position falls into the support of the kernel
out[:, t, p, :, :] = kernel_handle(theta, phi)

return out


class TestDiscreteContinuousConvolution(unittest.TestCase):
def setUp(self):
if torch.cuda.is_available():
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")

self.device = torch.device("cpu")

@parameterized.expand(
[
# regular convolution
[8, 4, 2, (16, 32), (16, 32), [2], "equiangular", "equiangular", False, 1e-5],
[8, 4, 2, (16, 32), (8, 16), [3], "equiangular", "equiangular", False, 1e-5],
[8, 4, 2, (16, 32), (8, 16), [2, 3], "equiangular", "equiangular", False, 1e-5],
[8, 4, 2, (18, 36), (6, 12), [4], "equiangular", "equiangular", False, 1e-5],
[8, 4, 2, (16, 32), (8, 16), [3], "equiangular", "legendre-gauss", False, 1e-5],
[8, 4, 2, (16, 32), (8, 16), [3], "legendre-gauss", "equiangular", False, 1e-5],
[8, 4, 2, (16, 32), (8, 16), [3], "legendre-gauss", "legendre-gauss", False, 1e-5],
# transpose convolution
[8, 4, 2, (16, 32), (16, 32), [2], "equiangular", "equiangular", True, 1e-5],
[8, 4, 2, (8, 16), (16, 32), [3], "equiangular", "equiangular", True, 1e-5],
[8, 4, 2, (8, 16), (16, 32), [2, 3], "equiangular", "equiangular", True, 1e-5],
[8, 4, 2, (6, 12), (18, 36), [4], "equiangular", "equiangular", True, 1e-5],
[8, 4, 2, (8, 16), (16, 32), [3], "equiangular", "legendre-gauss", True, 1e-5],
[8, 4, 2, (8, 16), (16, 32), [3], "legendre-gauss", "equiangular", True, 1e-5],
[8, 4, 2, (8, 16), (16, 32), [3], "legendre-gauss", "legendre-gauss", True, 1e-5],
]
)
def test_disco_convolution(
self,
batch_size,
in_channels,
out_channels,
in_shape,
out_shape,
kernel_shape,
grid_in,
grid_out,
transpose,
tol,
):
Conv = DiscreteContinuousConvTransposeS2 if transpose else DiscreteContinuousConvS2
conv = Conv(
in_channels,
out_channels,
in_shape,
out_shape,
kernel_shape,
groups=1,
grid_in=grid_in,
grid_out=grid_out,
bias=False,
).to(self.device)

nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape

theta_cutoff = (kernel_shape[0] + 1) * torch.pi / float(nlat_in - 1)

if transpose:
psi_dense = _precompute_convolution_tensor_dense(
out_shape, in_shape, kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff
).to(self.device)
else:
psi_dense = _precompute_convolution_tensor_dense(
in_shape, out_shape, kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff
).to(self.device)

self.assertTrue(
torch.allclose(conv.psi.to_dense(), psi_dense[:, :, 0].reshape(-1, nlat_out, nlat_in * nlon_in))
)

# create a copy of the weight
w_ref = conv.weight.detach().clone()
w_ref.requires_grad_(True)

# create an input signal
torch.manual_seed(333)
x = torch.randn(batch_size, in_channels, *in_shape, requires_grad=True).to(self.device)

# perform the reference computation
x_ref = x.clone().detach()
x_ref.requires_grad_(True)
if transpose:
y_ref = torch.einsum("oif,biqr->bofqr", w_ref, x_ref)
y_ref = torch.einsum("fqrtp,bofqr->botp", psi_dense, y_ref * conv.quad_weights)
else:
y_ref = torch.einsum("ftpqr,bcqr->bcftp", psi_dense, x_ref * conv.quad_weights)
y_ref = torch.einsum("oif,biftp->botp", w_ref, y_ref)

# use the convolution module
y = conv(x)

# compare results
self.assertTrue(torch.allclose(y, y_ref, rtol=tol, atol=tol))

# compute gradients and compare results
grad_input = torch.randn_like(y)
y_ref.backward(grad_input)
y.backward(grad_input)

# compare
self.assertTrue(torch.allclose(x.grad, x_ref.grad, rtol=tol, atol=tol))
self.assertTrue(torch.allclose(conv.weight.grad, w_ref.grad, rtol=tol, atol=tol))

if __name__ == "__main__":
unittest.main()
4 changes: 2 additions & 2 deletions tests/test_sht.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,9 @@ def test_sht_grad(self, nlat, nlon, batch_size, norm, grid, tol):
coeffs[:, :lmax, :mmax] = torch.randn(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
signal = isht(coeffs)

input = torch.randn_like(signal, requires_grad=True)
grad_input = torch.randn_like(signal, requires_grad=True)
err_handle = lambda x : torch.mean(torch.norm( isht(sht(x)) - signal , p='fro', dim=(-1,-2)) / torch.norm(signal, p='fro', dim=(-1,-2)) )
test_result = gradcheck(err_handle, input, eps=1e-6, atol=tol)
test_result = gradcheck(err_handle, grad_input, eps=1e-6, atol=tol)
self.assertTrue(test_result)


Expand Down
3 changes: 1 addition & 2 deletions torch_harmonics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@
__version__ = '0.6.4'

from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT
from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
from . import quadrature
from . import s2_convolutions
from . import disco_convolutions
from . import random_fields
from . import examples
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: in

assert psi.shape[-1] == nlat_in * nlon_in
assert nlon_in % nlon_out == 0

assert nlon_in >= nlat_out
pscale = nlon_in // nlon_out

# add a dummy dimension for nkernel and move the batch and channel dims to the end
Expand Down Expand Up @@ -414,7 +414,7 @@ def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nl
assert psi.shape[-2] == nlat_in
assert n_out % nlon_out == 0
nlat_out = n_out // nlon_out

assert nlon_out >= nlat_in
pscale = nlon_out // nlon_in

# we do a semi-transposition to faciliate the computation
Expand All @@ -429,7 +429,7 @@ def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nl

# interleave zeros along the longitude dimension to allow for fractional offsets to be considered
x_ext = torch.zeros(kernel_size, nlat_in, nlon_out, batch_size * n_chans, device=x.device, dtype=x.dtype)
x_ext[:, :, (pscale-1)::pscale, :] = x.reshape(batch_size * n_chans, kernel_size, nlat_in, nlon_in).permute(1, 2, 3, 0)
x_ext[:, :, ::pscale, :] = x.reshape(batch_size * n_chans, kernel_size, nlat_in, nlon_in).permute(1, 2, 3, 0)
# we need to go backwards through the vector, so we flip the axis
x_ext = x_ext.contiguous()

Expand Down
Loading

0 comments on commit 942aa4e

Please sign in to comment.