From eeaa20f576516cc25bc41720c8db69a543c39874 Mon Sep 17 00:00:00 2001 From: Raul Date: Wed, 20 Sep 2023 14:35:19 +0200 Subject: [PATCH] Compile torch extension in setup.py (#216) * Compile torch extension in setup.py * Fix extension library discovery * Make CI more verbose * Install gxx from conda-forge to ensure ABI compatibility * Change extension name so it does not collide with NNPOps --- .github/workflows/CI.yml | 4 ++-- environment.yml | 1 + setup.py | 24 +++++++++++++++++++++++- torchmdnet/neighbors/__init__.py | 26 ++++++++++++-------------- 4 files changed, 38 insertions(+), 17 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index a3efe044c..1f9067d1f 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -30,7 +30,7 @@ jobs: use-mamba: true - name: Install the package - run: pip install . + run: pip -vv install . - name: List the conda environment run: conda list @@ -43,4 +43,4 @@ jobs: flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - name: Run tests - run: pytest -v + run: pytest -v -s diff --git a/environment.yml b/environment.yml index d7a5207e5..900781ab5 100644 --- a/environment.yml +++ b/environment.yml @@ -20,3 +20,4 @@ dependencies: - pytest - psutil - ninja + - gxx diff --git a/setup.py b/setup.py index 2fd821ed0..8950a7662 100644 --- a/setup.py +++ b/setup.py @@ -11,11 +11,33 @@ print("Failed to retrieve the current version, defaulting to 0") version = "0" +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension, include_paths, CppExtension +import torch + +neighs = CppExtension( + name='torchmdnet.neighbors.torchmdnet_neighbors', + sources=["torchmdnet/neighbors/neighbors.cpp", "torchmdnet/neighbors/neighbors_cpu.cpp"], + include_dirs=include_paths(), + language='c++') + +if torch.cuda._is_compiled(): + neighs = CUDAExtension( + name='torchmdnet.neighbors.torchmdnet_neighbors', + sources=["torchmdnet/neighbors/neighbors.cpp", "torchmdnet/neighbors/neighbors_cpu.cpp", "torchmdnet/neighbors/neighbors_cuda.cu"], + include_dirs=include_paths(), + language='cuda' + ) + setup( name="torchmd-net", version=version, packages=find_packages(), - package_data={"torchmdnet": ["neighbors/neighbors*", "neighbors/*.cu*"]}, + ext_modules=[neighs,], + cmdclass={ + 'build_ext': BuildExtension.with_options(no_python_abi_suffix=True, use_ninja=False)}, include_package_data=True, entry_points={"console_scripts": ["torchmd-train = torchmdnet.scripts.train:main"]}, + package_data={"torchmdnet": ["neighbors/torchmdnet_neighbors.so"]}, + ) diff --git a/torchmdnet/neighbors/__init__.py b/torchmdnet/neighbors/__init__.py index 3f4fdc29d..4f96c8334 100644 --- a/torchmdnet/neighbors/__init__.py +++ b/torchmdnet/neighbors/__init__.py @@ -1,16 +1,14 @@ -import os +import os.path as osp import torch -from torch.utils import cpp_extension - -def compile_extension(): - src_dir = os.path.dirname(__file__) - sources = ["neighbors.cpp", "neighbors_cpu.cpp"] + ( - ["neighbors_cuda.cu"] if torch.cuda.is_available() else [] - ) - sources = [os.path.join(src_dir, name) for name in sources] - cpp_extension.load( - name="torchmdnet_neighbors", sources=sources, is_python_module=False - ) - -compile_extension() +import importlib.machinery +library = "torchmdnet_neighbors" +# Find the specification for the library +spec = importlib.machinery.PathFinder().find_spec( + library, [osp.dirname(__file__)] +) +# Check if the specification is found and load the library +if spec is not None: + torch.ops.load_library(spec.origin) +else: + raise ImportError(f"Could not find module '{library}' in {osp.dirname(__file__)}") get_neighbor_pairs_kernel = torch.ops.torchmdnet_neighbors.get_neighbor_pairs