-
Notifications
You must be signed in to change notification settings - Fork 26
/
setup.py
105 lines (88 loc) · 3.2 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from setuptools import setup, find_packages
import os
import glob
try:
import torch
from torch.utils.cpp_extension import (
BuildExtension,
CUDAExtension,
CUDA_HOME,
CppExtension,
)
except:
raise ModuleNotFoundError("Please install pytorch >= 1.1 before proceeding.")
WITH_CUDA = torch.cuda.is_available() and CUDA_HOME is not None
WITH_CPU = True
if os.getenv("FORCE_CUDA", "0") == "1":
WITH_CUDA = True
if os.getenv("FORCE_ONLY_CUDA", "0") == "1":
WITH_CUDA = True
WITH_CPU = False
if os.getenv("FORCE_ONLY_CPU", "0") == "1":
WITH_CUDA = False
WITH_CPU = True
def get_ext_modules():
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
extra_compile_args = {"cxx": ["-O3"]}
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
extra_compile_args["cxx"] += ["-DVERSION_GE_1_3"]
ext_src_root = "cuda"
ext_sources = glob.glob("{}/src/*.cpp".format(ext_src_root)) + glob.glob(
"{}/src/*.cu".format(ext_src_root)
)
ext_modules = []
if WITH_CUDA:
nvcc_flags = os.getenv("NVCC_FLAGS", "")
nvcc_flags = [] if nvcc_flags == "" else nvcc_flags.split(" ")
nvcc_flags += ["-arch=sm_35", "--expt-relaxed-constexpr", "-O2"]
extra_compile_args["nvcc"] = nvcc_flags
ext_modules.append(
CUDAExtension(
name="torch_points_kernels.points_cuda",
sources=ext_sources,
include_dirs=["{}/include".format(ext_src_root)],
extra_compile_args=extra_compile_args,
)
)
cpu_ext_src_root = "cpu"
cpu_ext_sources = glob.glob("{}/src/*.cpp".format(cpu_ext_src_root))
if WITH_CPU:
ext_modules.append(
CppExtension(
name="torch_points_kernels.points_cpu",
sources=cpu_ext_sources,
include_dirs=["{}/include".format(cpu_ext_src_root)],
extra_compile_args=extra_compile_args,
)
)
return ext_modules
class CustomBuildExtension(BuildExtension):
def __init__(self, *args, **kwargs):
super().__init__(*args, no_python_abi_suffix=True, use_ninja=False, **kwargs)
def get_cmdclass():
return {"build_ext": CustomBuildExtension}
this_directory = os.path.abspath(os.path.dirname(__file__))
with open(os.path.join(this_directory, "README.md"), encoding="utf-8") as f:
long_description = f.read()
requirements = ["torch>=1.1.0", "numba", "numpy<=1.21", "scikit-learn"]
url = "https://github.com/nicolas-chaulet/torch-points-kernels"
__version__ = "0.7.1"
setup(
name="torch-points-kernels",
version=__version__,
author="Nicolas Chaulet",
packages=find_packages(),
description="PyTorch kernels for spatial operations on point clouds",
url=url,
download_url="{}/archive/{}.tar.gz".format(url, __version__),
install_requires=requirements,
ext_modules=get_ext_modules(),
cmdclass=get_cmdclass(),
long_description=long_description,
long_description_content_type="text/markdown",
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
],
)