Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix for mila code in WSL #131

Merged
merged 7 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
- uses: actions/setup-python@v4
with:
python-version: '3.10'
- run: pip install pre-commit
- run: pip install "pre-commit<4.0.0"
- run: pre-commit --version
- run: pre-commit install
- run: pre-commit run --all-files
Expand Down
4 changes: 4 additions & 0 deletions milatools/cli/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
MilatoolsUserError,
currently_in_a_test,
internet_on_compute_nodes,
running_inside_WSL,
)
from milatools.utils.compute_node import ComputeNode, salloc, sbatch
from milatools.utils.disk_quota import check_disk_quota
Expand Down Expand Up @@ -193,6 +194,9 @@ async def launch_vscode_loop(code_command: str, compute_node: ComputeNode, path:
f"ssh-remote+{compute_node.hostname}",
path,
)
if running_inside_WSL():
code_command_to_run = ("powershell.exe", *code_command_to_run)

await LocalV2.run_async(code_command_to_run, display=True)
# TODO: BUG: This now requires two Ctrl+C's instead of one!
console.print(
Expand Down
64 changes: 64 additions & 0 deletions tests/cli/test_code.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Unit tests for the `milatools.cli.code` module.

TODO: There are quite a few tests in `tests/integration/test_code.py` that could be
moved here, since some of them aren't exactly "integration" tests.
"""

from unittest.mock import AsyncMock, Mock

import pytest

import milatools.cli.code
import milatools.cli.utils
from milatools.cli.utils import running_inside_WSL
from milatools.utils.compute_node import ComputeNode
from milatools.utils.local_v2 import LocalV2


@pytest.fixture
def pretend_to_be_in_WSL(
request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch
):
# By default, pretend to be in WSL. Indirect parametrization can be used to
# overwrite this value for a given test (as is done below).
in_wsl = getattr(request, "param", True)

_mock_running_inside_WSL = Mock(spec=running_inside_WSL, return_value=in_wsl)
monkeypatch.setattr(
milatools.cli.utils,
running_inside_WSL.__name__, # type: ignore
_mock_running_inside_WSL,
)
monkeypatch.setattr(
milatools.cli.code,
running_inside_WSL.__name__, # type: ignore
_mock_running_inside_WSL,
)
return in_wsl


@pytest.mark.parametrize("pretend_to_be_in_WSL", [True, False], indirect=True)
@pytest.mark.asyncio
async def test_code_from_WSL(
monkeypatch: pytest.MonkeyPatch, pretend_to_be_in_WSL: bool
):
# Mock the LocalV2 class so that we can inspect the call to `LocalV2.run_async`.
mock_localv2 = Mock(spec=LocalV2)
monkeypatch.setattr(milatools.cli.code, LocalV2.__name__, mock_localv2)

await milatools.cli.code.launch_vscode_loop(
"code", Mock(spec=ComputeNode, hostname="foo"), "/bob/path"
)
assert isinstance(mock_localv2.run_async, AsyncMock)
mock_localv2.run_async.assert_called_once_with(
(
*(("powershell.exe",) if pretend_to_be_in_WSL else ()),
"code",
"--new-window",
"--wait",
"--remote",
"ssh-remote+foo",
"/bob/path",
),
display=True,
)
2 changes: 1 addition & 1 deletion tests/integration/test_code/test_code_mila__salloc_.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Checking disk quota on $HOME...
Disk usage: X / LIMIT GiB and X / LIMIT files
(mila) $ cd $SCRATCH && salloc --wckey=milatools_test --account=SLURM_ACCOUNT --nodes=1 --ntasks=1 --cpus-per-task=1 --mem=1G --time=0:05:00 --oversubscribe --job-name=mila-code
salloc: --------------------------------------------------------------------------------------------------
salloc: # Using default long partition
salloc: # Using default long-cpu partition (CPU-only)
salloc: --------------------------------------------------------------------------------------------------
salloc: Granted job allocation JOB_ID
Waiting for job JOB_ID to start.
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_code/test_code_mila__sbatch_.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Disk usage: X / LIMIT GiB and X / LIMIT files
JOB_ID

sbatch: --------------------------------------------------------------------------------------------------
sbatch: # Using default long partition
sbatch: # Using default long-cpu partition (CPU-only)
sbatch: --------------------------------------------------------------------------------------------------

(localhost) $ echo --new-window --wait --remote ssh-remote+COMPUTE_NODE $HOME/bob
Expand Down
5 changes: 5 additions & 0 deletions tests/integration/test_sync_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
_install_vscode_extensions_task_function,
sync_vscode_extensions,
)
from tests.integration.conftest import SLURM_CLUSTER

from ..cli.common import (
requires_ssh_to_localhost,
Expand All @@ -28,6 +29,10 @@
logger = get_logger(__name__)


@pytest.mark.xfail(
SLURM_CLUSTER == "mila",
reason="`code-server` procs are killed on the login nodes of the Mila cluster.",
)
@pytest.mark.slow
@pytest.mark.parametrize(
"source",
Expand Down
Loading