Skip to content

Commit

Permalink
Compile torch extension in setup.py (#216)
Browse files Browse the repository at this point in the history
* Compile torch extension in setup.py

* Fix extension library discovery

* Make CI more verbose

* Install gxx from conda-forge to ensure ABI compatibility

* Change extension name so it does not collide with NNPOps
  • Loading branch information
RaulPPelaez authored Sep 20, 2023
1 parent ac16c09 commit eeaa20f
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 17 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
use-mamba: true

- name: Install the package
run: pip install .
run: pip -vv install .

- name: List the conda environment
run: conda list
Expand All @@ -43,4 +43,4 @@ jobs:
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Run tests
run: pytest -v
run: pytest -v -s
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ dependencies:
- pytest
- psutil
- ninja
- gxx
24 changes: 23 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,33 @@
print("Failed to retrieve the current version, defaulting to 0")
version = "0"

from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, include_paths, CppExtension
import torch

neighs = CppExtension(
name='torchmdnet.neighbors.torchmdnet_neighbors',
sources=["torchmdnet/neighbors/neighbors.cpp", "torchmdnet/neighbors/neighbors_cpu.cpp"],
include_dirs=include_paths(),
language='c++')

if torch.cuda._is_compiled():
neighs = CUDAExtension(
name='torchmdnet.neighbors.torchmdnet_neighbors',
sources=["torchmdnet/neighbors/neighbors.cpp", "torchmdnet/neighbors/neighbors_cpu.cpp", "torchmdnet/neighbors/neighbors_cuda.cu"],
include_dirs=include_paths(),
language='cuda'
)

setup(
name="torchmd-net",
version=version,
packages=find_packages(),
package_data={"torchmdnet": ["neighbors/neighbors*", "neighbors/*.cu*"]},
ext_modules=[neighs,],
cmdclass={
'build_ext': BuildExtension.with_options(no_python_abi_suffix=True, use_ninja=False)},
include_package_data=True,
entry_points={"console_scripts": ["torchmd-train = torchmdnet.scripts.train:main"]},
package_data={"torchmdnet": ["neighbors/torchmdnet_neighbors.so"]},

)
26 changes: 12 additions & 14 deletions torchmdnet/neighbors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
import os
import os.path as osp
import torch
from torch.utils import cpp_extension

def compile_extension():
src_dir = os.path.dirname(__file__)
sources = ["neighbors.cpp", "neighbors_cpu.cpp"] + (
["neighbors_cuda.cu"] if torch.cuda.is_available() else []
)
sources = [os.path.join(src_dir, name) for name in sources]
cpp_extension.load(
name="torchmdnet_neighbors", sources=sources, is_python_module=False
)

compile_extension()
import importlib.machinery
library = "torchmdnet_neighbors"
# Find the specification for the library
spec = importlib.machinery.PathFinder().find_spec(
library, [osp.dirname(__file__)]
)
# Check if the specification is found and load the library
if spec is not None:
torch.ops.load_library(spec.origin)
else:
raise ImportError(f"Could not find module '{library}' in {osp.dirname(__file__)}")
get_neighbor_pairs_kernel = torch.ops.torchmdnet_neighbors.get_neighbor_pairs

0 comments on commit eeaa20f

Please sign in to comment.