Skip to content

Commit

Permalink
rewrite zluda installer
Browse files Browse the repository at this point in the history
  • Loading branch information
lshqqytiger committed Apr 30, 2024
1 parent b64ac51 commit d1bbe65
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 68 deletions.
24 changes: 14 additions & 10 deletions installer.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,19 +495,29 @@ def is_rocm_available():
rocm_ver = None
if args.use_zluda:
log.warning("ZLUDA support: experimental")
error = None
from modules import zluda_installer
try:
from modules import zluda_installer
if args.use_zluda_dnn:
if zluda_installer.check_dnn_dependency():
zluda_installer.enable_dnn()
else:
log.warning("Couldn't find the required dependency of ZLUDA DNN.")
zluda_installer.install()
zluda_installer.resolve_path()
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.3.0 torchvision --index-url https://download.pytorch.org/whl/cu118')
log.info(f'Using ZLUDA in {zluda_installer.ZLUDA_PATH}')
zluda_path = zluda_installer.find()
zluda_installer.make_copy(zluda_path)
except Exception as e:
error = e
log.warning(f'Failed to install ZLUDA: {e}')
if error is None:
try:
zluda_installer.load(zluda_path)
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.3.0 torchvision --index-url https://download.pytorch.org/whl/cu118')
log.info(f'Using ZLUDA in {zluda_path}')
except Exception as e:
error = e
log.warning(f'Failed to load ZLUDA: {e}')
if error is not None:
log.info('Using CPU-only torch')
torch_command = os.environ.get('TORCH_COMMAND', 'torch torchvision')
elif is_windows: # TODO TBD after ROCm for Windows is released
Expand Down Expand Up @@ -594,12 +604,6 @@ def is_rocm_available():
if not installed('torch', quiet=True):
log.debug(f'Installing torch: {torch_command}')
install(torch_command, 'torch torchvision')
if args.use_zluda:
try:
from modules.zluda_installer import patch as patch_torch
patch_torch()
except Exception as e:
log.warning(f'ZLUDA: failed to automatically patch torch: {e}')
else:
try:
import torch
Expand Down
100 changes: 42 additions & 58 deletions modules/zluda_installer.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,23 @@
import os
import ctypes
import shutil
import zipfile
import tarfile
import platform
import urllib.request


RELEASE = 'rel.2804604c29b5fa36deca9ece219d3970b61d4c27'
TARGETS = {
RELEASE = f"rel.{os.environ.get('ZLUDA_HASH', '2804604c29b5fa36deca9ece219d3970b61d4c27')}"
DLL_MAPPING = {
'cublas.dll': 'cublas64_11.dll',
'cusparse.dll': 'cusparse64_11.dll',
'nvrtc.dll': 'nvrtc64_112_0.dll',
}
ZLUDA_PATH = None
TORCHLIB_PATH = None
HIP_TARGETS = ('rocblas.dll', 'rocsolver.dll', 'hiprtc0507.dll',)
ZLUDA_TARGETS = ('nvcuda.dll', 'nvml.dll',)


def find_zluda_path():
zluda_path = os.environ.get('ZLUDA', None)
if zluda_path is None:
paths = os.environ.get('PATH', '').split(';')
for path in paths:
if os.path.exists(os.path.join(path, 'zluda_redirect.dll')):
zluda_path = path
break
return zluda_path


def find_venv_dir():
python_dir = os.path.dirname(shutil.which('python'))
if shutil.which('conda') is None:
python_dir = os.path.dirname(python_dir)
return os.environ.get('VENV_DIR', python_dir)


def reset_torch():
for dll in TARGETS.values():
path = os.path.join(TORCHLIB_PATH, dll)
if os.path.exists(path):
os.remove(path)


def is_patched():
for dll in TARGETS.values():
if not os.path.islink(os.path.join(TORCHLIB_PATH, dll)):
return False
return True
def find():
return os.path.abspath(os.environ.get('ZLUDA', '.zluda'))


def check_dnn_dependency():
Expand All @@ -59,36 +31,48 @@ def check_dnn_dependency():

def enable_dnn():
global RELEASE # pylint: disable=global-statement
TARGETS['cudnn.dll'] = 'cudnn64_8.dll'
DLL_MAPPING['cudnn.dll'] = 'cudnn64_8.dll'
RELEASE = 'v3.8-pre2-dnn'


def install():
global ZLUDA_PATH, TORCHLIB_PATH # pylint: disable=global-statement
ZLUDA_PATH = find_zluda_path()
TORCHLIB_PATH = os.path.join(find_venv_dir(), 'Lib', 'site-packages', 'torch', 'lib')
zluda_path = find()

if ZLUDA_PATH is not None:
if os.path.exists(zluda_path):
return

is_windows = platform.system() == 'Windows'
archive_type = zipfile.ZipFile if is_windows else tarfile.TarFile
urllib.request.urlretrieve(f'https://github.com/lshqqytiger/ZLUDA/releases/download/{RELEASE}/ZLUDA-{platform.system().lower()}-amd64.{"zip" if is_windows else "tar.gz"}', '_zluda')
with archive_type('_zluda', 'r') as f:
f.extractall('.zluda')
ZLUDA_PATH = os.path.abspath('./.zluda')
os.remove('_zluda')

if platform.system() != 'Windows': # TODO
return

def resolve_path():
paths = os.environ.get('PATH', '.')
if ZLUDA_PATH not in paths:
os.environ['PATH'] = paths + ';' + ZLUDA_PATH
urllib.request.urlretrieve(f'https://github.com/lshqqytiger/ZLUDA/releases/download/{RELEASE}/ZLUDA-windows-amd64.zip', '_zluda')
with zipfile.ZipFile('_zluda', 'r') as archive:
infos = archive.infolist()
for info in infos:
if not info.is_dir():
info.filename = os.path.basename(info.filename)
archive.extract(info, '.zluda')
os.remove('_zluda')


def patch():
if is_patched():
return
reset_torch()
for k, v in TARGETS.items():
os.symlink(os.path.join(ZLUDA_PATH, k), os.path.join(TORCHLIB_PATH, v))
def make_copy(zluda_path: os.PathLike):
for k, v in DLL_MAPPING.items():
if not os.path.exists(os.path.join(zluda_path, v)):
try:
os.link(os.path.join(zluda_path, k), os.path.join(zluda_path, v))
except Exception:
shutil.copyfile(os.path.join(zluda_path, k), os.path.join(zluda_path, v))


def load(zluda_path: os.PathLike):
hip_path_default = r'C:\Program Files\AMD\ROCm\5.7'
if not os.path.exists(hip_path_default):
hip_path_default = None
hip_path = os.environ.get('HIP_PATH', hip_path_default)
if hip_path is None:
raise RuntimeError('Could not find %HIP_PATH%. Please install AMD HIP SDK.')
for v in HIP_TARGETS:
ctypes.windll.LoadLibrary(os.path.join(hip_path, 'bin', v))
for v in ZLUDA_TARGETS:
ctypes.windll.LoadLibrary(os.path.join(zluda_path, v))
for v in DLL_MAPPING.values():
ctypes.windll.LoadLibrary(os.path.join(zluda_path, v))

0 comments on commit d1bbe65

Please sign in to comment.