Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tkurth/cuda disco reduce scatter #39

Closed
wants to merge 43 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
432ce9a
adding cuda kernels for disco conv
azrael417 Mar 5, 2024
b59a595
dbg
azrael417 Mar 5, 2024
5d2d2f3
dbg
azrael417 Mar 5, 2024
0a9272f
dbg
azrael417 Mar 5, 2024
00057be
making psi_idx an attribute
azrael417 Mar 5, 2024
cf803f0
adding license headers
azrael417 Mar 5, 2024
785f37c
adding author files
azrael417 Mar 5, 2024
add47f1
adding author files
azrael417 Mar 5, 2024
3fb9e87
reorganizing files
azrael417 Mar 11, 2024
6d28d78
draft implementation
azrael417 Mar 12, 2024
9fb19b0
added conditional installation to setup.py
bonevbs Mar 12, 2024
86213df
formatting changes
bonevbs Mar 12, 2024
dd8b7ac
removing triton kernel in DISCO convolution
bonevbs Mar 12, 2024
8d1141c
updated github actions
bonevbs Mar 12, 2024
220dc01
updated Readme and changelog
bonevbs Mar 12, 2024
2099340
adding another guard for the cuda installation
bonevbs Mar 13, 2024
e9a8035
renaming the cuda extension
bonevbs Mar 13, 2024
599f5a3
simplifying setup.py
bonevbs Mar 13, 2024
bc03756
minor bugfix
bonevbs Mar 13, 2024
3af2f08
fixing streams
azrael417 Mar 14, 2024
f3b6553
Bbonev/cuda disco cleanup (#32)
bonevbs Mar 13, 2024
15f9618
initial rewrite of the distributed convolution with CUDA
bonevbs Mar 13, 2024
c82dafe
need to fix install options
azrael417 Mar 14, 2024
7877fde
fixing streams
azrael417 Mar 14, 2024
9c20c99
undid setup.py changes
azrael417 Mar 14, 2024
a40ad98
reset setup.py
azrael417 Mar 14, 2024
6441e07
including CUDAStream
azrael417 Mar 14, 2024
61dc795
adjusted the precomputation of theta_cutoff. If you rely on this, you…
bonevbs Mar 20, 2024
3ef3b7c
adjusting theta_cutoff in the unittest
bonevbs Mar 20, 2024
ed2409b
adding newly refactored kernels for faster compile
azrael417 Mar 21, 2024
ac5b9d5
Tkurth/cuda disco distributed fix (#34)
azrael417 Mar 25, 2024
096d5c5
using stream functions from at instead of c10
azrael417 Apr 4, 2024
95131c3
using stream functions from at instead of c10, small fix
azrael417 Apr 4, 2024
d78f7c7
Bbonev/disc even filters (#35)
bonevbs Apr 17, 2024
108631e
reworked normalization of filter basis functions
bonevbs Apr 24, 2024
450bc6b
implemented discrete normalization of disco filters
bonevbs Apr 25, 2024
1cd22f2
relaxing tolerances in convolution unit test
bonevbs Apr 25, 2024
1059810
bugfix to correctly support unequal scale factors in latitudes and lo…
bonevbs Apr 26, 2024
0e9a863
hotfix to a bug in the imports
bonevbs Apr 26, 2024
4d952d1
Bbonev/distributed disco refactor (#37)
bonevbs Apr 29, 2024
f6672fd
fixed initial scale of convolution parameter weights and fixed naming…
bonevbs May 30, 2024
748193e
using fused scatter-reduce routines in fwd and trans conv
azrael417 Aug 19, 2024
b9fac9b
adding cuda disco reduce scatter
azrael417 Aug 19, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .github/workflows/deploy_pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ jobs:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Set up Python 3.9
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: 3.9
python-version: 'pypy3.9'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand All @@ -30,11 +30,11 @@ jobs:
# - name: Publish package to TestPyPI
# uses: pypa/gh-action-pypi-publish@master
# with:
# user: __token__
# user: __token__
# password: ${{ secrets.TEST_PYPI_PASSWORD }}
# repository_url: https://test.pypi.org/legacy/
- name: Publish package to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
user: __token__
user: __token__
password: ${{ secrets.PYPI_PASSWORD }}
6 changes: 3 additions & 3 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ jobs:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Set up Python 3.9
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: 3.9
python-version: '3.10'
- name: Install dependencies
run: |
python -m pip install --upgrade pip setuptools wheel
Expand Down
6 changes: 4 additions & 2 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ The code was authored by the following people:

Boris Bonev - NVIDIA Corporation
Thorsten Kurth - NVIDIA Corporation
Mauro Bisson - NVIDIA Corporation
Massimiliano Fatica - NVIDIA Corporation
Christian Hundt - NVIDIA Corporation
Nikola Kovachki - NVIDIA Corporation
Jean Kossaifi - NVIDIA Corporation
Jean Kossaifi - NVIDIA Corporation
Nikola Kovachki - NVIDIA Corporation
8 changes: 8 additions & 0 deletions Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@

## Versioning

### v0.7.0

* CUDA-accelerated DISCO convolutions
* Updated DISCO convolutions to support even number of collocation points across the diameter
* Distributed DISCO convolutions
* Removed DISCO convolution in the plane to focus on the sphere
* Updated unit tests which now include tests for the distributed convolutions

### v0.6.5

* Discrete-continuous (DISCO) convolutions on the sphere and in two dimensions
Expand Down
8 changes: 6 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
Expand Down Expand Up @@ -34,6 +34,10 @@ FROM nvcr.io/nvidia/pytorch:23.11-py3

COPY . /workspace/torch_harmonics

# we need this for tests
RUN pip install parameterized
RUN pip install /workspace/torch_harmonics

# we need to remove old archs
ENV TORCH_CUDA_ARCH_LIST "7.0 7.2 7.5 8.0 8.6 8.7 9.0+PTX"
RUN pip install --global-option --cuda_ext /workspace/torch_harmonics

21 changes: 14 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
<!--
<!--
SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.

SPDX-License-Identifier: BSD-3-Clause

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

Expand Down Expand Up @@ -160,6 +160,10 @@ $$

Here, $x_j \in [-1,1]$ are the quadrature nodes with the respective quadrature weights $w_j$.

### Discrete-continuous convolutions

torch-harmonics now provides local discrete-continuous (DISCO) convolutions as outlined in [4] on the sphere.

## Getting started

The main functionality of `torch_harmonics` is provided in the form of `torch.nn.Modules` for composability. A minimum example is given by:
Expand Down Expand Up @@ -223,15 +227,15 @@ Depending on the problem, it might be beneficial to upcast data to `float64` ins

## Contributors

[Boris Bonev](https://bonevbs.github.io) (bbonev@nvidia.com), [Thorsten Kurth](https://github.com/azrael417) (tkurth@nvidia.com), [Christian Hundt](https://github.com/gravitino) (chundt@nvidia.com), [Nikola Kovachki](https://kovachki.github.io) (nkovachki@nvidia.com), [Jean Kossaifi](http://jeankossaifi.com) (jkossaifi@nvidia.com)
[Boris Bonev](https://bonevbs.github.io) (bbonev@nvidia.com), [Thorsten Kurth](https://github.com/azrael417) (tkurth@nvidia.com), [Mauro Bisson](https://scholar.google.com/citations?hl=en&user=f0JE-0gAAAAJ) , [Massimiliano Fatica](https://scholar.google.com/citations?user=Deaq4uUAAAAJ&hl=en), [Nikola Kovachki](https://kovachki.github.io), [Jean Kossaifi](http://jeankossaifi.com), [Christian Hundt](https://github.com/gravitino)

## Cite us

If you use `torch-harmonics` in an academic paper, please cite [1]

```bibtex
@misc{bonev2023spherical,
title={Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere},
title={Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere},
author={Boris Bonev and Thorsten Kurth and Christian Hundt and Jaideep Pathak and Maximilian Baust and Karthik Kashinath and Anima Anandkumar},
year={2023},
eprint={2306.03838},
Expand All @@ -242,17 +246,20 @@ If you use `torch-harmonics` in an academic paper, please cite [1]

## References

<a id="1">[1]</a>
<a id="1">[1]</a>
Bonev B., Kurth T., Hundt C., Pathak, J., Baust M., Kashinath K., Anandkumar A.;
Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere;
arXiv 2306.0383, 2023.

<a id="1">[2]</a>
<a id="1">[2]</a>
Schaeffer N.;
Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations;
G3: Geochemistry, Geophysics, Geosystems, 2013.

<a id="1">[3]</a>
<a id="1">[3]</a>
Wang B., Wang L., Xie Z.;
Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids;
Adv Comput Math, 2018.

<a id="1">[4]</a>
Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
6 changes: 3 additions & 3 deletions notebooks/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
Expand Down Expand Up @@ -84,7 +84,7 @@ def plot_data(data,
**kwargs):
if fig == None:
fig = plt.figure()

nlat = data.shape[-2]
nlon = data.shape[-1]
if lon is None:
Expand All @@ -96,7 +96,7 @@ def plot_data(data,
fig = plt.figure(figsize=(10, 5))
ax = fig.add_subplot(1, 1, 1, projection=projection)
im = ax.pcolormesh(Lon, Lat, data, cmap=cmap, **kwargs)

if colorbar:
plt.colorbar(im)
plt.title(title, y=1.05)
Expand Down
81 changes: 56 additions & 25 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
Expand All @@ -28,6 +28,9 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import sys

try:
from setuptools import setup, find_packages
except ImportError:
Expand All @@ -36,6 +39,8 @@
import re
from pathlib import Path

import torch
from torch.utils import cpp_extension

def version(root_path):
"""Returns the version taken from __init__.py
Expand All @@ -49,11 +54,10 @@ def version(root_path):
---------
https://packaging.python.org/guides/single-sourcing-package-version/
"""
version_path = root_path.joinpath('torch_harmonics', '__init__.py')
version_path = root_path.joinpath("torch_harmonics", "__init__.py")
with version_path.open() as f:
version_file = f.read()
version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
version_file, re.M)
version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M)
if version_match:
return version_match.group(1)
raise RuntimeError("Unable to find version string.")
Expand All @@ -67,37 +71,64 @@ def readme(root_path):
root_path : pathlib.Path
path to the root of the package
"""
with root_path.joinpath('README.md').open(encoding='UTF-8') as f:
with root_path.joinpath("README.md").open(encoding="UTF-8") as f:
return f.read()


def get_ext_modules(argv):

compile_cuda_extension = False

if "--cuda_ext" in sys.argv:
sys.argv.remove("--cuda_ext")
compile_cuda_extension = True

ext_modules = [
cpp_extension.CppExtension("disco_helpers", ["torch_harmonics/csrc/disco/disco_helpers.cpp"]),
]

if torch.cuda.is_available() or compile_cuda_extension:
ext_modules.append(
cpp_extension.CUDAExtension(
"disco_cuda_extension",
[
"torch_harmonics/csrc/disco/disco_interface.cu",
"torch_harmonics/csrc/disco/disco_cuda_fwd.cu",
"torch_harmonics/csrc/disco/disco_cuda_bwd.cu",
],
)
)

return ext_modules


root_path = Path(__file__).parent
README = readme(root_path)
VERSION = version(root_path)

# external modules
ext_modules = get_ext_modules(sys.argv)

config = {
'name': 'torch_harmonics',
'packages': find_packages(),
'description': 'A differentiable spherical harmonic transform for PyTorch.',
'long_description': README,
'long_description_content_type' : 'text/markdown',
'url' : 'https://github.com/NVIDIA/torch-harmonics',
'author': 'Boris Bonev',
'author_email': 'bbonev@nvidia.com',
'version': VERSION,
'install_requires': ['torch', 'numpy', 'triton'],
'extras_require': {
'sfno': ['tensorly', 'tensorly-torch'],
"name": "torch_harmonics",
"packages": find_packages(),
"description": "A differentiable spherical harmonic transform for PyTorch.",
"long_description": README,
"long_description_content_type": "text/markdown",
"url": "https://github.com/NVIDIA/torch-harmonics",
"author": "Boris Bonev",
"author_email": "bbonev@nvidia.com",
"version": VERSION,
"install_requires": ["torch", "numpy"],
"extras_require": {
"sfno": ["tensorly", "tensorly-torch"],
},
'license': 'Modified BSD',
'scripts': [],
'include_package_data': True,
'classifiers': [
'Topic :: Scientific/Engineering',
'License :: OSI Approved :: BSD License',
'Programming Language :: Python :: 3'
],
"license": "Modified BSD",
"scripts": [],
"include_package_data": True,
"classifiers": ["Topic :: Scientific/Engineering", "License :: OSI Approved :: BSD License", "Programming Language :: Python :: 3"],
"ext_modules": ext_modules,
"cmdclass": {"build_ext": cpp_extension.BuildExtension} if ext_modules else {},
}

setup(**config)
Loading