Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support downloading OpenMMLab projects on Ascend NPU #228

Merged
merged 1 commit into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions mim/commands/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
DEFAULT_MMCV_BASE_URL,
PKG2PROJECT,
echo_warning,
get_torch_cuda_version,
exit_with_error,
get_torch_device_version,
)


Expand Down Expand Up @@ -160,16 +161,24 @@ def get_mmcv_full_find_link(mmcv_base_url: str) -> str:

Returns:
str: The mmcv find links corresponding to the current torch version and
cuda version.
CUDA/NPU version.
"""
torch_v, cuda_v = get_torch_cuda_version()
torch_v, device, device_v = get_torch_device_version()
major, minor, *_ = torch_v.split('.')
torch_v = '.'.join([major, minor, '0'])

if cuda_v.isdigit():
cuda_v = f'cu{cuda_v}'
if device == 'cuda' and device_v.isdigit():
device_link = f'cu{device_v}'
elif device == 'ascend':
if not device_v.isdigit():
exit_with_error('Unable to install OpenMMLab projects via mim '
'on the current Ascend NPU, '
'please compile from source code to install.')
device_link = f'ascend{device_v}'
else:
device_link = 'cpu'

find_link = f'{mmcv_base_url}/mmcv/dist/{cuda_v}/torch{torch_v}/index.html' # noqa: E501
find_link = f'{mmcv_base_url}/mmcv/dist/{device_link}/torch{torch_v}/index.html' # noqa: E501
return find_link


Expand Down
4 changes: 2 additions & 2 deletions mim/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
get_package_info_from_pypi,
get_package_version,
get_release_version,
get_torch_cuda_version,
get_torch_device_version,
highlighted_error,
is_installed,
is_version_equal,
Expand All @@ -59,7 +59,7 @@
'get_installed_version',
'get_installed_path',
'get_latest_version',
'get_torch_cuda_version',
'get_torch_device_version',
'is_installed',
'parse_url',
'PKG2PROJECT',
Expand Down
54 changes: 47 additions & 7 deletions mim/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@
from .default import PKG2PROJECT
from .progress_bars import rich_progress_bar

try:
import torch
import torch_npu

IS_NPU_AVAILABLE = hasattr(
torch, 'npu') and torch.npu.is_available() # type: ignore
except Exception:
IS_NPU_AVAILABLE = False


def parse_url(url: str) -> Tuple[str, str]:
"""Parse username and repo from url.
Expand Down Expand Up @@ -327,12 +336,37 @@ def get_installed_path(package: str) -> str:
return osp.join(pkg.location, package2module(package))


def get_torch_cuda_version() -> Tuple[str, str]:
"""Get PyTorch version and CUDA version if it is available.
def is_npu_available() -> bool:
"""Returns True if Ascend PyTorch and npu devices exist."""
return IS_NPU_AVAILABLE


def get_npu_version() -> str:
"""Returns the version of NPU when npu devices exist."""
if not is_npu_available():
return ''
ascend_home_path = os.environ.get('ASCEND_HOME_PATH', None)
if not ascend_home_path or not os.path.exists(ascend_home_path):
raise RuntimeError(
highlighted_error(
f'ASCEND_HOME_PATH:{ascend_home_path} does not exists when '
'installing OpenMMLab projects on Ascend NPU.'
"Please run 'source set_env.sh' in the CANN installation path."
))
npu_version = torch_npu.get_cann_version(ascend_home_path)
return npu_version


def get_torch_device_version() -> Tuple[str, str, str]:
"""Get PyTorch version and CUDA/NPU version if it is available.

Example:
>>> get_torch_cuda_version()
'1.8.0', '102'
>>> get_torch_device_version()
'1.8.0', 'cpu', ''
>>> get_torch_device_version()
'1.8.0', 'cuda', '102'
>>> get_torch_device_version()
'1.11.0', 'ascend', '602'
"""
try:
import torch
Expand All @@ -344,11 +378,17 @@ def get_torch_cuda_version() -> Tuple[str, str]:
torch_v = torch_v.split('+')[0]

if torch.version.cuda is not None:
device = 'cuda'
# torch.version.cuda like 10.2 -> 102
cuda_v = ''.join(torch.version.cuda.split('.'))
device_v = ''.join(torch.version.cuda.split('.'))
elif is_npu_available():
device = 'ascend'
device_v = get_npu_version()
device_v = ''.join(device_v.split('.'))
else:
cuda_v = 'cpu'
return torch_v, cuda_v
device = 'cpu'
device_v = ''
return torch_v, device, device_v


def cast2lowercase(input: Union[list, tuple, str]) -> Any:
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ default_section = THIRDPARTY
include_trailing_comma = true

[codespell]
ignore-words-list = te
ignore-words-list = te, cann
9 changes: 9 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from mim.commands.install import cli as install
from mim.commands.uninstall import cli as uninstall
from mim.utils import get_github_url, parse_home_page
from mim.utils.utils import get_torch_device_version, is_npu_available


def setup_module():
Expand Down Expand Up @@ -39,6 +40,14 @@ def test_get_github_url():
'mmcls') == 'https://github.com/open-mmlab/mmclassification.git'


def test_get_torch_device_version():
torch_v, device, device_v = get_torch_device_version()
assert torch_v.replace('.', '').isdigit()
if is_npu_available():
assert device == 'ascend'
assert device_v.replace('.', '').isdigit()


def teardown_module():
runner = CliRunner()
result = runner.invoke(uninstall, ['mmcv-full', '--yes'])
Expand Down
Loading