-
Notifications
You must be signed in to change notification settings - Fork 82
/
setup.py
137 lines (121 loc) · 4.23 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
from setuptools import setup, find_packages
from distutils.errors import (
DistutilsPlatformError,
)
from setuptools_rust import Binding, RustExtension
import shutil
import sys
from typing import List
from pathlib import Path
library_records = {}
def check_torch_version():
try:
import torch
except ImportError:
print("import torch failed, is it installed?")
version = torch.__version__
if version is None:
raise DistutilsPlatformError(
"Unable to determine PyTorch version from the version string '%s'"
% torch.__version__
)
return version
def install_baguanet(destination):
os.makedirs(destination, exist_ok=True)
os.system("cd rust/bagua-net/cc && make")
shutil.move(
"rust/bagua-net/cc/libnccl-net.so", os.path.join(destination, "libnccl-net.so")
)
def install_dependency_library():
nvcc_version = (
os.popen(
"nvcc --version | grep release | sed 's/.*release //' | sed 's/,.*//'"
)
.read()
.strip()
)
print("nvcc_version: ", nvcc_version)
install_baguanet(os.path.join(cwd, "bagua_core", ".data", "bagua-net"))
if __name__ == "__main__":
import colorama
colorama.init(autoreset=True)
cwd = os.path.dirname(os.path.abspath(__file__))
def check_args(args: List[str]) -> bool:
for arg in ["build", "install", "develop", "bdist_wheel", "wheel"]:
if arg in args:
return True
return False
if (
len(sys.argv) > 1
and check_args(sys.argv) # noqa: W503
):
if int(os.getenv("BAGUA_NO_INSTALL_DEPS", 0)) == 0:
print(
colorama.Fore.BLACK
+ colorama.Back.CYAN
+ "Bagua is automatically installing some system dependencies like bagua-net, to disable set env variable BAGUA_NO_INSTALL_DEPS=1",
)
os.system("python3 bagua_core/bagua_install_deps.py")
install_dependency_library()
else:
os.makedirs(os.path.join(cwd, "bagua_core", ".data"), exist_ok=True)
name_suffix = os.getenv("BAGUA_CUDA_VERSION", "")
if name_suffix != "":
name_suffix = "-cuda" + name_suffix
this_directory = Path(__file__).parent
long_description = (this_directory / "README.md").read_text()
setup(
name="bagua" + name_suffix,
use_scm_version={
"local_scheme": "no-local-version",
"fallback_version": "0.9.2",
},
setup_requires=["setuptools_scm"],
url="https://github.com/BaguaSys/bagua",
python_requires=">=3.7",
description="Bagua is a deep learning training acceleration framework for PyTorch. It provides a one-stop training acceleration solution, including faster distributed training compared to PyTorch DDP, faster dataloader, kernel fusion, and more.",
long_description=long_description,
long_description_content_type="text/markdown",
packages=find_packages(exclude=("tests")),
package_data={
"": [
".data/bagua-net/libnccl-net.so",
]
},
rust_extensions=[
RustExtension(
"bagua_core.bagua_core",
path="rust/bagua-core/bagua-core-py/Cargo.toml",
binding=Binding.PyO3,
native=False,
),
],
author="Kuaishou AI Platform & DS3 Lab",
author_email="admin@mail.xrlian.com",
install_requires=[
"setuptools_rust",
"colorama",
"tqdm",
"deprecation>=2.1",
"pytest-benchmark>=3.4",
"scikit-optimize>=0.8.1",
"scikit-learn>=0.24,!=1.0,<1.2.2",
"numpy",
"flask>=2.0",
"prometheus_client>=0.11",
"parallel-ssh==2.9.1",
"pydantic>=1.8",
"requests>=2.25",
"gorilla==0.4.0",
"gevent>=21.8",
"xxhash>=2.0",
],
entry_points={
"console_scripts": [
"baguarun = bagua.script.baguarun:main",
],
},
scripts=["bagua/script/bagua_sys_perf", "bagua_core/bagua_install_deps.py"],
zip_safe=False,
)