Skip to content

Commit

Permalink
reworked setup.py with delayed torch import
Browse files Browse the repository at this point in the history
  • Loading branch information
bonevbs committed Aug 27, 2024
1 parent ed72700 commit abfaa79
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 44 deletions.
11 changes: 5 additions & 6 deletions .github/workflows/deploy_pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,14 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: "3.9"
- name: Install dependencies
- name: Install build tools
run: |
python -m pip install --upgrade pip
python -m pip install torch --index-url https://download.pytorch.org/whl/cpu
python -m pip install numpy
python -m pip install setuptools wheel build
# - name: Install package
# run: |
# python -m pip install -e .
- name: Install dependencies
run: |
python -m pip install torch
python -m pip install numpy
- name: Build a binary wheel and a source tarball
run: |
python -m build
Expand Down
60 changes: 22 additions & 38 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,30 +30,16 @@
#

import sys
from pathlib import Path
import re

try:
from setuptools import setup, find_packages
except ImportError:
from distutils.core import setup, find_packages

import re
from pathlib import Path

import torch
from torch.utils import cpp_extension

def version(root_path):
"""Returns the version taken from __init__.py
Parameters
----------
root_path : pathlib.Path
path to the root of the package
Reference
---------
https://packaging.python.org/guides/single-sourcing-package-version/
"""
"""Returns the version taken from __init__.py"""
version_path = root_path.joinpath("torch_harmonics", "__init__.py")
with version_path.open() as f:
version_file = f.read()
Expand All @@ -62,26 +48,19 @@ def version(root_path):
return version_match.group(1)
raise RuntimeError("Unable to find version string.")


def readme(root_path):
"""Returns the text content of the README.md of the package
Parameters
----------
root_path : pathlib.Path
path to the root of the package
"""
"""Returns the text content of README.md"""
with root_path.joinpath("README.md").open(encoding="UTF-8") as f:
return f.read()


def get_ext_modules(argv):
# Delay import of torch and cpp_extension
import torch
from torch.utils import cpp_extension

compile_cuda_extension = False

if "--cuda_ext" in sys.argv:
sys.argv.remove("--cuda_ext")
compile_cuda_extension = True
compile_cuda_extension = "--cuda_ext" in argv
if "--cuda_ext" in argv:
argv.remove("--cuda_ext")

ext_modules = [
cpp_extension.CppExtension("disco_helpers", ["torch_harmonics/csrc/disco/disco_helpers.cpp"]),
Expand All @@ -101,14 +80,10 @@ def get_ext_modules(argv):

return ext_modules


root_path = Path(__file__).parent
README = readme(root_path)
VERSION = version(root_path)

# external modules
ext_modules = get_ext_modules(sys.argv)

config = {
"name": "torch_harmonics",
"packages": find_packages(),
Expand All @@ -127,8 +102,17 @@ def get_ext_modules(argv):
"scripts": [],
"include_package_data": True,
"classifiers": ["Topic :: Scientific/Engineering", "License :: OSI Approved :: BSD License", "Programming Language :: Python :: 3"],
"ext_modules": ext_modules,
"cmdclass": {"build_ext": cpp_extension.BuildExtension} if ext_modules else {},
}

setup(**config)
def setup_package():
# Delay the decision about ext_modules and cmdclass
ext_modules = get_ext_modules(sys.argv)
if ext_modules:
from torch.utils import cpp_extension
config["ext_modules"] = ext_modules
config["cmdclass"] = {"build_ext": cpp_extension.BuildExtension}

setup(**config)

if __name__ == "__main__":
setup_package()

0 comments on commit abfaa79

Please sign in to comment.