Skip to content

Commit

Permalink
[Feature] Support downloading OpenMMLab projects on Ascend NPU
Browse files Browse the repository at this point in the history
  • Loading branch information
Ginray committed Aug 28, 2023
1 parent bf69e00 commit f905674
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 16 deletions.
21 changes: 15 additions & 6 deletions mim/commands/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
DEFAULT_MMCV_BASE_URL,
PKG2PROJECT,
echo_warning,
get_torch_cuda_version,
exit_with_error,
get_torch_device_version,
)


Expand Down Expand Up @@ -160,16 +161,24 @@ def get_mmcv_full_find_link(mmcv_base_url: str) -> str:
Returns:
str: The mmcv find links corresponding to the current torch version and
cuda version.
CUDA/NPU version.
"""
torch_v, cuda_v = get_torch_cuda_version()
torch_v, device, device_v = get_torch_device_version()
major, minor, *_ = torch_v.split('.')
torch_v = '.'.join([major, minor, '0'])

if cuda_v.isdigit():
cuda_v = f'cu{cuda_v}'
if device == 'cuda' and device_v.isdigit():
device_link = f'cu{device_v}'
elif device == 'ascend':
if not device_v.isdigit():
exit_with_error('Unable to install OpenMMLab projects via mim '
'on the current Ascend NPU, '
'please compile from source code to install.')
device_link = f'ascend{device_v}'
else:
device_link = 'cpu'

find_link = f'{mmcv_base_url}/mmcv/dist/{cuda_v}/torch{torch_v}/index.html' # noqa: E501
find_link = f'{mmcv_base_url}/mmcv/dist/{device_link}/torch{torch_v}/index.html' # noqa: E501
return find_link


Expand Down
4 changes: 2 additions & 2 deletions mim/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
get_package_info_from_pypi,
get_package_version,
get_release_version,
get_torch_cuda_version,
get_torch_device_version,
highlighted_error,
is_installed,
is_version_equal,
Expand All @@ -59,7 +59,7 @@
'get_installed_version',
'get_installed_path',
'get_latest_version',
'get_torch_cuda_version',
'get_torch_device_version',
'is_installed',
'parse_url',
'PKG2PROJECT',
Expand Down
53 changes: 46 additions & 7 deletions mim/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@
from .default import PKG2PROJECT
from .progress_bars import rich_progress_bar

try:
import torch
import torch_npu

IS_NPU_AVAILABLE = hasattr(
torch, 'npu') and torch.npu.is_available() # type: ignore
except Exception:
IS_NPU_AVAILABLE = False


def parse_url(url: str) -> Tuple[str, str]:
"""Parse username and repo from url.
Expand Down Expand Up @@ -327,12 +336,37 @@ def get_installed_path(package: str) -> str:
return osp.join(pkg.location, package2module(package))


def get_torch_cuda_version() -> Tuple[str, str]:
"""Get PyTorch version and CUDA version if it is available.
def is_npu_available() -> bool:
"""Returns True if Ascend PyTorch and npu devices exist."""
return IS_NPU_AVAILABLE


def get_npu_version() -> str:
"""Returns the version of NPU when npu devices exist."""
if not is_npu_available():
return ''
ascend_home_path = os.environ.get('ASCEND_HOME_PATH', None)
if not ascend_home_path or not os.path.exists(ascend_home_path):
raise RuntimeError(
highlighted_error(
f'ASCEND_HOME_PATH:{ascend_home_path} does not exists when '
'installing OpenMMLab projects on Ascend NPU.'
"Please run 'source set_env.sh' in the CANN installation path."
))
npu_version = torch_npu.get_cann_version(ascend_home_path)
return npu_version


def get_torch_device_version() -> Tuple[str, str, str]:
"""Get PyTorch version and CUDA/NPU version if it is available.
Example:
>>> get_torch_cuda_version()
'1.8.0', '102'
>>> get_torch_device_version()
'1.8.0', 'cpu', ''
>>> get_torch_device_version()
'1.8.0', 'cuda', '102'
>>> get_torch_device_version()
'1.11.0', 'ascend', '6.0.2'
"""
try:
import torch
Expand All @@ -344,11 +378,16 @@ def get_torch_cuda_version() -> Tuple[str, str]:
torch_v = torch_v.split('+')[0]

if torch.version.cuda is not None:
device = 'cuda'
# torch.version.cuda like 10.2 -> 102
cuda_v = ''.join(torch.version.cuda.split('.'))
device_v = ''.join(torch.version.cuda.split('.'))
elif is_npu_available():
device = 'ascend'
device_v = get_npu_version()
else:
cuda_v = 'cpu'
return torch_v, cuda_v
device = 'cpu'
device_v = ''
return torch_v, device, device_v


def cast2lowercase(input: Union[list, tuple, str]) -> Any:
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ default_section = THIRDPARTY
include_trailing_comma = true

[codespell]
ignore-words-list = te
ignore-words-list = te, cann
9 changes: 9 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from mim.commands.install import cli as install
from mim.commands.uninstall import cli as uninstall
from mim.utils import get_github_url, parse_home_page
from mim.utils.utils import get_torch_device_version, is_npu_available


def setup_module():
Expand Down Expand Up @@ -39,6 +40,14 @@ def test_get_github_url():
'mmcls') == 'https://github.com/open-mmlab/mmclassification.git'


def test_get_torch_device_version():
torch_v, device, device_v = get_torch_device_version()
assert torch_v.replace('.', '').isdigit()
if is_npu_available():
assert device == 'ascend'
assert device_v.replace('.', '').isdigit()


def teardown_module():
runner = CliRunner()
result = runner.invoke(uninstall, ['mmcv-full', '--yes'])
Expand Down

0 comments on commit f905674

Please sign in to comment.