Skip to content

Commit

Permalink
Rename internal functions with _ prefix
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
  • Loading branch information
lebrice committed May 16, 2024
1 parent ee33330 commit 4f35924
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 68 deletions.
64 changes: 26 additions & 38 deletions milatools/utils/vscode_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
logger = get_logger(__name__)


def running_inside_WSL() -> bool:
def _running_inside_WSL() -> bool:
return sys.platform == "linux" and bool(shutil.which("powershell.exe"))


Expand All @@ -44,7 +44,7 @@ def get_expected_vscode_settings_json_path() -> Path:
/ "User"
/ "settings.json"
)
if running_inside_WSL():
if _running_inside_WSL():
# Need to get the Windows Home directory, not the WSL one!
windows_username = subprocess.getoutput("powershell.exe '$env:UserName'")
return Path(
Expand All @@ -58,7 +58,7 @@ def get_code_command() -> str:
return os.environ.get("MILATOOLS_CODE_COMMAND", "code")


def get_local_vscode_executable_path(code_command: str | None = None) -> str:
def _get_local_vscode_executable_path(code_command: str | None = None) -> str:
if code_command is None:
code_command = get_code_command()

Expand All @@ -70,7 +70,7 @@ def get_local_vscode_executable_path(code_command: str | None = None) -> str:

def vscode_installed() -> bool:
try:
_ = get_local_vscode_executable_path()
_ = _get_local_vscode_executable_path()
except CommandNotFoundError:
return False
return True
Expand All @@ -97,7 +97,7 @@ async def sync_vscode_extensions(
logger.info("No destinations to sync extensions to!")
return {}

source_hostname, source_extensions = await get_vscode_extensions(source)
source_extensions = await _get_vscode_extensions(source)

task_hostnames: list[str] = []
tasks: list[AsyncTaskFn[ProgressDict]] = []
Expand All @@ -111,14 +111,14 @@ async def sync_vscode_extensions(
task_hostnames.append(dest_hostname)
tasks.append(
functools.partial(
install_vscode_extensions_task_function,
_install_vscode_extensions_task_function,
dest_hostname=dest_hostname,
source_extensions=source_extensions,
remote=dest_runner,
source_name=source_hostname,
source_name=source.hostname,
)
)
task_descriptions.append(f"{source_hostname} -> {dest_hostname}")
task_descriptions.append(f"{source.hostname} -> {dest_hostname}")

return {
hostname: result
Expand Down Expand Up @@ -161,24 +161,23 @@ async def _get_runner_and_hostname(
return dest_remote, dest_remote.hostname


async def get_vscode_extensions(
async def _get_vscode_extensions(
source: LocalV2 | RemoteV2,
) -> tuple[str, dict[str, str]]:
) -> dict[str, str]:
if isinstance(source, LocalV2):
code_server_executable = get_local_vscode_executable_path(code_command=None)
code_server_executable = _get_local_vscode_executable_path(code_command=None)
else:
code_server_executable = await find_code_server_executable(
code_server_executable = await _find_code_server_executable(
source, remote_vscode_server_dir="~/.vscode-server"
)
if not code_server_executable:
raise RuntimeError(
f"The vscode-server executable was not found on {source.hostname}."
)
source_extensions = await _get_vscode_extensions(source, code_server_executable)
return source.hostname, source_extensions
return await _get_vscode_extensions_dict(source, code_server_executable)


async def install_vscode_extensions_task_function(
async def _install_vscode_extensions_task_function(
task_progress_dict: dict[TaskID, ProgressDict],
task_id: TaskID,
dest_hostname: str | Literal["localhost"],
Expand Down Expand Up @@ -219,15 +218,15 @@ def _update_progress(

if isinstance(remote, LocalV2):
assert dest_hostname == "localhost"
code_server_executable = get_local_vscode_executable_path()
extensions_on_dest = await _get_vscode_extensions(
code_server_executable = _get_local_vscode_executable_path()
extensions_on_dest = await _get_vscode_extensions_dict(
remote, code_server_executable
)
else:
dest_hostname = remote.hostname
remote_vscode_server_dir = "~/.vscode-server"
_update_progress(0, f"Looking for code-server in {remote_vscode_server_dir}")
code_server_executable = await find_code_server_executable(
code_server_executable = await _find_code_server_executable(
remote,
remote_vscode_server_dir=remote_vscode_server_dir,
)
Expand All @@ -244,12 +243,12 @@ def _update_progress(
status="code-server executable not found!",
)
_update_progress(0, status="fetching installed extensions...")
extensions_on_dest = await _get_vscode_extensions(
extensions_on_dest = await _get_vscode_extensions_dict(
remote, code_server_executable
)

logger.debug(f"{len(source_extensions)=}, {len(extensions_on_dest)=}")
to_install = extensions_to_install(
to_install = _extensions_to_install(
source_extensions,
extensions_on_dest,
source_name=source_name,
Expand All @@ -270,7 +269,7 @@ def _update_progress(
total=len(to_install),
status=f"Installing {extension_name}",
)
result = await install_vscode_extension(
result = await _install_vscode_extension(
remote,
code_server_executable,
extension=f"{extension_name}@{extension_version}",
Expand All @@ -293,7 +292,7 @@ def _update_progress(
)


async def install_vscode_extension(
async def _install_vscode_extension(
remote: LocalV2 | RemoteV2,
code_server_executable: str,
extension: str,
Expand All @@ -311,24 +310,13 @@ async def install_vscode_extension(
return result


async def _get_local_vscode_extensions(
code_command: str | None = None,
) -> dict[str, str]:
return await _get_vscode_extensions(
LocalV2(),
code_server_executable=get_local_vscode_executable_path(
code_command=code_command
),
)


async def _get_vscode_extensions(
async def _get_vscode_extensions_dict(
remote: RemoteV2 | LocalV2,
code_server_executable: str,
) -> dict[str, str]:
"""Returns the list of isntalled extensions and the path to the code-server
executable."""
return parse_vscode_extensions_versions(
return _parse_vscode_extensions_versions(
stripped_lines_of(
await remote.get_output_async(
f"{code_server_executable} --list-extensions --show-versions",
Expand All @@ -339,7 +327,7 @@ async def _get_vscode_extensions(
)


def extensions_to_install(
def _extensions_to_install(
source_extensions: dict[str, str],
dest_extensions: dict[str, str],
source_name: str,
Expand Down Expand Up @@ -374,7 +362,7 @@ def extensions_to_install(
return extensions_to_install_on_dest


async def find_code_server_executable(
async def _find_code_server_executable(
remote: RemoteV2, remote_vscode_server_dir: str = "~/.vscode-server"
) -> str | None:
"""Find the most recent `code-server` executable on the remote.
Expand Down Expand Up @@ -449,7 +437,7 @@ async def find_code_server_executable(
return most_recent_code_server_executable


def parse_vscode_extensions_versions(
def _parse_vscode_extensions_versions(
list_extensions_output_lines: list[str],
) -> dict[str, str]:
extensions = [line for line in list_extensions_output_lines if "@" in line]
Expand Down
12 changes: 6 additions & 6 deletions tests/integration/test_sync_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from milatools.utils.local_v2 import LocalV2
from milatools.utils.remote_v2 import RemoteV2
from milatools.utils.vscode_utils import (
extensions_to_install,
find_code_server_executable,
install_vscode_extensions_task_function,
_extensions_to_install,
_find_code_server_executable,
_install_vscode_extensions_task_function,
sync_vscode_extensions,
)

Expand Down Expand Up @@ -65,15 +65,15 @@ def mock_and_patch(wraps: Callable, *mock_args, **mock_kwargs):
return mock

mock_task_function = mock_and_patch(
wraps=install_vscode_extensions_task_function,
wraps=_install_vscode_extensions_task_function,
)
# Make it so we only need to install this particular extension.
mock_extensions_to_install = mock_and_patch(
wraps=extensions_to_install,
wraps=_extensions_to_install,
return_value={"ms-python.python": "v2024.0.1"},
)
mock_find_code_server_executable = mock_and_patch(
wraps=find_code_server_executable,
wraps=_find_code_server_executable,
)

await sync_vscode_extensions(
Expand Down
50 changes: 26 additions & 24 deletions tests/utils/test_vscode_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,20 @@
import pytest_asyncio

from milatools.cli.utils import MilatoolsUserError, running_inside_WSL
from milatools.utils.local_v2 import LocalV2
from milatools.utils.parallel_progress import ProgressDict
from milatools.utils.remote_v1 import RemoteV1
from milatools.utils.remote_v2 import RemoteV2
from milatools.utils.vscode_utils import (
_get_local_vscode_extensions,
_extensions_to_install,
_find_code_server_executable,
_get_local_vscode_executable_path,
_get_vscode_extensions,
extensions_to_install,
find_code_server_executable,
_get_vscode_extensions_dict,
_install_vscode_extension,
_install_vscode_extensions_task_function,
get_code_command,
get_expected_vscode_settings_json_path,
get_local_vscode_executable_path,
install_vscode_extension,
install_vscode_extensions_task_function,
sync_vscode_extensions,
vscode_installed,
)
Expand Down Expand Up @@ -84,13 +85,13 @@ def test_running_inside_WSL():

def test_get_vscode_executable_path():
if vscode_installed():
code = get_local_vscode_executable_path()
code = _get_local_vscode_executable_path()
assert Path(code).exists()
else:
with pytest.raises(
MilatoolsUserError, match="Command 'code' does not exist locally."
):
get_local_vscode_executable_path()
_get_local_vscode_executable_path()


@pytest.fixture
Expand All @@ -103,12 +104,12 @@ def mock_find_code_server_executable(monkeypatch: pytest.MonkeyPatch):
import milatools.utils.vscode_utils

mock_find_code_server_executable = AsyncMock(
spec=find_code_server_executable,
return_value=get_local_vscode_executable_path(),
spec=_find_code_server_executable,
return_value=_get_local_vscode_executable_path(),
)
monkeypatch.setattr(
milatools.utils.vscode_utils,
find_code_server_executable.__name__,
_find_code_server_executable.__name__,
mock_find_code_server_executable,
)
return mock_find_code_server_executable
Expand Down Expand Up @@ -143,7 +144,7 @@ async def vscode_extensions(
Here we pretend like some local vscode extensions are missing by patching the
function that returns the local extensions to return only part of its actual result.
"""
all_extensions = await _get_local_vscode_extensions()
all_extensions = await _get_vscode_extensions(LocalV2())

installed_extensions = all_extensions.copy()
num_missing_extensions = 3
Expand All @@ -157,12 +158,12 @@ async def vscode_extensions(
# `localhost` is the source, so it has all the extensions
# the "remote" (just to localhost during tests) is missing some extensions
mock_remote_extensions = AsyncMock(
spec=_get_vscode_extensions,
return_value=(installed_extensions, str(get_local_vscode_executable_path())),
spec=_get_vscode_extensions_dict,
return_value=(installed_extensions, str(_get_local_vscode_executable_path())),
)
monkeypatch.setattr(
milatools.utils.vscode_utils,
_get_vscode_extensions.__name__,
_get_vscode_extensions_dict.__name__,
mock_remote_extensions,
)

Expand Down Expand Up @@ -215,7 +216,7 @@ async def test_install_vscode_extensions_task_function(
task_progress_dict: dict[TaskID, ProgressDict] = {}
_fake_remote = await RemoteV2.connect("localhost")

result = await install_vscode_extensions_task_function(
result = await _install_vscode_extensions_task_function(
task_progress_dict=task_progress_dict,
task_id=TaskID(0),
dest_hostname="fake_cluster",
Expand All @@ -241,9 +242,9 @@ async def test_install_vscode_extensions_task_function(
@pytest.mark.asyncio
async def test_install_vscode_extension(missing_extensions: dict[str, str]):
extension_name, version = next(iter(missing_extensions.items()))
result = await install_vscode_extension(
result = await _install_vscode_extension(
remote=(await RemoteV2.connect("localhost")),
code_server_executable=str(get_local_vscode_executable_path()),
code_server_executable=str(_get_local_vscode_executable_path()),
extension=f"{extension_name}@{version}",
verbose=False,
)
Expand All @@ -258,7 +259,8 @@ async def test_install_vscode_extension(missing_extensions: dict[str, str]):
@requires_vscode
@pytest.mark.asyncio
async def test_get_local_vscode_extensions():
local_extensions = await _get_local_vscode_extensions()
local_extensions = await _get_vscode_extensions(LocalV2())

assert local_extensions and all(
isinstance(ext, str) and isinstance(version, str)
for ext, version in local_extensions.items()
Expand All @@ -274,13 +276,13 @@ async def test_get_remote_vscode_extensions(mock_find_code_server_executable):
# therefore the "remote" extensions are the same as the local extensions.
fake_remote = await RemoteV2.connect("localhost")

local_vscode_executable = get_local_vscode_executable_path()
local_vscode_executable = _get_local_vscode_executable_path()
assert local_vscode_executable is not None

fake_remote_extensions = await _get_vscode_extensions(
fake_remote_extensions = await _get_vscode_extensions_dict(
fake_remote, code_server_executable=local_vscode_executable
)
assert fake_remote_extensions == await _get_local_vscode_extensions()
assert fake_remote_extensions == await _get_vscode_extensions(LocalV2())


@requires_vscode
Expand All @@ -289,7 +291,7 @@ def test_extensions_to_install(
installed_extensions: dict[str, str],
missing_extensions: dict[str, str],
):
to_install = extensions_to_install(
to_install = _extensions_to_install(
source_extensions=all_extensions,
dest_extensions=installed_extensions,
source_name="foo",
Expand All @@ -313,7 +315,7 @@ async def test_find_code_server_executable(cluster: str, login_node_v2: RemoteV2
remote_vscode_server_dir = "~/.vscode-server"
should_exist = True

code_server_exe_path = await find_code_server_executable(
code_server_exe_path = await _find_code_server_executable(
login_node_v2,
remote_vscode_server_dir=remote_vscode_server_dir,
)
Expand Down

0 comments on commit 4f35924

Please sign in to comment.