From 4f4c96ab7c94b3adef68d61515def60dec928894 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Thu, 11 Apr 2024 14:31:35 -0400 Subject: [PATCH] Add ComputeNode, LocalV2, reorganize commands - Rename `Local` to `LocalV1` - Rename `Remote` -> `RemoteV1` - Mark LocalV1 and RemoteV1 as deprecated - Move remote_v1.py and local_v1.py under `utils` Signed-off-by: Fabrice Normandin --- .github/workflows/build.yml | 5 +- README.md | 23 +- milatools/cli/__init__.py | 15 +- milatools/cli/code_command.py | 238 ++++++ milatools/cli/commands.py | 697 ++---------------- milatools/cli/common.py | 426 +++++++++++ milatools/cli/init_command.py | 19 +- milatools/cli/profile.py | 18 +- milatools/cli/utils.py | 77 +- milatools/utils/compute_node.py | 387 ++++++++++ milatools/{cli/local.py => utils/local_v1.py} | 9 +- milatools/utils/local_v2.py | 238 ++++++ .../{cli/remote.py => utils/remote_v1.py} | 9 +- milatools/utils/remote_v2.py | 414 +++++++---- milatools/utils/runner.py | 101 +++ milatools/utils/vscode_utils.py | 69 +- poetry.lock | 54 +- pyproject.toml | 2 + tests/cli/common.py | 44 +- tests/cli/test_commands.py | 3 +- tests/cli/test_init_command.py | 16 +- tests/cli/test_utils.py | 93 ++- tests/conftest.py | 267 ++++++- tests/integration/conftest.py | 81 +- tests/integration/test_code_command.py | 15 +- tests/integration/test_slurm_remote.py | 92 +-- tests/integration/test_sync_command.py | 28 +- tests/utils/runner_tests.py | 254 +++++++ tests/utils/test_compute_node.py | 213 ++++++ .../test_local.py => utils/test_local_v1.py} | 19 +- .../test_local_v1}/test_display_cmd0_.txt | 0 .../test_local_v1}/test_display_cmd1_.txt | 0 .../test_local_v1}/test_get_cmd0_.txt | 0 .../test_local_v1}/test_popen_cmd0_.txt | 0 .../test_local_v1}/test_run_cmd0_.txt | 0 .../test_local_v1}/test_run_cmd1_.txt | 0 .../test_local_v1}/test_run_cmd2_.txt | 0 .../test_local_v1}/test_silent_get_cmd0_.txt | 0 tests/utils/test_local_v2.py | 14 + .../test_parallel_progress_bar.txt | 10 +- .../test_remote_v1.py} | 52 +- .../test_remote_v1}/test_QueueIO.txt | 0 ...bash_args4_initial_transforms0_echo_OK_.md | 2 +- ...mand_args2_initial_transforms0_echo_OK_.md | 2 +- ...file_args3_initial_transforms0_echo_OK_.md | 2 +- ...orms_args0_initial_transforms0_echo_OK_.md | 2 +- ...wrap_args1_initial_transforms0_echo_OK_.md | 2 +- .../test_srun_transform_persist_localhost_.md | 0 tests/utils/test_remote_v2.py | 139 ++-- tests/utils/test_vscode_utils.py | 29 +- 50 files changed, 2972 insertions(+), 1208 deletions(-) create mode 100644 milatools/cli/code_command.py create mode 100644 milatools/cli/common.py create mode 100644 milatools/utils/compute_node.py rename milatools/{cli/local.py => utils/local_v1.py} (90%) create mode 100644 milatools/utils/local_v2.py rename milatools/{cli/remote.py => utils/remote_v1.py} (98%) create mode 100644 milatools/utils/runner.py create mode 100644 tests/utils/runner_tests.py create mode 100644 tests/utils/test_compute_node.py rename tests/{cli/test_local.py => utils/test_local_v1.py} (88%) rename tests/{cli/test_local => utils/test_local_v1}/test_display_cmd0_.txt (100%) rename tests/{cli/test_local => utils/test_local_v1}/test_display_cmd1_.txt (100%) rename tests/{cli/test_local => utils/test_local_v1}/test_get_cmd0_.txt (100%) rename tests/{cli/test_local => utils/test_local_v1}/test_popen_cmd0_.txt (100%) rename tests/{cli/test_local => utils/test_local_v1}/test_run_cmd0_.txt (100%) rename tests/{cli/test_local => utils/test_local_v1}/test_run_cmd1_.txt (100%) rename tests/{cli/test_local => utils/test_local_v1}/test_run_cmd2_.txt (100%) rename tests/{cli/test_local => utils/test_local_v1}/test_silent_get_cmd0_.txt (100%) create mode 100644 tests/utils/test_local_v2.py rename tests/{cli/test_remote.py => utils/test_remote_v1.py} (95%) rename tests/{cli/test_remote => utils/test_remote_v1}/test_QueueIO.txt (100%) rename tests/{cli/test_remote => utils/test_remote_v1}/test_remote_transform_methods_localhost_with_bash_args4_initial_transforms0_echo_OK_.md (82%) rename tests/{cli/test_remote => utils/test_remote_v1}/test_remote_transform_methods_localhost_with_precommand_args2_initial_transforms0_echo_OK_.md (84%) rename tests/{cli/test_remote => utils/test_remote_v1}/test_remote_transform_methods_localhost_with_profile_args3_initial_transforms0_echo_OK_.md (82%) rename tests/{cli/test_remote => utils/test_remote_v1}/test_remote_transform_methods_localhost_with_transforms_args0_initial_transforms0_echo_OK_.md (86%) rename tests/{cli/test_remote => utils/test_remote_v1}/test_remote_transform_methods_localhost_wrap_args1_initial_transforms0_echo_OK_.md (83%) rename tests/{cli/test_remote => utils/test_remote_v1}/test_srun_transform_persist_localhost_.md (100%) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index ec1b9b0c..c8745f12 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -98,6 +98,7 @@ jobs: # NOTE: Replacing this with our customized version of # - uses: koesterlab/setup-slurm-action@v1 - uses: ./.github/custom_setup_slurm_action + timeout-minutes: 5 - name: Test if the slurm cluster is setup correctly run: srun --nodes=1 --ntasks=1 --cpus-per-task=1 --mem=1G --time=00:01:00 hostname @@ -170,8 +171,8 @@ jobs: - name: Launch integration tests id: self_hosted_integration_tests - run: poetry run pytest --slow --cov=milatools --cov-report=xml --cov-append -vvv --log-level=DEBUG - timeout-minutes: 30 + run: poetry run pytest --slow -n 5 --cov=milatools --cov-report=xml --cov-append -vvv --log-level=DEBUG + timeout-minutes: 60 env: SLURM_CLUSTER: mila diff --git a/README.md b/README.md index 4153ab4c..04b818d6 100644 --- a/README.md +++ b/README.md @@ -61,16 +61,23 @@ Connect a VSCode instance to a compute node. `mila code` first allocates a compu You can simply Ctrl+C the process to end the session. ``` -usage: mila code [-h] [--alloc ...] [--job VALUE] [--node VALUE] PATH +usage: mila code [-h] [--cluster {mila,cedar,narval,beluga,graham}] [--alloc ...] + [--command VALUE] [--job VALUE] [--node VALUE] [--persist] + PATH positional arguments: - PATH Path to open on the remote machine + PATH Path to open on the remote machine -optional arguments: - -h, --help show this help message and exit - --alloc ... Extra options to pass to slurm - --job VALUE Job ID to connect to - --node VALUE Node to connect to +options: + -h, --help show this help message and exit + --alloc ... Extra options to pass to slurm + --cluster {mila,cedar,narval,beluga,graham} + Which cluster to connect to. + --command VALUE Command to use to start vscode (defaults to "code" or the value + of $MILATOOLS_CODE_COMMAND) + --job VALUE Job ID to connect to + --node VALUE Node to connect to + --persist Whether the server should persist or not ``` For example: @@ -79,7 +86,7 @@ For example: mila code path/to/my/experiment ``` -The `--alloc` option may be used to pass extra arguments to `salloc` when allocating a node (for example, `--alloc --gres=cpu:8` to allocate 8 CPUs). `--alloc` should be at the end, because it will take all of the arguments that come after it. +The `--alloc` option may be used to pass extra arguments to `salloc` when allocating a node (for example, `--alloc --gres=gpu:1` to allocate 1 GPU). `--alloc` should be at the end, because it will take all of the arguments that come after it. If you already have an allocation on a compute node, you may use the `--node NODENAME` or `--job JOBID` options to connect to that node. diff --git a/milatools/cli/__init__.py b/milatools/cli/__init__.py index 247a0631..4c3fbc77 100644 --- a/milatools/cli/__init__.py +++ b/milatools/cli/__init__.py @@ -1,3 +1,16 @@ from rich.console import Console -console = Console(record=True) + +def _currently_in_a_test() -> bool: + """Returns True during unit tests (pytest) and False during normal execution.""" + import sys + + return "pytest" in sys.modules + + +if _currently_in_a_test(): + # Make the console very wide so commands are not wrapped across multiple lines. + # This makes tests that check the output of commands easier to write. + console = Console(record=True, width=200, log_time=False, log_path=False) +else: + console = Console(record=True) diff --git a/milatools/cli/code_command.py b/milatools/cli/code_command.py new file mode 100644 index 00000000..148962ec --- /dev/null +++ b/milatools/cli/code_command.py @@ -0,0 +1,238 @@ +from __future__ import annotations + +import argparse +import shutil +import sys +from logging import getLogger as get_logger + +from milatools.cli import console +from milatools.cli.common import ( + check_disk_quota, + find_allocation, +) +from milatools.cli.utils import ( + CLUSTERS, + Cluster, + CommandNotFoundError, + MilatoolsUserError, + SortingHelpFormatter, + currently_in_a_test, + get_hostname_to_use_for_compute_node, + make_process, + no_internet_on_compute_nodes, + running_inside_WSL, +) +from milatools.utils.local_v1 import LocalV1 +from milatools.utils.local_v2 import LocalV2 +from milatools.utils.remote_v1 import RemoteV1 +from milatools.utils.vscode_utils import ( + get_code_command, + sync_vscode_extensions, + sync_vscode_extensions_with_hostnames, +) + +logger = get_logger(__name__) + + +def add_mila_code_arguments(subparsers: argparse._SubParsersAction): + code_parser: argparse.ArgumentParser = subparsers.add_parser( + "code", + help="Open a remote VSCode session on a compute node.", + formatter_class=SortingHelpFormatter, + ) + code_parser.add_argument( + "PATH", help="Path to open on the remote machine", type=str + ) + code_parser.add_argument( + "--cluster", + choices=CLUSTERS, # todo: widen based on the entries in ssh config? + default="mila", + help="Which cluster to connect to.", + ) + code_parser.add_argument( + "--alloc", + nargs=argparse.REMAINDER, + help="Extra options to pass to slurm", + metavar="VALUE", + default=[], + ) + code_parser.add_argument( + "--command", + default=get_code_command(), + help=( + "Command to use to start vscode\n" + '(defaults to "code" or the value of $MILATOOLS_CODE_COMMAND)' + ), + metavar="VALUE", + ) + code_parser.add_argument( + "--job", + type=int, + default=None, + help="Job ID to connect to", + metavar="VALUE", + ) + code_parser.add_argument( + "--node", + type=str, + default=None, + help="Node to connect to", + metavar="VALUE", + ) + code_parser.add_argument( + "--persist", + action="store_true", + help="Whether the server should persist or not", + ) + code_parser.set_defaults(function=code) + + +def code( + path: str, + command: str, + persist: bool, + job: int | None, + node: str | None, + alloc: list[str], + cluster: Cluster = "mila", +): + """Open a remote VSCode session on a compute node. + + Arguments: + path: Path to open on the remote machine + command: Command to use to start vscode + (defaults to "code" or the value of $MILATOOLS_CODE_COMMAND) + persist: Whether the server should persist or not + job: Job ID to connect to + node: Node to connect to + alloc: Extra options to pass to slurm + """ + here = LocalV1() + remote = RemoteV1(cluster) + + if cluster != "mila" and job is None and node is None: + if not any("--account" in flag for flag in alloc): + logger.warning( + "Warning: When using the DRAC clusters, you usually need to " + "specify the account to use when submitting a job. You can specify " + "this in the job resources with `--alloc`, like so: " + "`--alloc --account=`, for example:\n" + f"mila code {path} --cluster {cluster} --alloc " + f"--account=your-account-here" + ) + + try: + check_disk_quota(remote) + except MilatoolsUserError: + raise + except Exception as exc: + logger.warning(f"Unable to check the disk-quota on the cluster: {exc}") + + if sys.platform == "win32": + print( + "Syncing vscode extensions in the background isn't supported on " + "Windows. Skipping." + ) + elif no_internet_on_compute_nodes(cluster): + # Sync the VsCode extensions from the local machine over to the target cluster. + run_in_the_background = False # if "pytest" not in sys.modules else True + print( + console.log( + f"[cyan]Installing VSCode extensions that are on the local machine on " + f"{cluster}" + (" in the background." if run_in_the_background else ".") + ) + ) + if run_in_the_background: + copy_vscode_extensions_process = make_process( + sync_vscode_extensions_with_hostnames, + # todo: use the mila cluster as the source for vscode extensions? Or + # `localhost`? + source="localhost", + destinations=[cluster], + ) + copy_vscode_extensions_process.start() + else: + sync_vscode_extensions( + LocalV2(), + [cluster], + ) + + if node is None: + cnode = find_allocation( + remote, + job_name="mila-code", + job=job, + node=node, + alloc=alloc, + cluster=cluster, + ) + if persist: + cnode = cnode.persist() + + data, proc = cnode.ensure_allocation() + + node_name = data["node_name"] + else: + node_name = node + proc = None + data = None + + if not path.startswith("/"): + # Get $HOME because we have to give the full path to code + home = remote.home() + path = home if path == "." else f"{home}/{path}" + + command_path = shutil.which(command) + if not command_path: + raise CommandNotFoundError(command) + + # NOTE: Since we have the config entries for the DRAC compute nodes, there is no + # need to use the fully qualified hostname here. + if cluster == "mila": + node_name = get_hostname_to_use_for_compute_node(node_name) + + # Try to detect if this is being run from within the Windows Subsystem for Linux. + # If so, then we run `code` through a powershell.exe command to open VSCode without + # issues. + inside_WSL = running_inside_WSL() + try: + while True: + if inside_WSL: + here.run( + "powershell.exe", + "code", + "-nw", + "--remote", + f"ssh-remote+{node_name}", + path, + ) + else: + here.run( + command_path, + "-nw", + "--remote", + f"ssh-remote+{node_name}", + path, + ) + print( + "The editor was closed. Reopen it with " + " or terminate the process with " + ) + if currently_in_a_test(): + break + input() + + except KeyboardInterrupt: + if not persist: + if proc is not None: + proc.kill() + print(f"Ended session on '{node_name}'") + + if persist: + console.print("This allocation is persistent and is still active.") + console.print("To reconnect to this node:") + console.print(f" mila code {path} --node {node_name}", markup=True) + console.print("To kill this allocation:") + assert data is not None + assert "jobid" in data + console.print(f" ssh mila scancel {data['jobid']}", style="bold") diff --git a/milatools/cli/commands.py b/milatools/cli/commands.py index 47743605..52fd70f2 100644 --- a/milatools/cli/commands.py +++ b/milatools/cli/commands.py @@ -2,25 +2,20 @@ Cluster documentation: https://docs.mila.quebec/ """ + from __future__ import annotations import argparse +import asyncio +import inspect import logging -import operator -import re -import shutil -import socket -import subprocess import sys -import time import traceback import typing import webbrowser -from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, _HelpAction +from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser from collections.abc import Sequence -from contextlib import ExitStack from logging import getLogger as get_logger -from pathlib import Path from typing import Any from urllib.parse import urlencode @@ -28,16 +23,15 @@ import rich.logging from typing_extensions import TypedDict -from milatools.cli import console -from milatools.utils.remote_v2 import RemoteV2 from milatools.utils.vscode_utils import ( - get_code_command, - # install_local_vscode_extensions_on_remote, - sync_vscode_extensions, sync_vscode_extensions_with_hostnames, ) from ..__version__ import __version__ +from ..utils.local_v1 import LocalV1 +from ..utils.remote_v1 import RemoteV1 +from .code_command import add_mila_code_arguments +from .common import forward, standard_server from .init_command import ( print_welcome_message, setup_keys_on_login_node, @@ -46,25 +40,14 @@ setup_vscode_settings, setup_windows_ssh_config_from_wsl, ) -from .local import Local -from .profile import ensure_program, setup_profile -from .remote import Remote, SlurmRemote from .utils import ( CLUSTERS, - Cluster, - CommandNotFoundError, MilatoolsUserError, + SortingHelpFormatter, SSHConnectionError, T, - cluster_to_connect_kwargs, - currently_in_a_test, - get_fully_qualified_hostname_of_compute_node, get_fully_qualified_name, - make_process, - no_internet_on_compute_nodes, - randname, running_inside_WSL, - with_control_file, ) if typing.TYPE_CHECKING: @@ -191,57 +174,7 @@ def mila(): forward_parser.set_defaults(function=forward) # ----- mila code ------ - - code_parser = subparsers.add_parser( - "code", - help="Open a remote VSCode session on a compute node.", - formatter_class=SortingHelpFormatter, - ) - code_parser.add_argument( - "PATH", help="Path to open on the remote machine", type=str - ) - code_parser.add_argument( - "--cluster", - choices=CLUSTERS, - default="mila", - help="Which cluster to connect to.", - ) - code_parser.add_argument( - "--alloc", - nargs=argparse.REMAINDER, - help="Extra options to pass to slurm", - metavar="VALUE", - default=[], - ) - code_parser.add_argument( - "--command", - default=get_code_command(), - help=( - "Command to use to start vscode\n" - '(defaults to "code" or the value of $MILATOOLS_CODE_COMMAND)' - ), - metavar="VALUE", - ) - code_parser.add_argument( - "--job", - type=str, - default=None, - help="Job ID to connect to", - metavar="VALUE", - ) - code_parser.add_argument( - "--node", - type=str, - default=None, - help="Node to connect to", - metavar="VALUE", - ) - code_parser.add_argument( - "--persist", - action="store_true", - help="Whether the server should persist or not", - ) - code_parser.set_defaults(function=code) + add_mila_code_arguments(subparsers) # ----- mila sync vscode-extensions ------ @@ -426,6 +359,16 @@ def mila(): setup_logging(verbose) # replace SEARCH -> "search", REMOTE -> "remote", etc. args_dict = _convert_uppercase_keys_to_lowercase(args_dict) + + if inspect.iscoroutinefunction(function): + try: + return asyncio.run(function(**args_dict)) + except KeyboardInterrupt: + from milatools.cli import console + + console.log("Terminated by user.") + exit() + assert callable(function) return function(**args_dict) @@ -434,11 +377,13 @@ def setup_logging(verbose: int) -> None: global_loglevel = ( logging.CRITICAL if verbose == 0 - else logging.WARNING - if verbose == 1 - else logging.INFO - if verbose == 2 - else logging.DEBUG + else ( + logging.WARNING + if verbose == 1 + else logging.INFO + if verbose == 2 + else logging.DEBUG + ) ) package_loglevel = ( logging.WARNING @@ -505,7 +450,7 @@ def init(): print_welcome_message() -def forward( +def forward_command( remote: str, page: str | None, port: int | None, @@ -517,8 +462,8 @@ def forward( except ValueError: pass - local_proc, _ = _forward( - local=Local(), + local_proc, _ = forward( + local=LocalV1(), node=f"{node}.server.mila.quebec", to_forward=remote_port, page=page, @@ -533,167 +478,13 @@ def forward( local_proc.kill() -def code( - path: str, - command: str, - persist: bool, - job: str | None, - node: str | None, - alloc: list[str], - cluster: Cluster = "mila", -): - """Open a remote VSCode session on a compute node. - - Arguments: - path: Path to open on the remote machine - command: Command to use to start vscode - (defaults to "code" or the value of $MILATOOLS_CODE_COMMAND) - persist: Whether the server should persist or not - job: Job ID to connect to - node: Node to connect to - alloc: Extra options to pass to slurm - """ - here = Local() - remote = Remote(cluster) - - if cluster != "mila" and job is None and node is None: - if not any("--account" in flag for flag in alloc): - logger.warning( - "Warning: When using the DRAC clusters, you usually need to " - "specify the account to use when submitting a job. You can specify " - "this in the job resources with `--alloc`, like so: " - "`--alloc --account=`, for example:\n" - f"mila code {path} --cluster {cluster} --alloc " - f"--account=your-account-here" - ) - - if command is None: - command = get_code_command() - - try: - check_disk_quota(remote) - except MilatoolsUserError: - raise - except Exception as exc: - logger.warning(f"Unable to check the disk-quota on the cluster: {exc}") - - if sys.platform == "win32": - print( - "Syncing vscode extensions in the background isn't supported on " - "Windows. Skipping." - ) - elif no_internet_on_compute_nodes(cluster): - # Sync the VsCode extensions from the local machine over to the target cluster. - run_in_the_background = False # if "pytest" not in sys.modules else True - print( - console.log( - f"[cyan]Installing VSCode extensions that are on the local machine on " - f"{cluster}" + (" in the background." if run_in_the_background else ".") - ) - ) - if run_in_the_background: - copy_vscode_extensions_process = make_process( - sync_vscode_extensions_with_hostnames, - # todo: use the mila cluster as the source for vscode extensions? Or - # `localhost`? - source="localhost", - destinations=[cluster], - ) - copy_vscode_extensions_process.start() - else: - sync_vscode_extensions( - Local(), - [cluster], - ) - - if node is None: - cnode = _find_allocation( - remote, - job_name="mila-code", - job=job, - node=node, - alloc=alloc, - cluster=cluster, - ) - if persist: - cnode = cnode.persist() - - data, proc = cnode.ensure_allocation() - - node_name = data["node_name"] - else: - node_name = node - proc = None - data = None - - if not path.startswith("/"): - # Get $HOME because we have to give the full path to code - home = remote.home() - path = home if path == "." else f"{home}/{path}" - - command_path = shutil.which(command) - if not command_path: - raise CommandNotFoundError(command) - - # NOTE: Since we have the config entries for the DRAC compute nodes, there is no - # need to use the fully qualified hostname here. - if cluster == "mila": - node_name = get_fully_qualified_hostname_of_compute_node(node_name) - - # Try to detect if this is being run from within the Windows Subsystem for Linux. - # If so, then we run `code` through a powershell.exe command to open VSCode without - # issues. - inside_WSL = running_inside_WSL() - try: - while True: - if inside_WSL: - here.run( - "powershell.exe", - "code", - "-nw", - "--remote", - f"ssh-remote+{node_name}", - path, - ) - else: - here.run( - command_path, - "-nw", - "--remote", - f"ssh-remote+{node_name}", - path, - ) - print( - "The editor was closed. Reopen it with " - " or terminate the process with " - ) - if currently_in_a_test(): - break - input() - - except KeyboardInterrupt: - if not persist: - if proc is not None: - proc.kill() - print(f"Ended session on '{node_name}'") - - if persist: - print("This allocation is persistent and is still active.") - print("To reconnect to this node:") - print(T.bold(f" mila code {path} --node {node_name}")) - print("To kill this allocation:") - assert data is not None - assert "jobid" in data - print(T.bold(f" ssh mila scancel {data['jobid']}")) - - def connect(identifier: str, port: int | None): """Reconnect to a persistent server.""" - remote = Remote("mila") + remote = RemoteV1("mila") info = _get_server_info(remote, identifier) - local_proc, _ = _forward( - local=Local(), + local_proc, _ = forward( + local=LocalV1(), node=f"{info['node_name']}.server.mila.quebec", to_forward=info["to_forward"], options={"token": info.get("token", None)}, @@ -711,7 +502,7 @@ def connect(identifier: str, port: int | None): def kill(identifier: str | None, all: bool = False): """Kill a persistent server.""" - remote = Remote("mila") + remote = RemoteV1("mila") if all: for identifier in remote.get_lines("ls .milatools/control", hide=True): @@ -733,7 +524,7 @@ def kill(identifier: str | None, all: bool = False): def serve_list(purge: bool): """List active servers.""" - remote = Remote("mila") + remote = RemoteV1("mila") to_purge = [] @@ -768,7 +559,7 @@ class StandardServerArgs(TypedDict): alloc: list[str] """Extra options to pass to slurm.""" - job: str | None + job: int | None """Job ID to connect to.""" name: str | None @@ -797,7 +588,7 @@ def lab(path: str | None, **kwargs: Unpack[StandardServerArgs]): if path and path.endswith(".ipynb"): exit("Only directories can be given to the mila serve lab command") - _standard_server( + standard_server( path, program="jupyter-lab", installers={ @@ -820,7 +611,7 @@ def notebook(path: str | None, **kwargs: Unpack[StandardServerArgs]): if path and path.endswith(".ipynb"): exit("Only directories can be given to the mila serve notebook command") - _standard_server( + standard_server( path, program="jupyter-notebook", installers={ @@ -841,7 +632,7 @@ def tensorboard(logdir: str, **kwargs: Unpack[StandardServerArgs]): logdir: Path to the experiment logs """ - _standard_server( + standard_server( logdir, program="tensorboard", installers={ @@ -861,7 +652,7 @@ def mlflow(logdir: str, **kwargs: Unpack[StandardServerArgs]): logdir: Path to the experiment logs """ - _standard_server( + standard_server( logdir, program="mlflow", installers={ @@ -879,7 +670,7 @@ def aim(logdir: str, **kwargs: Unpack[StandardServerArgs]): Arguments: logdir: Path to the experiment logs """ - _standard_server( + standard_server( logdir, program="aim", installers={ @@ -892,25 +683,13 @@ def aim(logdir: str, **kwargs: Unpack[StandardServerArgs]): def _get_server_info( - remote: Remote, identifier: str, hide: bool = False + remote: RemoteV1, identifier: str, hide: bool = False ) -> dict[str, str]: text = remote.get_output(f"cat .milatools/control/{identifier}", hide=hide) info = dict(line.split(" = ") for line in text.split("\n") if line) return info -class SortingHelpFormatter(argparse.HelpFormatter): - """Taken and adapted from https://stackoverflow.com/a/12269143/6388696.""" - - def add_arguments(self, actions): - actions = sorted(actions, key=operator.attrgetter("option_strings")) - # put help actions first. - actions = sorted( - actions, key=lambda action: not isinstance(action, _HelpAction) - ) - super().add_arguments(actions) - - def _add_standard_server_args(parser: ArgumentParser): parser.add_argument( "--alloc", @@ -921,7 +700,7 @@ def _add_standard_server_args(parser: ArgumentParser): ) parser.add_argument( "--job", - type=str, + type=int, default=None, help="Job ID to connect to", metavar="VALUE", @@ -961,395 +740,5 @@ def _add_standard_server_args(parser: ArgumentParser): ) -def _standard_server( - path: str | None, - *, - program: str, - installers: dict[str, str], - command: str, - profile: str | None, - persist: bool, - port: int | None, - name: str | None, - node: str | None, - job: str | None, - alloc: list[str], - port_pattern=None, - token_pattern=None, -): - # Make the server visible from the login node (other users will be able to connect) - # Temporarily disabled - share = False - - if name is not None: - persist = True - elif persist: - name = program - - remote = Remote("mila") - - path = path or "~" - if path == "~" or path.startswith("~/"): - path = remote.home() + path[1:] - - results: dict | None = None - node_name: str | None = None - to_forward: int | str | None = None - cf: str | None = None - proc = None - with ExitStack() as stack: - if persist: - cf = stack.enter_context(with_control_file(remote, name=name)) - else: - cf = None - - if profile: - prof = f"~/.milatools/profiles/{profile}.bash" - else: - prof = setup_profile(remote, path) - - qn.print(f"Using profile: {prof}") - cat_result = remote.run(f"cat {prof}", hide=True, warn=True) - if cat_result.ok: - qn.print("=" * 50) - qn.print(cat_result.stdout.rstrip()) - qn.print("=" * 50) - else: - exit(f"Could not find or load profile: {prof}") - - premote = remote.with_profile(prof) - - if not ensure_program( - remote=premote, - program=program, - installers=installers, - ): - exit(f"Exit: {program} is not installed.") - - cnode = _find_allocation( - remote, - job_name=f"mila-serve-{program}", - node=node, - job=job, - alloc=alloc, - cluster="mila", - ) - - patterns = { - "node_name": "#### ([A-Za-z0-9_-]+)", - } - - if port_pattern: - patterns["port"] = port_pattern - elif share: - exit( - "Server cannot be shared because it is serving over a Unix domain " - "socket" - ) - else: - remote.run("mkdir -p ~/.milatools/sockets", hide=True) - - if share: - host = "0.0.0.0" - else: - host = "localhost" - - sock_name = name or randname() - command = command.format( - path=path, - sock=f"~/.milatools/sockets/{sock_name}.sock", - host=host, - ) - - if token_pattern: - patterns["token"] = token_pattern - - if persist: - cnode = cnode.persist() - - proc, results = ( - cnode.with_profile(prof) - .with_precommand("echo '####' $(hostname)") - .extract( - command, - patterns=patterns, - ) - ) - node_name = results["node_name"] - - if port_pattern: - to_forward = int(results["port"]) - else: - to_forward = f"{remote.home()}/.milatools/sockets/{sock_name}.sock" - - if cf is not None: - remote.simple_run(f"echo program = {program} >> {cf}") - remote.simple_run(f"echo node_name = {results['node_name']} >> {cf}") - remote.simple_run(f"echo host = {host} >> {cf}") - remote.simple_run(f"echo to_forward = {to_forward} >> {cf}") - if token_pattern: - remote.simple_run(f"echo token = {results['token']} >> {cf}") - - assert results is not None - assert node_name is not None - assert to_forward is not None - assert proc is not None - if token_pattern: - options = {"token": results["token"]} - else: - options = {} - - local_proc, local_port = _forward( - local=Local(), - node=get_fully_qualified_hostname_of_compute_node(node_name, cluster="mila"), - to_forward=to_forward, - options=options, - port=port, - ) - - if cf is not None: - remote.simple_run(f"echo local_port = {local_port} >> {cf}") - - try: - local_proc.wait() - except KeyboardInterrupt: - qn.print("Terminated by user.") - if cf is not None: - name = Path(cf).name - qn.print("To reconnect to this server, use the command:") - qn.print(f" mila serve connect {name}", style="bold yellow") - qn.print("To kill this server, use the command:") - qn.print(f" mila serve kill {name}", style="bold red") - finally: - local_proc.kill() - proc.kill() - - -def _parse_lfs_quota_output( - lfs_quota_output: str, -) -> tuple[tuple[float, float], tuple[int, int]]: - """Parses space and # of files (usage, limit) from the output of `lfs quota`.""" - lines = lfs_quota_output.splitlines() - - header_line: str | None = None - header_line_index: int | None = None - for index, line in enumerate(lines): - if ( - len(line_parts := line.strip().split()) == 9 - and line_parts[0].lower() == "filesystem" - ): - header_line = line - header_line_index = index - break - assert header_line - assert header_line_index is not None - - values_line_parts: list[str] = [] - # The next line may overflow to two (or maybe even more?) lines if the name of the - # $HOME dir is too long. - for content_line in lines[header_line_index + 1 :]: - additional_values = content_line.strip().split() - assert len(values_line_parts) < 9 - values_line_parts.extend(additional_values) - if len(values_line_parts) == 9: - break - - assert len(values_line_parts) == 9, values_line_parts - ( - _filesystem, - used_kbytes, - _quota_kbytes, - limit_kbytes, - _grace_kbytes, - files, - _quota_files, - limit_files, - _grace_files, - ) = values_line_parts - - used_gb = int(used_kbytes.strip()) / (1024**2) - max_gb = int(limit_kbytes.strip()) / (1024**2) - used_files = int(files.strip()) - max_files = int(limit_files.strip()) - return (used_gb, max_gb), (used_files, max_files) - - -def check_disk_quota(remote: Remote | RemoteV2) -> None: - cluster = remote.hostname - - # NOTE: This is what the output of the command looks like on the Mila cluster: - # - # Disk quotas for usr normandf (uid 1471600598): - # Filesystem kbytes quota limit grace files quota limit grace - # /home/mila/n/normandf - # 95747836 0 104857600 - 908722 0 1048576 - - # uid 1471600598 is using default block quota setting - # uid 1471600598 is using default file quota setting - - # Need to assert this, otherwise .get_output calls .run which would spawn a job! - assert not isinstance(remote, SlurmRemote) - if not remote.get_output("which lfs", hide=True): - logger.debug("Cluster doesn't have the lfs command. Skipping check.") - return - - console.log("Checking disk quota on $HOME...") - - home_disk_quota_output = remote.get_output("lfs quota -u $USER $HOME", hide=True) - if "not on a mounted Lustre filesystem" in home_disk_quota_output: - logger.debug("Cluster doesn't use lustre on $HOME filesystem. Skipping check.") - return - - (used_gb, max_gb), (used_files, max_files) = _parse_lfs_quota_output( - home_disk_quota_output - ) - - def get_colour(used: float, max: float) -> str: - return "red" if used >= max else "orange" if used / max > 0.7 else "green" - - disk_usage_style = get_colour(used_gb, max_gb) - num_files_style = get_colour(used_files, max_files) - from rich.text import Text - - console.log( - "Disk usage:", - Text(f"{used_gb:.2f} / {max_gb:.2f} GiB", style=disk_usage_style), - "and", - Text(f"{used_files} / {max_files} files", style=num_files_style), - markup=False, - ) - size_ratio = used_gb / max_gb - files_ratio = used_files / max_files - reason = ( - f"{used_gb:.1f} / {max_gb} GiB" - if size_ratio > files_ratio - else f"{used_files} / {max_files} files" - ) - - freeing_up_space_instructions = ( - "For example, temporary files (logs, checkpoints, etc.) can be moved to " - "$SCRATCH, while files that need to be stored for longer periods can be moved " - "to $ARCHIVE or to a shared project folder under /network/projects.\n" - "Visit https://docs.mila.quebec/Information.html#storage to learn more about " - "how to best make use of the different filesystems available on the cluster." - ) - - if used_gb >= max_gb or used_files >= max_files: - raise MilatoolsUserError( - T.red( - f"ERROR: Your disk quota on the $HOME filesystem is exceeded! " - f"({reason}).\n" - f"To fix this, login to the cluster with `ssh {cluster}` and free up " - f"some space, either by deleting files, or by moving them to a " - f"suitable filesystem.\n" + freeing_up_space_instructions - ) - ) - if max(size_ratio, files_ratio) > 0.9: - warning_message = ( - f"You are getting pretty close to your disk quota on the $HOME " - f"filesystem: ({reason})\n" - "Please consider freeing up some space in your $HOME folder, either by " - "deleting files, or by moving them to a more suitable filesystem.\n" - + freeing_up_space_instructions - ) - logger.warning(UserWarning(warning_message)) - - -def _find_allocation( - remote: Remote, - node: str | None, - job: str | None, - alloc: list[str], - cluster: Cluster = "mila", - job_name: str = "mila-tools", -): - if (node is not None) + (job is not None) + bool(alloc) > 1: - exit("ERROR: --node, --job and --alloc are mutually exclusive") - - if node is not None: - node_name = get_fully_qualified_hostname_of_compute_node(node, cluster=cluster) - return Remote(node_name, connect_kwargs=cluster_to_connect_kwargs.get(cluster)) - - elif job is not None: - node_name = remote.get_output(f"squeue --jobs {job} -ho %N") - return Remote(node_name, connect_kwargs=cluster_to_connect_kwargs.get(cluster)) - - else: - alloc = ["-J", job_name, *alloc] - return SlurmRemote( - connection=remote.connection, - alloc=alloc, - hostname=remote.hostname, - ) - - -def _forward( - local: Local, - node: str, - to_forward: int | str, - port: int | None, - page: str | None = None, - options: dict[str, str | None] = {}, - through_login: bool = False, -): - if port is None: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - # Find a free local port by binding to port 0 - sock.bind(("localhost", 0)) - _, port = sock.getsockname() - # Close it for ssh -L. It is *unlikely* it will not be available. - sock.close() - - if isinstance(to_forward, int) or re.match("[0-9]+", to_forward): - if through_login: - to_forward = f"{node}:{to_forward}" - args = [f"localhost:{port}:{to_forward}", "mila"] - else: - to_forward = f"localhost:{to_forward}" - args = [f"localhost:{port}:{to_forward}", node] - else: - args = [f"localhost:{port}:{to_forward}", node] - - proc = local.popen( - "ssh", - "-o", - "UserKnownHostsFile=/dev/null", - "-o", - "StrictHostKeyChecking=no", - "-nNL", - *args, - ) - - url = f"http://localhost:{port}" - if page is not None: - if not page.startswith("/"): - page = f"/{page}" - url += page - - options = {k: v for k, v in options.items() if v is not None} - if options: - url += f"?{urlencode(options)}" - - qn.print("Waiting for connection to be active...") - nsecs = 10 - period = 0.2 - for _ in range(int(nsecs / period)): - time.sleep(period) - try: - # This feels stupid, there's probably a better way - local.silent_get("nc", "-z", "localhost", str(port)) - except subprocess.CalledProcessError: - continue - except Exception: - break - break - - qn.print( - "Starting browser. You might need to refresh the page.", - style="bold", - ) - webbrowser.open(url) - return proc, port - - if __name__ == "__main__": main() diff --git a/milatools/cli/common.py b/milatools/cli/common.py new file mode 100644 index 00000000..39818891 --- /dev/null +++ b/milatools/cli/common.py @@ -0,0 +1,426 @@ +from __future__ import annotations + +import re +import socket +import subprocess +import time +import webbrowser +from contextlib import ExitStack +from logging import getLogger as get_logger +from pathlib import Path +from urllib.parse import urlencode + +import questionary as qn +from rich.text import Text + +from milatools.cli import console +from milatools.cli.profile import ensure_program, setup_profile +from milatools.cli.utils import ( + Cluster, + MilatoolsUserError, + T, + cluster_to_connect_kwargs, + get_hostname_to_use_for_compute_node, + randname, + with_control_file, +) +from milatools.utils.local_v1 import LocalV1 +from milatools.utils.remote_v1 import RemoteV1, SlurmRemote +from milatools.utils.remote_v2 import RemoteV2 + +logger = get_logger(__name__) + + +def _parse_lfs_quota_output( + lfs_quota_output: str, +) -> tuple[tuple[float, float], tuple[int, int]]: + """Parses space and # of files (usage, limit) from the output of `lfs quota`.""" + lines = lfs_quota_output.splitlines() + + header_line: str | None = None + header_line_index: int | None = None + for index, line in enumerate(lines): + if ( + len(line_parts := line.strip().split()) == 9 + and line_parts[0].lower() == "filesystem" + ): + header_line = line + header_line_index = index + break + assert header_line + assert header_line_index is not None + + values_line_parts: list[str] = [] + # The next line may overflow to two (or maybe even more?) lines if the name of the + # $HOME dir is too long. + for content_line in lines[header_line_index + 1 :]: + additional_values = content_line.strip().split() + assert len(values_line_parts) < 9 + values_line_parts.extend(additional_values) + if len(values_line_parts) == 9: + break + + assert len(values_line_parts) == 9, values_line_parts + ( + _filesystem, + used_kbytes, + _quota_kbytes, + limit_kbytes, + _grace_kbytes, + files, + _quota_files, + limit_files, + _grace_files, + ) = values_line_parts + + used_gb = int(used_kbytes.strip()) / (1024**2) + max_gb = int(limit_kbytes.strip()) / (1024**2) + used_files = int(files.strip()) + max_files = int(limit_files.strip()) + return (used_gb, max_gb), (used_files, max_files) + + +def check_disk_quota(remote: RemoteV1 | RemoteV2) -> None: + cluster = remote.hostname + + # NOTE: This is what the output of the command looks like on the Mila cluster: + # + # Disk quotas for usr normandf (uid 1471600598): + # Filesystem kbytes quota limit grace files quota limit grace + # /home/mila/n/normandf + # 95747836 0 104857600 - 908722 0 1048576 - + # uid 1471600598 is using default block quota setting + # uid 1471600598 is using default file quota setting + + # Need to assert this, otherwise .get_output calls .run which would spawn a job! + assert not isinstance(remote, SlurmRemote) + if not remote.get_output("which lfs", display=False, hide=True): + logger.debug("Cluster doesn't have the lfs command. Skipping check.") + return + + console.log("Checking disk quota on $HOME...") + + home_disk_quota_output = remote.get_output( + "lfs quota -u $USER $HOME", display=False, hide=True + ) + if "not on a mounted Lustre filesystem" in home_disk_quota_output: + logger.debug("Cluster doesn't use lustre on $HOME filesystem. Skipping check.") + return + + (used_gb, max_gb), (used_files, max_files) = _parse_lfs_quota_output( + home_disk_quota_output + ) + + def get_colour(used: float, max: float) -> str: + return "red" if used >= max else "orange" if used / max > 0.7 else "green" + + disk_usage_style = get_colour(used_gb, max_gb) + num_files_style = get_colour(used_files, max_files) + + console.log( + "Disk usage:", + Text(f"{used_gb:.2f} / {max_gb:.2f} GiB", style=disk_usage_style), + "and", + Text(f"{used_files} / {max_files} files", style=num_files_style), + markup=False, + ) + size_ratio = used_gb / max_gb + files_ratio = used_files / max_files + reason = ( + f"{used_gb:.1f} / {max_gb} GiB" + if size_ratio > files_ratio + else f"{used_files} / {max_files} files" + ) + + freeing_up_space_instructions = ( + "For example, temporary files (logs, checkpoints, etc.) can be moved to " + "$SCRATCH, while files that need to be stored for longer periods can be moved " + "to $ARCHIVE or to a shared project folder under /network/projects.\n" + "Visit https://docs.mila.quebec/Information.html#storage to learn more about " + "how to best make use of the different filesystems available on the cluster." + ) + + if used_gb >= max_gb or used_files >= max_files: + raise MilatoolsUserError( + T.red( + f"ERROR: Your disk quota on the $HOME filesystem is exceeded! " + f"({reason}).\n" + f"To fix this, login to the cluster with `ssh {cluster}` and free up " + f"some space, either by deleting files, or by moving them to a " + f"suitable filesystem.\n" + freeing_up_space_instructions + ) + ) + if max(size_ratio, files_ratio) > 0.9: + warning_message = ( + f"You are getting pretty close to your disk quota on the $HOME " + f"filesystem: ({reason})\n" + "Please consider freeing up some space in your $HOME folder, either by " + "deleting files, or by moving them to a more suitable filesystem.\n" + + freeing_up_space_instructions + ) + logger.warning(UserWarning(warning_message)) + + +def find_allocation( + remote: RemoteV1, + node: str | None, + job: int | None, + alloc: list[str], + cluster: Cluster = "mila", + job_name: str = "mila-tools", +): + if (node is not None) + (job is not None) + bool(alloc) > 1: + exit("ERROR: --node, --job and --alloc are mutually exclusive") + + if node is not None: + node_name = get_hostname_to_use_for_compute_node(node, cluster=cluster) + return RemoteV1( + node_name, connect_kwargs=cluster_to_connect_kwargs.get(cluster) + ) + + elif job is not None: + node_name = remote.get_output(f"squeue --jobs {job} -ho %N") + return RemoteV1( + node_name, connect_kwargs=cluster_to_connect_kwargs.get(cluster) + ) + + else: + alloc = ["-J", job_name, *alloc] + return SlurmRemote( + connection=remote.connection, + alloc=alloc, + hostname=remote.hostname, + ) + + +def forward( + local: LocalV1, + node: str, + to_forward: int | str, + port: int | None, + page: str | None = None, + options: dict[str, str | None] = {}, + through_login: bool = False, +): + if port is None: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + # Find a free local port by binding to port 0 + sock.bind(("localhost", 0)) + _, port = sock.getsockname() + # Close it for ssh -L. It is *unlikely* it will not be available. + sock.close() + + if isinstance(to_forward, int) or re.match("[0-9]+", to_forward): + if through_login: + to_forward = f"{node}:{to_forward}" + args = [f"localhost:{port}:{to_forward}", "mila"] + else: + to_forward = f"localhost:{to_forward}" + args = [f"localhost:{port}:{to_forward}", node] + else: + args = [f"localhost:{port}:{to_forward}", node] + + proc = local.popen( + "ssh", + "-o", + "UserKnownHostsFile=/dev/null", + "-o", + "StrictHostKeyChecking=no", + "-nNL", + *args, + ) + + url = f"http://localhost:{port}" + if page is not None: + if not page.startswith("/"): + page = f"/{page}" + url += page + + options = {k: v for k, v in options.items() if v is not None} + if options: + url += f"?{urlencode(options)}" + + qn.print("Waiting for connection to be active...") + nsecs = 10 + period = 0.2 + for _ in range(int(nsecs / period)): + time.sleep(period) + try: + # This feels stupid, there's probably a better way + local.silent_get("nc", "-z", "localhost", str(port)) + except subprocess.CalledProcessError: + continue + except Exception: + break + break + + qn.print( + "Starting browser. You might need to refresh the page.", + style="bold", + ) + webbrowser.open(url) + return proc, port + + +def standard_server( + path: str | None, + *, + program: str, + installers: dict[str, str], + command: str, + profile: str | None, + persist: bool, + port: int | None, + name: str | None, + node: str | None, + job: int | None, + alloc: list[str], + port_pattern=None, + token_pattern=None, +): + # Make the server visible from the login node (other users will be able to connect) + # Temporarily disabled + share = False + + if name is not None: + persist = True + elif persist: + name = program + + remote = RemoteV1("mila") + + path = path or "~" + if path == "~" or path.startswith("~/"): + path = remote.home() + path[1:] + + results: dict | None = None + node_name: str | None = None + to_forward: int | str | None = None + cf: str | None = None + proc = None + with ExitStack() as stack: + if persist: + cf = stack.enter_context(with_control_file(remote, name=name)) + else: + cf = None + + if profile: + prof = f"~/.milatools/profiles/{profile}.bash" + else: + prof = setup_profile(remote, path) + + qn.print(f"Using profile: {prof}") + cat_result = remote.run(f"cat {prof}", hide=True, warn=True) + if cat_result.ok: + qn.print("=" * 50) + qn.print(cat_result.stdout.rstrip()) + qn.print("=" * 50) + else: + exit(f"Could not find or load profile: {prof}") + + premote = remote.with_profile(prof) + + if not ensure_program( + remote=premote, + program=program, + installers=installers, + ): + exit(f"Exit: {program} is not installed.") + + cnode = find_allocation( + remote, + job_name=f"mila-serve-{program}", + node=node, + job=job, + alloc=alloc, + cluster="mila", + ) + + patterns = { + "node_name": "#### ([A-Za-z0-9_-]+)", + } + + if port_pattern: + patterns["port"] = port_pattern + elif share: + exit( + "Server cannot be shared because it is serving over a Unix domain " + "socket" + ) + else: + remote.run("mkdir -p ~/.milatools/sockets", hide=True) + + if share: + host = "0.0.0.0" + else: + host = "localhost" + + sock_name = name or randname() + command = command.format( + path=path, + sock=f"~/.milatools/sockets/{sock_name}.sock", + host=host, + ) + + if token_pattern: + patterns["token"] = token_pattern + + if persist: + cnode = cnode.persist() + + proc, results = ( + cnode.with_profile(prof) + .with_precommand("echo '####' $(hostname)") + .extract( + command, + patterns=patterns, + ) + ) + node_name = results["node_name"] + + if port_pattern: + to_forward = int(results["port"]) + else: + to_forward = f"{remote.home()}/.milatools/sockets/{sock_name}.sock" + + if cf is not None: + remote.simple_run(f"echo program = {program} >> {cf}") + remote.simple_run(f"echo node_name = {results['node_name']} >> {cf}") + remote.simple_run(f"echo host = {host} >> {cf}") + remote.simple_run(f"echo to_forward = {to_forward} >> {cf}") + if token_pattern: + remote.simple_run(f"echo token = {results['token']} >> {cf}") + + assert results is not None + assert node_name is not None + assert to_forward is not None + assert proc is not None + if token_pattern: + options = {"token": results["token"]} + else: + options = {} + + local_proc, local_port = forward( + local=LocalV1(), + node=get_hostname_to_use_for_compute_node(node_name, cluster="mila"), + to_forward=to_forward, + options=options, + port=port, + ) + + if cf is not None: + remote.simple_run(f"echo local_port = {local_port} >> {cf}") + + try: + local_proc.wait() + except KeyboardInterrupt: + qn.print("Terminated by user.") + if cf is not None: + name = Path(cf).name + qn.print("To reconnect to this server, use the command:") + qn.print(f" mila serve connect {name}", style="bold yellow") + qn.print("To kill this server, use the command:") + qn.print(f" mila serve kill {name}", style="bold red") + finally: + local_proc.kill() + proc.kill() diff --git a/milatools/cli/init_command.py b/milatools/cli/init_command.py index add99165..521983bf 100644 --- a/milatools/cli/init_command.py +++ b/milatools/cli/init_command.py @@ -15,12 +15,12 @@ import questionary as qn from invoke.exceptions import UnexpectedExit +from ..utils.local_v1 import LocalV1, check_passwordless, display +from ..utils.remote_v1 import RemoteV1 from ..utils.vscode_utils import ( get_expected_vscode_settings_json_path, vscode_installed, ) -from .local import Local, check_passwordless, display -from .remote import Remote from .utils import SSHConfig, T, running_inside_WSL, yn logger = get_logger(__name__) @@ -236,7 +236,7 @@ def setup_passwordless_ssh_access(ssh_config: SSHConfig) -> bool: """ print("Checking passwordless authentication") - here = Local() + here = LocalV1() sshdir = Path.home() / ".ssh" ssh_private_key_path = Path.home() / ".ssh" / "id_rsa" @@ -288,7 +288,7 @@ def setup_passwordless_ssh_access_to_cluster(cluster: str) -> bool: Returns whether the operation completed successfully or not. """ - here = Local() + here = LocalV1() # Check that it is possible to connect without using a password. print(f"Checking if passwordless SSH access is setup for the {cluster} cluster.") # TODO: Potentially use a custom key like `~/.ssh/id_milatools.pub` instead of @@ -344,7 +344,7 @@ def setup_keys_on_login_node(): print("Checking connection to compute nodes") - remote = Remote("mila") + remote = RemoteV1("mila") try: pubkeys = remote.get_lines("ls -t ~/.ssh/id*.pub") print("# OK") @@ -400,7 +400,10 @@ def print_welcome_message(): def _copy_if_needed(linux_key_file: Path, windows_key_file: Path): - if linux_key_file.exists() and not windows_key_file.exists(): + if ( + linux_key_file.expanduser().exists() + and not windows_key_file.expanduser().exists() + ): print( f"Copying {linux_key_file} over to the Windows ssh folder at " f"{windows_key_file}." @@ -416,7 +419,7 @@ def get_windows_home_path_in_wsl() -> Path: def create_ssh_keypair( ssh_private_key_path: Path, - local: Local | None = None, + local: LocalV1 | None = None, passphrase: str | None = "", ) -> None: """Creates a public/private key pair at the given path using ssh-keygen. @@ -425,7 +428,7 @@ def create_ssh_keypair( Otherwise, if passphrase is an empty string, no passphrase will be used (default). If a string is passed, it is passed to ssh-keygen and used as the passphrase. """ - local = local or Local() + local = local or LocalV1() command = [ "ssh-keygen", "-f", diff --git a/milatools/cli/profile.py b/milatools/cli/profile.py index d4ff7638..1a71af45 100644 --- a/milatools/cli/profile.py +++ b/milatools/cli/profile.py @@ -16,7 +16,7 @@ from .utils import askpath, yn if typing.TYPE_CHECKING: - from milatools.cli.remote import Remote + from milatools.utils.remote_v1 import RemoteV1 style = qn.Style( [ @@ -38,7 +38,7 @@ def _ask_name(message: str, default: str = "") -> str: qn.print(f"Invalid name: {name}", style="bold red") -def setup_profile(remote: Remote, path: str) -> str: +def setup_profile(remote: RemoteV1, path: str) -> str: profile = select_preferred(remote, path) preferred = profile is not None if not preferred: @@ -58,7 +58,7 @@ def setup_profile(remote: Remote, path: str) -> str: return profile -def select_preferred(remote: Remote, path: str) -> str | None: +def select_preferred(remote: RemoteV1, path: str) -> str | None: preferred = f"{path}/.milatools-profile" qn.print(f"Checking for preferred profile in {preferred}") @@ -71,7 +71,7 @@ def select_preferred(remote: Remote, path: str) -> str | None: return preferred -def select_profile(remote: Remote) -> str | None: +def select_profile(remote: RemoteV1) -> str | None: profdir = "~/.milatools/profiles" qn.print(f"Fetching profiles in {profdir}") @@ -109,7 +109,7 @@ def select_profile(remote: Remote) -> str | None: return profile -def create_profile(remote: Remote, path: str = "~"): +def create_profile(remote: RemoteV1, path: str = "~"): modules = select_modules(remote) mload = f"module load {' '.join(modules)}" @@ -139,7 +139,7 @@ def create_profile(remote: Remote, path: str = "~"): return prof_file -def select_modules(remote: Remote): +def select_modules(remote: RemoteV1): choices = [ Choice( title="miniconda/3", @@ -202,7 +202,7 @@ def _env_basename(pth: str) -> str | None: return base -def select_conda_environment(remote: Remote, loader: str = "module load miniconda/3"): +def select_conda_environment(remote: RemoteV1, loader: str = "module load miniconda/3"): qn.print("Fetching the list of conda environments...") envstr = remote.get_output("conda env list --json", hide=True) envlist: list[str] = json.loads(envstr)["envs"] @@ -254,7 +254,7 @@ def select_conda_environment(remote: Remote, loader: str = "module load minicond return env -def select_virtual_environment(remote: Remote, path): +def select_virtual_environment(remote: RemoteV1, path): envstr = remote.get_output( ( f"ls -d {path}/venv {path}/.venv {path}/virtualenv ~/virtualenvs/* " @@ -293,7 +293,7 @@ def select_virtual_environment(remote: Remote, path): return env -def ensure_program(remote: Remote, program: str, installers: dict[str, str]): +def ensure_program(remote: RemoteV1, program: str, installers: dict[str, str]): to_test = [program, *installers.keys()] progs = [ Path(p).name diff --git a/milatools/cli/utils.py b/milatools/cli/utils.py index f4ca0964..c7e9343b 100644 --- a/milatools/cli/utils.py +++ b/milatools/cli/utils.py @@ -1,9 +1,11 @@ from __future__ import annotations +import argparse import contextvars import functools import itertools import multiprocessing +import operator import random import shutil import socket @@ -11,6 +13,7 @@ import sys import typing import warnings +from argparse import _HelpAction from collections.abc import Callable, Iterable from contextlib import contextmanager from pathlib import Path @@ -24,10 +27,13 @@ from typing_extensions import ParamSpec, TypeGuard if typing.TYPE_CHECKING: - from milatools.cli.remote import Remote + from milatools.utils.remote_v1 import RemoteV1 control_file_var = contextvars.ContextVar("control_file", default="/dev/null") +SSH_CONFIG_FILE = Path.home() / ".ssh" / "config" +SSH_CACHE_DIR = Path.home() / ".cache" / "ssh" + T = blessed.Terminal() @@ -96,7 +102,7 @@ def randname(): @contextmanager -def with_control_file(remote: Remote, name=None): +def with_control_file(remote: RemoteV1, name=None): name = name or randname() pth = f".milatools/control/{name}" remote.run("mkdir -p ~/.milatools/control", hide=True) @@ -172,7 +178,7 @@ def yn(prompt: str, default: bool = True) -> bool: return qn.confirm(prompt, default=default).unsafe_ask() -def askpath(prompt: str, remote: Remote) -> str: +def askpath(prompt: str, remote: RemoteV1) -> str: while True: pth = qn.text(prompt).unsafe_ask() try: @@ -249,23 +255,49 @@ def hoststring(self, host: str) -> str: return "\n".join(lines) -def get_fully_qualified_hostname_of_compute_node( - node_name: str, cluster: str = "mila" +def get_hostname_to_use_for_compute_node( + node_name: str, cluster: str = "mila", ssh_config_path: Path = SSH_CONFIG_FILE ) -> str: - """Return the fully qualified name corresponding to this node name.""" - if cluster == "mila": - if node_name.endswith(".server.mila.quebec"): - return node_name - return f"{node_name}.server.mila.quebec" - if cluster in CLUSTERS: - # For the other explicitly supported clusters in the SSH config, the node name - # of the compute node can be used directly with ssh from the local machine, no - # need to use a fully qualified name. + """Return the hostname to use to connect to this compute note via ssh.""" + # TODO: Raise an error if the SSH config file doesn't exist, to remove the + # hard-coded code below below this 'if' and remove a level of nesting. + if not ssh_config_path.exists(): + # If the SSH config file doesn't exist, we can't do much. + raise MilatoolsUserError( + f"SSH Config doesn't exist at {ssh_config_path}, did you run `mila init`?" + ) + + ssh_config = paramiko.SSHConfig.from_path(str(ssh_config_path)) + + # If there is an entry matching for the compute node name (cn-a001) and there + # isn't one matching the fully qualified compute node name + # (cn-a001.(...).quebec), + # then use the compute node name. + + def should_be_used_to_connect(hostname: str) -> bool: + """Returns whether `hostname` should be used to run `ssh {hostname}`. + + Returns True if an entry matches `hostname` and returns a different hostname to + use, or if the "proxyjump" option is set. + """ + options = ssh_config.lookup(hostname) + return bool(options.get("proxyjump")) or options["hostname"] != hostname + + if should_be_used_to_connect(node_name): + # There is an entry in the sshconfig for e.g. `cn-a001` that sets the + # hostname to use as `cn-a001.(...).quebec` or similar. return node_name + if cluster == "mila" and should_be_used_to_connect( + fully_qualified_name := f"{node_name}.server.mila.quebec" + ): + return fully_qualified_name warnings.warn( UserWarning( - f"Using a custom cluster {cluster}. Assuming that we can ssh directly to " - f"its compute node {node_name!r}." + f"Unable to find the hostname to use to connect to node {node_name} of " + f"the {cluster} cluster.\n" + f"Assuming that we can ssh directly to {node_name} for now. To fix " + f"this, consider adding an entry that matches the compute node " + f"{node_name} in the SSH config file at {ssh_config_path}" ) ) return node_name @@ -343,5 +375,18 @@ def removesuffix(s: str, suffix: str) -> str: return s[: -len(suffix)] else: return s + else: removesuffix = str.removesuffix + + +class SortingHelpFormatter(argparse.HelpFormatter): + """Taken and adapted from https://stackoverflow.com/a/12269143/6388696.""" + + def add_arguments(self, actions): + actions = sorted(actions, key=operator.attrgetter("option_strings")) + # put help actions first. + actions = sorted( + actions, key=lambda action: not isinstance(action, _HelpAction) + ) + super().add_arguments(actions) diff --git a/milatools/utils/compute_node.py b/milatools/utils/compute_node.py new file mode 100644 index 00000000..0a7b81d9 --- /dev/null +++ b/milatools/utils/compute_node.py @@ -0,0 +1,387 @@ +from __future__ import annotations + +import asyncio.subprocess +import contextlib +import dataclasses +import datetime +import inspect +import re +import shlex +import signal +import subprocess +import sys +import warnings + +from milatools.cli import console +from milatools.cli.init_command import DRAC_CLUSTERS +from milatools.cli.utils import get_hostname_to_use_for_compute_node +from milatools.utils.remote_v1 import Hide +from milatools.utils.remote_v2 import RemoteV2, logger, ssh_command +from milatools.utils.runner import Runner + + +@dataclasses.dataclass +class ComputeNode(Runner): + """Runs commands on a compute node with `srun --jobid {job_id}` from the login node. + + This essentially runs this: + `ssh {cluster} srun --overlap --jobid {job_id} {command}` + in a subprocess each time `run` is called. + + NOTE: Found out about this trick from https://hpc.fau.de/faq/how-can-i-attach-to-a-running-slurm-job/ + """ + + login_node: RemoteV2 + job_id: int + salloc_subprocess: asyncio.subprocess.Process | None = None + + def __post_init__(self): + # The hostname will be of the compute node, not the login node. + # NOTE: On DRAC clusters, we don't actually want the full hostname here, because + # the SSH config we make with `mila init` matches the node names like `cdr****`, + # but not the fully qualified hostnames (`cdr2052.int.cedar.computecanada.ca`). + cluster = self.login_node.hostname + node_name = self.get_output("echo $SLURMD_NODENAME", display=False, hide=True) + self.hostname: str = get_hostname_to_use_for_compute_node( + node_name, + cluster=cluster, + ssh_config_path=self.login_node.ssh_config_path, + ) + + def run( + self, command: str, display: bool = True, warn: bool = False, hide: Hide = False + ): + if display: + # Show the compute node hostname instead of the login node. + console.log(f"({self.hostname}) $ {command}", style="green") + if shlex.quote(command) == command: + return self.login_node.run( + command=f"srun --ntasks=1 --overlap --quiet --jobid {self.job_id} {command}", + display=False, + warn=warn, + hide=hide, + ) + + return self.login_node.run( + command=f"srun --ntasks=1 --overlap --quiet --jobid {self.job_id} bash", + input=f"{command}\n", + display=False, + warn=warn, + hide=hide, + ) + + async def run_async( + self, + command: str, + display: bool = True, + warn: bool = False, + hide: Hide = False, + ) -> subprocess.CompletedProcess[str]: + if display: + # Show the compute node hostname instead of the login node. + console.log(f"({self.hostname}) $ {command}", style="green") + if shlex.quote(command) == command: + return await self.login_node.run_async( + command=f"srun --ntasks=1 --overlap --quiet --jobid {self.job_id} {command}", + display=False, + warn=warn, + hide=hide, + ) + + return await self.login_node.run_async( + command=f"srun --ntasks=1 --overlap --quiet --jobid {self.job_id} bash", + input=f"{command}\n", + display=False, + warn=warn, + hide=hide, + ) + + async def close(self): + """Cancels the running job using `scancel`.""" + logger.info(f"Stopping job {self.job_id}.") + if self.salloc_subprocess is not None: + # NOTE: This will exit cleanly because we don't have nested terminals or + # job steps. + if self.salloc_subprocess.stdin is not None: + await self.salloc_subprocess.communicate("exit\n".encode()) # noqa: UP012 + else: + self.salloc_subprocess.send_signal(signal=signal.SIGINT) + # The scancel below is done even though it's redundant, just to be safe. + await self.login_node.run_async( + f"scancel {self.job_id}", display=True, hide=False + ) + + def __repr__(self) -> str: + params = ", ".join( + f"{k}={repr(getattr(self, k))}" + for k in inspect.signature(type(self)).parameters + ) + return f"{type(self).__name__}({params})" + + +async def get_queued_milatools_job_ids( + login_node: RemoteV2, job_name="mila-code" +) -> set[int]: + # NOTE: `since` is unused in this case. + jobs = await login_node.get_output_async( + f"squeue --noheader --me --format=%A --name={job_name}" + ) + return set([int(job_id_str) for job_id_str in jobs.splitlines()]) + + +async def get_milatools_job_ids( + login_node: RemoteV2, + job_name="mila-code", + since: datetime.timedelta = datetime.timedelta(hours=1), +) -> set[int]: + """Get the IDS of jobs created by milatools from the output of `sacct`.""" + jobs = await login_node.get_output_async( + f"sacct --noheader --allocations --user=$USER " + f"--starttime=now-{int(since.total_seconds())}seconds --format=JobId " + f"--name={job_name}" + ) + return set([int(job_id_str) for job_id_str in jobs.splitlines()]) + + +@contextlib.asynccontextmanager +async def cancel_new_jobs_on_interrupt(login_node: RemoteV2, job_name: str): + """ContextManager that handles interruptions while creating a new allocation. + + This handles the case where an interrupt is raised while running a command over SSH + that creates a new job allocation (either salloc or sbatch) before we are able to + parse the job id. (In the case where we have the job ID, we simply use + `scancel {job_id}`). + + In this case, we try to cancel the new job(s) that have appeared since entering the + `async with` block that have the name `job_name`. Emits a warning in the (unlikely) + case where such jobs are not found, as it means that there could be a "zombie" job + allocation on the cluster. + """ + jobs_before = await get_queued_milatools_job_ids(login_node, job_name=job_name) + try: + yield + except (KeyboardInterrupt, asyncio.CancelledError): + jobs_after = await get_queued_milatools_job_ids(login_node, job_name=job_name) + logger.warning("Interrupted before we were able to parse a job id!") + # We were unable to get the job id, so we'll try to cancel only the newly + # spawned jobs from this user that match the set name. + new_jobs = list(set(jobs_after) - set(jobs_before)) + if len(new_jobs) == 1: + job_id = new_jobs[0] + console.log( + f"Cancelling job {job_id} since it is the only new job of this " + f"user with name {job_name!r} since the last call to salloc or sbatch.", + style="yellow", + ) + login_node.run(f"scancel {new_jobs[0]}", display=True, hide=False) + elif len(new_jobs) > 1: + console.log( + f"There appears to be more than one new jobs from this user with " + f"name {job_name!r} since the initial call to salloc or sbatch: " + f"{new_jobs}\n" + "Cancelling all of them to be safe...", + style="yellow", + ) + login_node.run( + "scancel " + " ".join(str(job_id) for job_id in new_jobs), + display=True, + hide=False, + ) + else: + warnings.warn( + RuntimeWarning( + f"Unable to find any new job IDs with name {job_name!r} since the last " + f"job allocation. This means that if job allocations were created, " + "they might not have been properly cancelled. Please check that there " + f"are no leftover jobs with {job_name!r} on the cluster!" + ) + ) + raise + + +async def salloc( + login_node: RemoteV2, salloc_flags: list[str], job_name: str +) -> ComputeNode: + """Runs `salloc` and returns a remote connected to the compute node.""" + + salloc_command = "salloc " + shlex.join(salloc_flags) + if login_node.hostname in DRAC_CLUSTERS: + salloc_command = f"cd $SCRATCH && {salloc_command}" + + command = ssh_command( + hostname=login_node.hostname, + control_path=login_node.control_path, + control_master="auto", + control_persist="yes", + command=salloc_command, + ssh_config_path=login_node.ssh_config_path, + ) + + salloc_subprocess = None + job_id: int | None = None + + # "Why not just use `subprocess.Popen`?", you might ask. Well, we're essentially + # trying to go full-async so that the parsing of the job-id from stderr can + # eventually be done at the same time as something else (while waiting for the + # job to start) using things like `asyncio.gather` and `asyncio.wait_for`. + logger.debug(f"(local) $ {shlex.join(command)}") + console.log(f"({login_node.hostname}) $ {salloc_command}", style="green") + async with cancel_new_jobs_on_interrupt(login_node, job_name): + salloc_subprocess = await asyncio.subprocess.create_subprocess_exec( + *command, + shell=False, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + assert salloc_subprocess.stderr is not None + while job_id is None: + error_line = (await salloc_subprocess.stderr.readline()).decode() + print(error_line, end="", file=sys.stderr) + if job_id_match := re.findall(r"job allocation [0-9]+", error_line): + job_id = int(job_id_match[0].split()[-1]) + break + if not error_line: + # Getting an empty line (only?) after salloc errors are done printing. + break + + if job_id is None: + salloc_subprocess.kill() + raise RuntimeError("Unable to parse the job ID from the output of salloc!") + + try: + console.log(f"Waiting for job {job_id} to start.", style="green") + _node, state = await wait_while_job_is_pending(login_node, job_id) + except (KeyboardInterrupt, asyncio.CancelledError): + if salloc_subprocess is not None: + logger.debug("Killing the salloc subprocess following a KeyboardInterrupt.") + salloc_subprocess.send_signal(signal.SIGINT) + salloc_subprocess.terminate() + login_node.run(f"scancel {job_id}", display=True, hide=False) + raise + + # todo: Are there are states between `PENDING` and `RUNNING`? + # if state != "RUNNING": + # raise RuntimeError( + # f"Error: Expected job {job_id} to be running, but it is in state {state!r}!" + # ) + + # NOTE: passing the process handle to this ComputeNodeRemote so it doesn't go out of + # scope and die (which would maybe kill the job, not 100% sure). + + return ComputeNode( + job_id=job_id, + login_node=login_node, + salloc_subprocess=salloc_subprocess, + ) + + +async def sbatch( + login_node: RemoteV2, sbatch_flags: list[str], job_name: str +) -> ComputeNode: + """Runs `sbatch` and returns a remote connected to the compute node. + + The job script is actually the `sleep` command wrapped in an sbatch script thanks to + [the '--wrap' argument of sbatch](https://slurm.schedmd.com/sbatch.html#OPT_wrap) + + This then waits asynchronously until the job show us as RUNNING in the output of the + `sacct` command. + """ + # idea: Find the job length from the sbatch flags if possible so we can do + # --wrap='sleep {job_duration}' instead of 'sleep 7d' so the job doesn't look + # like it failed or was interrupted, just cleanly exits before the end time. + sbatch_command = ( + "sbatch --parsable " + shlex.join(sbatch_flags) + " --wrap 'srun sleep 7d'" + ) + cluster = login_node.hostname + if cluster in DRAC_CLUSTERS: + sbatch_command = f"cd $SCRATCH && {sbatch_command}" + + job_id = None + async with cancel_new_jobs_on_interrupt(login_node, job_name): + job_id = await login_node.get_output_async( + sbatch_command, display=True, hide=False + ) + job_id = int(job_id) + + try: + await wait_while_job_is_pending(login_node, job_id) + except (KeyboardInterrupt, asyncio.CancelledError): + console.log(f"Received KeyboardInterrupt, cancelling job {job_id}") + login_node.run(f"scancel {job_id}", display=True, hide=False) + raise + + return ComputeNode(job_id=job_id, login_node=login_node) + + +async def _wait_while_job_is_in_state(login_node: RemoteV2, job_id: int, state: str): + nodes: str | None = None + current_state: str | None = None + wait_time_seconds = 1 + attempt = 1 + + while True: + result = await login_node.run_async( + f"sacct --jobs {job_id} --allocations --noheader --format=Node,State", + display=False, + warn=True, # don't raise an error if the command fails. + hide=True, + ) + stdout = result.stdout.strip() + nodes, _, current_state = stdout.rpartition(" ") + nodes = nodes.strip() + current_state = current_state.strip() + logger.debug(f"{nodes=}, {current_state=}") + + if ( + result.returncode == 0 + and nodes + and nodes != "None assigned" + and current_state + and current_state != state + ): + logger.info( + f"Job {job_id} was allocated node(s) {nodes!r} and is in state " + f"{current_state!r}." + ) + return nodes, current_state + + waiting_until = f"Waiting {wait_time_seconds} seconds until job {job_id} " + condition: str | None = None + if result.returncode == 0 and not nodes and not current_state: + condition = "shows up in the output of `sacct`." + elif result.returncode != 0: + # todo: Look into this case a bit more deeply. Seems like sometimes sacct + # gives errors for example right after salloc, when the job id is not yet + # in the slurm DB. + condition = "shows up in the output of `sacct`." + elif nodes == "None assigned": + condition = "is allocated a node." + elif current_state == state: + condition = f"is no longer {state}." + else: + # TODO: Don't yet understand when this case could occur. + logger.warning( + f"Unexpected result from `sacct` for job {job_id}: {result.stdout=}, {result.stderr=}" + ) + condition = "shows up correctly in the output of sacct." + logger.info(waiting_until + condition) + + if attempt > 1: + logger.debug(f"Attempt #{attempt}") + + await asyncio.sleep(wait_time_seconds) + wait_time_seconds *= 2 + # wait at most 30 seconds for each attempt. + wait_time_seconds = min(30, wait_time_seconds) + attempt += 1 + + +async def wait_while_job_is_pending( + login_node: RemoteV2, job_id: int +) -> tuple[str, str]: + """Waits until a job show up in `sacct` then waits until its state is not PENDING. + + Returns the `Node` and `State` from `sacct` after the job is no longer pending. + """ + return await _wait_while_job_is_in_state(login_node, job_id, state="PENDING") diff --git a/milatools/cli/local.py b/milatools/utils/local_v1.py similarity index 90% rename from milatools/cli/local.py rename to milatools/utils/local_v1.py index fb8e2ee7..3ea24e84 100644 --- a/milatools/cli/local.py +++ b/milatools/utils/local_v1.py @@ -12,14 +12,14 @@ import paramiko.ssh_exception from typing_extensions import deprecated +from milatools.cli.utils import CommandNotFoundError, T, cluster_to_connect_kwargs from milatools.utils.remote_v2 import SSH_CONFIG_FILE, is_already_logged_in -from .utils import CommandNotFoundError, T, cluster_to_connect_kwargs - logger = get_logger(__name__) -class Local: +@deprecated("LocalV1 is being deprecated. Use LocalV2 instead.", category=None) +class LocalV1: def display(self, args: list[str] | tuple[str, ...]) -> None: display(args) @@ -90,9 +90,12 @@ def check_passwordless(host: str) -> bool: return True try: + # TODO: Would need to use a key to somehow say "you can ask for a passphrase, + # but not for a password" somehow. connect_kwargs_for_host = {"allow_agent": False} if host in cluster_to_connect_kwargs: connect_kwargs_for_host.update(cluster_to_connect_kwargs[host]) + logger.debug(f"Connecting with connect_kwargs: {connect_kwargs_for_host}") with fabric.Connection( host, connect_kwargs=connect_kwargs_for_host, diff --git a/milatools/utils/local_v2.py b/milatools/utils/local_v2.py new file mode 100644 index 00000000..d8fc1a74 --- /dev/null +++ b/milatools/utils/local_v2.py @@ -0,0 +1,238 @@ +from __future__ import annotations + +import asyncio +import dataclasses +import shlex +import subprocess +import sys +from logging import getLogger as get_logger +from subprocess import CompletedProcess + +from milatools.cli import console +from milatools.utils.remote_v1 import Hide +from milatools.utils.runner import Runner + +logger = get_logger(__name__) + + +@dataclasses.dataclass(init=False, frozen=True) +class LocalV2(Runner): + """A runner that runs commands in subprocesses on the local machine.""" + + hostname = "localhost" + + @staticmethod + def run( + command: str | tuple[str, ...], + input: str | None = None, + display: bool = True, + warn: bool = False, + hide: Hide = False, + ) -> CompletedProcess[str]: + program_and_args = _display_command(command, input=input, display=display) + return run(program_and_args=program_and_args, input=input, warn=warn, hide=hide) + + @staticmethod + def get_output( + command: str | tuple[str, ...], + *, + display: bool = False, + warn: bool = False, + hide: Hide = True, + ) -> str: + return LocalV2.run( + command, display=display, warn=warn, hide=hide + ).stdout.strip() + + @staticmethod + async def run_async( + command: str | tuple[str, ...], + input: str | None = None, + display: bool = True, + warn: bool = False, + hide: Hide = False, + ) -> CompletedProcess[str]: + program_and_args = _display_command(command, input=input, display=display) + return await run_async(program_and_args, input=input, warn=warn, hide=hide) + + @staticmethod + async def get_output_async( + command: str | tuple[str, ...], + *, + display: bool = False, + warn: bool = False, + hide: Hide = True, + ) -> str: + """Runs the command asynchronously and returns the stripped output string.""" + return ( + await LocalV2.run_async(command, display=display, warn=warn, hide=hide) + ).stdout.strip() + + +def _display_command( + command: str | tuple[str, ...], input: str | None, display: bool +) -> tuple[str, ...]: + """Converts the command to a tuple of strings if needed with `shlex.split` and + optionally logs it to the console. + + Also shows the input that would be passed to the command, if any. + """ + if isinstance(command, str): + program_and_args = tuple(shlex.split(command)) + displayed_command = command + else: + program_and_args = command + displayed_command = shlex.join(command) + if display: + if not input: + console.log( + f"(localhost) $ {displayed_command}", + style="green", + _stack_offset=2, + ) + else: + console.log( + f"(localhost) $ {displayed_command}\n{input}", + style="green", + _stack_offset=2, + ) + return program_and_args + + +def run( + program_and_args: tuple[str, ...], + input: str | None = None, + warn: bool = False, + hide: Hide = False, +) -> subprocess.CompletedProcess[str]: + """Runs the command *synchronously* in a subprocess and returns the result. + + Parameters + ---------- + program_and_args: The program and arguments to pass to it. This is a tuple of \ + strings, same as in `subprocess.Popen`. + input: The optional 'input' argument to `subprocess.Popen.communicate()`. + warn: When `True` and an exception occurs, warn instead of raising the exception. + hide: Controls the printing of the subprocess' stdout and stderr. + + Returns + ------- + The `subprocess.CompletedProcess` object with the result of the subprocess. + + Raises + ------ + subprocess.CalledProcessError + If an error occurs when running the command and `warn` is `False`. + """ + displayed_command = shlex.join(program_and_args) + if not input: + logger.debug(f"Calling `subprocess.run` with {program_and_args=}") + else: + logger.debug(f"Calling `subprocess.run` with {program_and_args=} and {input=}") + result = subprocess.run( + program_and_args, + shell=False, + capture_output=True, + text=True, + check=not warn, + input=input, + ) + assert result.returncode is not None + if warn and result.returncode != 0: + message = ( + f"Command {displayed_command!r}" + + (f" with {input=!r}" if input else "") + + f" exited with {result.returncode}: {result.stderr=}" + ) + logger.debug(message) + if hide is not True: # don't warn if hide is True. + logger.warning(RuntimeWarning(message), stacklevel=2) + + if result.stdout: + if hide not in [True, "out", "stdout"]: + print(result.stdout) + logger.debug(f"{result.stdout=}") + if result.stderr: + if hide not in [True, "err", "stderr"]: + print(result.stderr, file=sys.stderr) + logger.debug(f"{result.stderr=}") + return result + + +async def run_async( + program_and_args: tuple[str, ...], + input: str | None = None, + warn: bool = False, + hide: Hide = False, +) -> subprocess.CompletedProcess[str]: + """Runs the command *asynchronously* in a subprocess and returns the result. + + Parameters + ---------- + program_and_args: The program and arguments to pass to it. This is a tuple of \ + strings, same as in `subprocess.Popen`. + input: The optional 'input' argument to `subprocess.Popen.communicate()`. + warn: When `True` and an exception occurs, warn instead of raising the exception. + hide: Controls the printing of the subprocess' stdout and stderr. + + Returns + ------- + A `subprocess.CompletedProcess` object with the result of the asyncio.Process. + + Raises + ------ + subprocess.CalledProcessError + If an error occurs when running the command and `warn` is `False`. + """ + command_with_input = program_and_args + ((input,) if input else ()) + + logger.debug(f"Calling `asyncio.create_subprocess_exec` with {program_and_args=}") + proc = await asyncio.create_subprocess_exec( + *program_and_args, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + stdin=asyncio.subprocess.PIPE if input else None, + start_new_session=False, + ) + if input: + logger.debug(f"Sending {input=!r} to the subprocess' stdin.") + + stdout, stderr = await proc.communicate(input.encode() if input else None) + + assert proc.returncode is not None + if proc.returncode != 0: + logger.debug( + f"[{command_with_input!r}" + + (f" with input {input!r}" if input else "") + + f" exited with {proc.returncode}]" + ) + if not warn: + if stderr: + logger.error(stderr) + raise subprocess.CalledProcessError( + returncode=proc.returncode, + cmd=program_and_args, + output=stdout, + stderr=stderr, + ) + if hide is not True: # don't warn if hide is True. + logger.warning( + RuntimeWarning( + f"Command {program_and_args!r} returned non-zero exit code {proc.returncode}: {stderr}" + ) + ) + result = subprocess.CompletedProcess( + args=program_and_args, + returncode=proc.returncode, + stdout=stdout.decode(), + stderr=stderr.decode(), + ) + if result.stdout: + if hide not in [True, "out", "stdout"]: + print(result.stdout) + logger.debug(f"{result.stdout}") + if result.stderr: + if hide not in [True, "err", "stderr"]: + print(result.stderr, file=sys.stderr) + logger.debug(f"{result.stderr}") + return result diff --git a/milatools/cli/remote.py b/milatools/utils/remote_v1.py similarity index 98% rename from milatools/cli/remote.py rename to milatools/utils/remote_v1.py index 43997856..89a2229c 100644 --- a/milatools/cli/remote.py +++ b/milatools/utils/remote_v1.py @@ -19,7 +19,7 @@ from fabric import Connection from typing_extensions import Self, TypedDict, deprecated -from .utils import ( +from ..cli.utils import ( DRAC_CLUSTERS, SSHConnectionError, T, @@ -110,7 +110,8 @@ def get_first_node_name(node_names_out: str) -> str: return base + inside_brackets.split("-")[0] -class Remote: +@deprecated("RemoteV1 is being deprecated. Use RemoteV2 instead.", category=None) +class RemoteV1: def __init__( self, hostname: str, @@ -446,7 +447,7 @@ def extract_script( return self.extract(shlex.join([dest, *args]), pattern=pattern, **kwargs) -class SlurmRemote(Remote): +class SlurmRemote(RemoteV1): def __init__( self, connection: fabric.Connection, @@ -533,7 +534,7 @@ def ensure_allocation( "jobid": results["jobid"], }, login_node_runner else: - remote = Remote(hostname=self.hostname, connection=self.connection) + remote = RemoteV1(hostname=self.hostname, connection=self.connection) command = shlex.join(["salloc", *self.alloc]) # NOTE: On some DRAC clusters, it's required to first cd to $SCRATCH or # /projects before submitting a job. diff --git a/milatools/utils/remote_v2.py b/milatools/utils/remote_v2.py index 37166e60..34d4e8fa 100644 --- a/milatools/utils/remote_v2.py +++ b/milatools/utils/remote_v2.py @@ -1,25 +1,29 @@ from __future__ import annotations +import dataclasses import getpass -import shlex import shutil import subprocess import sys from logging import getLogger as get_logger from pathlib import Path -from typing import Any, Literal +from typing import Literal from paramiko import SSHConfig from milatools.cli import console -from milatools.cli.remote import Hide -from milatools.cli.utils import DRAC_CLUSTERS, MilatoolsUserError +from milatools.cli.utils import ( + DRAC_CLUSTERS, + SSH_CACHE_DIR, + SSH_CONFIG_FILE, + MilatoolsUserError, +) +from milatools.utils.local_v2 import LocalV2, run_async +from milatools.utils.remote_v1 import Hide +from milatools.utils.runner import Runner logger = get_logger(__name__) -SSH_CONFIG_FILE = Path.home() / ".ssh" / "config" -SSH_CACHE_DIR = Path.home() / ".cache" / "ssh" - class UnsupportedPlatformError(MilatoolsUserError): ... @@ -37,68 +41,24 @@ def raise_error_if_running_on_windows(): ) -def ssh_command( - hostname: str, - control_path: Path | Literal["none"], - command: str, - control_master: Literal["yes", "no", "auto", "ask", "autoask"] = "auto", - control_persist: int | str | Literal["yes", "no"] = "yes", -): - """Returns a tuple of strings to be used as the command to be run in a subprocess. - - Parameters - ---------- - hostname: The hostname to connect to. - control_path : See https://man.openbsd.org/ssh_config#ControlPath - command: The command to run on the remote host (kept as a string). - control_master: See https://man.openbsd.org/ssh_config#ControlMaster - control_persist: See https://man.openbsd.org/ssh_config#ControlPersist - - Returns - ------- - The tuple of strings to pass to `subprocess.run` or similar. - """ - return ( - "ssh", - f"-oControlMaster={control_master}", - f"-oControlPersist={control_persist}", - f"-oControlPath={control_path}", - hostname, - command, - ) - - -def control_socket_is_running(host: str, control_path: Path) -> bool: - """Check whether the control socket at the given path is running.""" - if not control_path.exists(): - return False - result = subprocess.run( - ("ssh", "-O", "check", f"-oControlPath={control_path}", host), - shell=False, - text=True, - capture_output=True, - ) - if ( - result.returncode != 0 - or not result.stderr - or not result.stderr.startswith("Master running") - ): - logger.debug(f"{control_path=} doesn't exist or isn't running: {result=}.") - return False - return True - - -class RemoteV2: +@dataclasses.dataclass(init=False) +class RemoteV2(Runner): """Simpler Remote where commands are run in subprocesses sharing an SSH connection. This doesn't work on Windows, as it assumes that the SSH client has SSH multiplexing support (ControlMaster, ControlPath and ControlPersist). """ + hostname: str + control_path: Path + ssh_config_path: Path + def __init__( self, hostname: str, + *, control_path: Path | None = None, + ssh_config_path: Path = SSH_CONFIG_FILE, ): """Create an SSH connection using this control_path, creating it if necessary. @@ -108,9 +68,15 @@ def __init__( control_path: The path where the control socket will be created if it doesn't \ already exist. You can use `get_controlpath_for` to get this for a given hostname. + ssh_config_path: Path to the ssh config file. Defaults to `SSH_CONFIG_FILE`. """ self.hostname = hostname - self.control_path = control_path or get_controlpath_for(hostname) + self.ssh_config_path = ssh_config_path + self.control_path = control_path or get_controlpath_for( + hostname, ssh_config_path=self.ssh_config_path + ) + self.control_path = self.control_path.expanduser() + self.local_runner = LocalV2() if not control_socket_is_running(self.hostname, self.control_path): logger.info( @@ -118,17 +84,55 @@ def __init__( ) setup_connection_with_controlpath( self.hostname, - self.control_path, - timeout=None, - display=False, + control_path=self.control_path, + ssh_config_path=self.ssh_config_path, ) else: - logger.info(f"Reusing an existing SSH socket at {self.control_path}.") + logger.debug(f"Reusing an existing SSH socket at {self.control_path}.") assert control_socket_is_running(self.hostname, self.control_path) + @staticmethod + async def connect( + hostname: str, + *, + control_path: Path | None = None, + ssh_config_path: Path = SSH_CONFIG_FILE, + ) -> RemoteV2: + """Async constructor. + + Note: Would be nice to remove the duplicated code between this method and init, + but it isn't clear to me how to use async functions in an __init__ method. + """ + control_path = control_path or get_controlpath_for( + hostname, ssh_config_path=ssh_config_path + ) + control_path = control_path.expanduser() + + if not await control_socket_is_running_async(hostname, control_path): + logger.info(f"Creating a reusable connection to the {hostname} cluster.") + setup_connection_with_controlpath( + hostname, + control_path=control_path, + ssh_config_path=ssh_config_path, + ) + else: + logger.info(f"Reusing an existing SSH socket at {control_path}.") + return RemoteV2( + hostname, + control_path=control_path, + # ssh_options=ssh_options, + ssh_config_path=ssh_config_path, + ) + def run( - self, command: str, display: bool = True, warn: bool = False, hide: Hide = False + self, + command: str, + *, + input: str | None = None, + display: bool = True, + warn: bool = False, + hide: Hide = False, ): assert self.control_path.exists() run_command = ssh_command( @@ -137,45 +141,175 @@ def run( control_master="auto", control_persist="yes", command=command, + ssh_config_path=self.ssh_config_path, + ) + if display: + # NOTE: Only display the input if it is passed. + if not input: + console.log(f"({self.hostname}) $ {command}", style="green") + else: + console.log(f"({self.hostname}) $ {command=}\n{input}", style="green") + return self.local_runner.run( + command=run_command, input=input, display=False, warn=warn, hide=hide + ) + + async def run_async( + self, + command: str, + *, + input: str | None = None, + display: bool = True, + warn: bool = False, + hide: Hide = False, + ) -> subprocess.CompletedProcess[str]: + assert self.control_path.exists() + run_command = ssh_command( + hostname=self.hostname, + control_path=self.control_path, + control_master="auto", + control_persist="yes", + command=command, + ssh_config_path=self.ssh_config_path, ) - logger.debug(f"(local) $ {shlex.join(run_command)}") if display: - console.log(f"({self.hostname}) $ {command}", style="green") - result = subprocess.run( - run_command, - capture_output=True, - check=not warn, - text=True, - bufsize=1, # 1 means line buffered + if not input: + console.log(f"({self.hostname}) $ {command}", style="green") + else: + console.log(f"({self.hostname}) $ {command=}\n{input}", style="green") + return await self.local_runner.run_async( + command=run_command, input=input, display=False, warn=warn, hide=hide ) - if result.stdout: - if hide not in [True, "out", "stdout"]: - print(result.stdout) - logger.debug(f"{result.stdout}") - if result.stderr: - if hide not in [True, "err", "stderr"]: - print(result.stderr) - logger.debug(f"{result.stderr}") - return result - - def __eq__(self, other: Any) -> bool: + + +# note: Could potentially cache the results of this function if we wanted to, assuming +# that the ssh config file doesn't change. + + +def ssh_command( + hostname: str, + control_path: Path, + command: str, + control_master: Literal["yes", "no", "auto", "ask", "autoask"] = "auto", + control_persist: int | str | Literal["yes", "no"] = "yes", + ssh_config_path: Path = SSH_CONFIG_FILE, +): + """Returns a tuple of strings to be used as the command to be run in a subprocess. + + When the path to the SSH config file is passed and exists, this will only add the + options which aren't already set in the SSH config, so as to avoid redundant + arguments to the `ssh` command. + + Parameters + ---------- + hostname: The hostname to connect to. + control_path : See https://man.openbsd.org/ssh_config#ControlPath + command: The command to run on the remote host (kept as a string). + control_master: See https://man.openbsd.org/ssh_config#ControlMaster + control_persist: See https://man.openbsd.org/ssh_config#ControlPersist + ssh_config_path: Path to the ssh config file. + + Returns + ------- + The tuple of strings to pass to `subprocess.run` or similar. + """ + control_path = control_path.expanduser() + ssh_config_path = ssh_config_path.expanduser() + if not ssh_config_path.exists(): return ( - isinstance(other, type(self)) - and other.hostname == self.hostname - and other.control_path == self.control_path + "ssh", + f"-oControlMaster={control_master}", + f"-oControlPersist={control_persist}", + f"-oControlPath={control_path}", + hostname, + command, ) - def __repr__(self) -> str: - return f"{type(self).__name__}(hostname={self.hostname!r}, control_path={str(self.control_path)})" + ssh_command: list[str] = ["ssh"] + ssh_config_entry = SSHConfig.from_path(str(ssh_config_path)).lookup(hostname) + if ssh_config_entry.get("controlmaster") != control_master: + ssh_command.append(f"-oControlMaster={control_master}") + if ssh_config_entry.get("controlpersist") != control_persist: + ssh_command.append(f"-oControlPersist={control_persist}") - def get_output( - self, - command: str, - display=False, - warn=False, + control_path_in_config = ssh_config_entry.get("controlpath") + if ( + control_path_in_config is None + or Path(control_path_in_config).expanduser() != control_path + ): + # Only add the ControlPath arg if it is not in the config, or if it differs from + # the value in the config. + ssh_command.append(f"-oControlPath={control_path}") + ssh_command.append(hostname) + # NOTE: Not quoting the command here, `subprocess.run` does it (since shell=False). + ssh_command.append(command) + return tuple(ssh_command) + + +def control_socket_is_running(host: str, control_path: Path) -> bool: + """Check whether the control socket at the given path is running.""" + control_path = control_path.expanduser() + if not control_path.exists(): + return False + + result = subprocess.run( + ( + "ssh", + "-O", + "check", + f"-oControlPath={control_path}", + host, + ), + check=False, + capture_output=True, + text=True, + shell=False, + ) + if ( + result.returncode != 0 + or not result.stderr + or not result.stderr.startswith("Master running") + ): + logger.debug(f"{control_path=} doesn't exist or isn't running: {result=}.") + return False + return True + + +async def control_socket_is_running_async(host: str, control_path: Path) -> bool: + """Check whether the control socket at the given path is running asynchronously.""" + control_path = control_path.expanduser() + if not control_path.exists(): + return False + + result = await run_async( + ( + "ssh", + "-O", + "check", + f"-oControlPath={control_path}", + host, + ), + warn=True, hide=True, + ) + if ( + result.returncode != 0 + or not result.stderr + or not result.stderr.startswith("Master running") ): - return self.run(command, display=display, warn=warn, hide=hide).stdout.strip() + logger.debug(f"{control_path=} doesn't exist or isn't running: {result=}.") + return False + return True + + +def option_dict_to_flags(options: dict[str, str]) -> list[str]: + return [ + ( + f"--{key.removeprefix('--')}={value}" + if value is not None + else f"--{key.removeprefix('--')}" + ) + for key, value in options.items() + ] def is_already_logged_in(cluster: str, also_run_command_to_check: bool = False) -> bool: @@ -226,33 +360,40 @@ def get_controlpath_for( If `ssh_cache_dir` is not set, and the `ControlPath` option doesn't apply for that hostname, a `RuntimeError` is raised. """ - if not ssh_config_path.exists(): - raise MilatoolsUserError(f"SSH config file doesn't exist at {ssh_config_path}.") - - ssh_config = SSHConfig.from_path(str(ssh_config_path)) - values = ssh_config.lookup(cluster) - if not (control_path := values.get("controlpath")): - if ssh_cache_dir is None: - raise RuntimeError( - f"ControlPath isn't set in the ssh config for {cluster}, and " - "ssh_cache_dir isn't set." - ) - logger.debug( - f"ControlPath isn't set for host {cluster}. Falling back to the ssh cache " - f"directory at {ssh_cache_dir}." + ssh_config_values: dict[str, str] = {} + + if ssh_config_path.exists(): + ssh_config_values = SSHConfig.from_path(str(ssh_config_path)).lookup(cluster) + + if control_path := ssh_config_values.get("controlpath"): + # Controlpath is set in the SSH config. + return Path(control_path).expanduser() + + if ssh_cache_dir is None: + raise RuntimeError( + f"ControlPath isn't set in the ssh config for {cluster}, and " + "ssh_cache_dir isn't set." ) - hostname = values.get("hostname", cluster) - username = values.get("user", getpass.getuser()) - port = values.get("port", 22) - control_path = ssh_cache_dir / f"{username}@{hostname}:{port}" - return Path(control_path).expanduser() + + ssh_cache_dir = ssh_cache_dir.expanduser() + logger.debug( + f"ControlPath isn't set for host {cluster}. Falling back to the ssh cache " + f"directory at {ssh_cache_dir}." + ) + # Assume that the hostname is the same if not set. + hostname = ssh_config_values.get("hostname", cluster) + if "@" in hostname: + logger.debug(f"Username is already in the hostname: {hostname}") + return ssh_cache_dir / hostname + username = ssh_config_values.get("user", getpass.getuser()) + port = int(ssh_config_values.get("port", 22)) + return ssh_cache_dir / f"{username}@{hostname}:{port}" def setup_connection_with_controlpath( cluster: str, control_path: Path, - display: bool = True, - timeout: int | None = None, + ssh_config_path: Path = SSH_CONFIG_FILE, ) -> None: """Setup (or test) an SSH connection to this cluster using this control path. @@ -262,13 +403,10 @@ def setup_connection_with_controlpath( ---------- cluster: name of the cluster to connect to. control_path: Path to the control socket file. - display: Whether to display the command being run. - timeout: Timeout in seconds for the subprocess. Set to `None` for no timeout. + ssh_config_path: Path to the ssh config file. Raises ------ - subprocess.TimeoutExpired - If `timeout` was passed and the subprocess times out. subprocess.CalledProcessError If the subprocess call raised an error. RuntimeError @@ -277,6 +415,10 @@ def setup_connection_with_controlpath( """ raise_error_if_running_on_windows() + # if control_socket_is_running(cluster, control_path): + # logger.debug(f"Connection to {cluster} is already running.") + # return + control_path = control_path.expanduser() if not control_path.exists(): control_path.parent.mkdir(parents=True, exist_ok=True) @@ -287,6 +429,7 @@ def setup_connection_with_controlpath( control_master="auto", control_persist="yes", command=command, + ssh_config_path=ssh_config_path, ) if cluster in DRAC_CLUSTERS: console.log( @@ -311,7 +454,7 @@ def setup_connection_with_controlpath( *first_command_args, ) else: - logger.debug( + logger.warning( f"`sshpass` is not installed. If 2FA is setup on {cluster}, you might " "be asked to press 1 or enter a 2fa passcode." ) @@ -322,30 +465,21 @@ def setup_connection_with_controlpath( pass logger.info(f"Making the first connection to {cluster}...") - logger.debug(f"(local) $ {first_command_args}") - if display: - console.log(f"({cluster}) $ {command}", style="green") try: - first_connection_result = subprocess.run( - first_command_args, - shell=False, - text=True, - bufsize=1, # line buffered - timeout=timeout, - capture_output=True, - check=True, + first_connection_output = LocalV2().get_output( + command=first_command_args, + display=False, + hide="out", + warn=False, ) - first_connection_output = first_connection_result.stdout - except subprocess.TimeoutExpired as err: + except subprocess.TimeoutExpired: console.log( f"Timeout while setting up a reusable SSH connection to cluster {cluster}!" ) - raise err - except subprocess.CalledProcessError as err: - console.log( - f"Unable to setup a reusable SSH connection to cluster {cluster}!", err - ) - raise err + raise + except subprocess.CalledProcessError: + console.log(f"Unable to setup a reusable SSH connection to cluster {cluster}!") + raise if "OK" not in first_connection_output: raise RuntimeError( f"Did not receive the expected output ('OK') from {cluster}: " diff --git a/milatools/utils/runner.py b/milatools/utils/runner.py new file mode 100644 index 00000000..3596b656 --- /dev/null +++ b/milatools/utils/runner.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +import subprocess +from abc import ABC, abstractmethod + +from milatools.utils.remote_v1 import Hide + + +class Runner(ABC): + """ABC for a Runner that runs commands on a (local or remote) machine.""" + + hostname: str + """Hostname of the machine that commands are ultimately being run on.""" + + @abstractmethod + def run( + self, + command: str, + *, + input: str | None = None, + display: bool = True, + warn: bool = False, + hide: Hide = False, + ) -> subprocess.CompletedProcess[str]: + """Runs the given command on the remote and returns the result. + + This executes the command in an ssh subprocess, which, thanks to the + ControlMaster/ControlPath/ControlPersist options, will reuse the existing + connection to the remote. + + Parameters + ---------- + command: The command to run. + input: Input to pass to the program (argument to `subprocess.run`). + display: Display the command on the console before it is run. + warn: If `true` and an exception occurs, warn instead of raising the exception. + hide: Controls the printing of the subprocess' stdout and stderr. + + Returns + ------- + A `subprocess.CompletedProcess` object with the output of the subprocess. + """ + # note: Could also have a default implementation just waits on the async method: + # return asyncio.get_event_loop().run_until_complete( + # self.run_async(command, input=input, display=display, warn=warn, hide=hide) + # ) + raise NotImplementedError() + + @abstractmethod + async def run_async( + self, + command: str, + *, + input: str | None = None, + display: bool = True, + warn: bool = False, + hide: Hide = False, + ) -> subprocess.CompletedProcess[str]: + """Runs the given command asynchronously and returns the result. + + This executes the command over ssh in an asyncio subprocess, which reuses the + existing connection to the remote. + + Parameters + ---------- + command: The command to run. + input: Input to pass to the program (as if it was the 'input' argument to \ + `asyncio.subprocess.Process.communicate`). + display: Display the command on the console before it is run. + warn: If `true` and an exception occurs, warn instead of raising the exception. + hide: Controls the printing of the subprocess' stdout and stderr. + + Returns + ------- + A `subprocess.CompletedProcess` object with the output of the subprocess. + """ + raise NotImplementedError() + + def get_output( + self, + command: str, + *, + display: bool = False, + warn: bool = False, + hide: Hide = True, + ) -> str: + """Runs the command and returns the stripped output string.""" + return self.run(command, display=display, warn=warn, hide=hide).stdout.strip() + + async def get_output_async( + self, + command: str, + *, + display: bool = False, + warn: bool = False, + hide: Hide = True, + ) -> str: + """Runs the command asynchronously and returns the stripped output string.""" + return ( + await self.run_async(command, display=display, warn=warn, hide=hide) + ).stdout.strip() diff --git a/milatools/utils/vscode_utils.py b/milatools/utils/vscode_utils.py index 95479d39..36d8fb14 100644 --- a/milatools/utils/vscode_utils.py +++ b/milatools/utils/vscode_utils.py @@ -11,13 +11,13 @@ from pathlib import Path from typing import Literal, Sequence -from milatools.cli.local import Local -from milatools.cli.remote import Remote from milatools.cli.utils import ( CLUSTERS, + CommandNotFoundError, batched, stripped_lines_of, ) +from milatools.utils.local_v2 import LocalV2 from milatools.utils.parallel_progress import ( DictProxy, ProgressDict, @@ -25,6 +25,7 @@ TaskID, parallel_progress_bar, ) +from milatools.utils.remote_v1 import RemoteV1 from milatools.utils.remote_v2 import RemoteV2 logger = get_logger(__name__) @@ -60,12 +61,22 @@ def get_code_command() -> str: return os.environ.get("MILATOOLS_CODE_COMMAND", "code") -def get_vscode_executable_path() -> str | None: - return shutil.which(get_code_command()) +def get_vscode_executable_path(code_command: str | None = None) -> str: + if code_command is None: + code_command = get_code_command() + + code_command_path = shutil.which(code_command) + if not code_command_path: + raise CommandNotFoundError(code_command) + return code_command_path def vscode_installed() -> bool: - return bool(get_vscode_executable_path()) + try: + _ = get_vscode_executable_path() + except CommandNotFoundError: + return False + return True def sync_vscode_extensions_with_hostnames( @@ -78,27 +89,27 @@ def sync_vscode_extensions_with_hostnames( logger.info("Assuming you want to sync from mila to all DRAC/CC clusters.") else: logger.warning( - f"{source=} is also in the destinations to sync to. " f"Removing it." + f"{source=} is also in the destinations to sync to. Removing it." ) destinations.remove(source) if len(set(destinations)) != len(destinations): raise ValueError(f"{destinations=} contains duplicate hostnames!") - source_obj = Local() if source == "localhost" else RemoteV2(source) + source_obj = LocalV2() if source == "localhost" else RemoteV2(source) return sync_vscode_extensions(source_obj, destinations) def sync_vscode_extensions( - source: str | Local | RemoteV2, - dest_clusters: Sequence[str | Local | RemoteV2], + source: str | LocalV2 | RemoteV2, + destinations: Sequence[str | LocalV2 | RemoteV2], ): - """Syncs vscode extensions between `source` all all the clusters in `dest`. + """Syncs vscode extensions between `source` all all the destination clusters. - This spawns a thread for each cluster in `dest` and displays a parallel progress bar - for the syncing of vscode extensions to each cluster. + This spawns a thread for each cluster and displays a parallel progress bar for the + syncing of vscode extensions to each cluster. """ - if isinstance(source, Local): + if isinstance(source, LocalV2): source_hostname = "localhost" source_extensions = get_local_vscode_extensions() elif isinstance(source, RemoteV2): @@ -120,19 +131,19 @@ def sync_vscode_extensions( task_fns: list[TaskFn[ProgressDict]] = [] task_descriptions: list[str] = [] - for dest_remote in dest_clusters: + for dest_remote in destinations: dest_hostname: str if dest_remote == "localhost": dest_hostname = dest_remote # type: ignore - dest_remote = Local() # pickleable - elif isinstance(dest_remote, Local): + dest_remote = LocalV2() # pickleable + elif isinstance(dest_remote, LocalV2): dest_hostname = "localhost" dest_remote = dest_remote # again, pickleable elif isinstance(dest_remote, RemoteV2): dest_hostname = dest_remote.hostname dest_remote = dest_remote # pickleable - elif isinstance(dest_remote, Remote): + elif isinstance(dest_remote, RemoteV1): # We unfortunately can't pass this kind of object to another process or # thread because it uses `fabric.Connection` which don't appear to be # pickleable. This means we will have to re-connect in the subprocess. @@ -180,7 +191,7 @@ def install_vscode_extensions_task_function( task_id: TaskID, dest_hostname: str | Literal["localhost"], source_extensions: dict[str, str], - remote: RemoteV2 | Local | None, + remote: RemoteV2 | LocalV2 | None, source_name: str, verbose: bool = False, ) -> ProgressDict: @@ -209,15 +220,14 @@ def _update_progress( if remote is None: if dest_hostname == "localhost": - remote = Local() + remote = LocalV2() else: _update_progress(0, "Connecting...") remote = RemoteV2(dest_hostname) - if isinstance(remote, Local): + if isinstance(remote, LocalV2): assert dest_hostname == "localhost" code_server_executable = get_vscode_executable_path() - assert code_server_executable extensions_on_dest = get_local_vscode_extensions() else: dest_hostname = remote.hostname @@ -290,7 +300,7 @@ def _update_progress( def install_vscode_extension( - remote: Local | RemoteV2, + remote: LocalV2 | RemoteV2, code_server_executable: str, extension: str, verbose: bool = False, @@ -309,19 +319,20 @@ def install_vscode_extension( ) else: result = remote.run( - *command, - capture_output=not verbose, - display_command=verbose, + command, + display=verbose, + warn=True, + hide=not verbose, ) if result.stdout: logger.debug(result.stdout) return result -def get_local_vscode_extensions() -> dict[str, str]: +def get_local_vscode_extensions(code_command: str | None = None) -> dict[str, str]: output = subprocess.run( ( - get_vscode_executable_path() or get_code_command(), + get_vscode_executable_path(code_command=code_command), "--list-extensions", "--show-versions", ), @@ -334,7 +345,7 @@ def get_local_vscode_extensions() -> dict[str, str]: def get_remote_vscode_extensions( - remote: Remote | RemoteV2, + remote: RemoteV1 | RemoteV2, remote_code_server_executable: str, ) -> dict[str, str]: """Returns the list of isntalled extensions and the path to the code-server @@ -387,7 +398,7 @@ def extensions_to_install( def find_code_server_executable( - remote: Remote | RemoteV2, remote_vscode_server_dir: str = "~/.vscode-server" + remote: RemoteV1 | RemoteV2, remote_vscode_server_dir: str = "~/.vscode-server" ) -> str | None: """Find the most recent `code-server` executable on the remote. diff --git a/poetry.lock b/poetry.lock index 2c0f8342..5443f68b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -520,6 +520,20 @@ files = [ [package.extras] test = ["pytest (>=6)"] +[[package]] +name = "execnet" +version = "2.1.1" +description = "execnet: rapid multi-Python deployment" +optional = false +python-versions = ">=3.8" +files = [ + {file = "execnet-2.1.1-py3-none-any.whl", hash = "sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc"}, + {file = "execnet-2.1.1.tar.gz", hash = "sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3"}, +] + +[package.extras] +testing = ["hatch", "pre-commit", "pytest", "tox"] + [[package]] name = "fabric" version = "3.2.2" @@ -959,6 +973,24 @@ tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} [package.extras] testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-asyncio" +version = "0.23.6" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-asyncio-0.23.6.tar.gz", hash = "sha256:ffe523a89c1c222598c76856e76852b787504ddb72dd5d9b6617ffa8aa2cde5f"}, + {file = "pytest_asyncio-0.23.6-py3-none-any.whl", hash = "sha256:68516fdd1018ac57b846c9846b954f0393b26f094764a28c955eabb0536a4e8a"}, +] + +[package.dependencies] +pytest = ">=7.0.0,<9" + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] +testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] + [[package]] name = "pytest-cov" version = "4.1.0" @@ -1075,6 +1107,26 @@ files = [ [package.dependencies] pytest = ">=5.0.0" +[[package]] +name = "pytest-xdist" +version = "3.5.0" +description = "pytest xdist plugin for distributed testing, most importantly across multiple CPUs" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-xdist-3.5.0.tar.gz", hash = "sha256:cbb36f3d67e0c478baa57fa4edc8843887e0f6cfc42d677530a36d7472b32d8a"}, + {file = "pytest_xdist-3.5.0-py3-none-any.whl", hash = "sha256:d075629c7e00b611df89f490a5063944bee7a4362a5ff11c7cc7824a03dfce24"}, +] + +[package.dependencies] +execnet = ">=1.1" +pytest = ">=6.2.0" + +[package.extras] +psutil = ["psutil (>=3.0)"] +setproctitle = ["setproctitle"] +testing = ["filelock"] + [[package]] name = "pytz" version = "2023.3.post1" @@ -1573,4 +1625,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "ee3d9cbcbc70bc1585411a3ae62f2ec6d825a6d2493764e49ef36a710797ebbf" +content-hash = "1941e7cf8f32199cd373687ed0611714869671c70ec97dc81f8b223d6226708c" diff --git a/pyproject.toml b/pyproject.toml index 62c22df3..eff3b66f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,8 @@ pytest-timeout = "^2.2.0" Sphinx = "^5.0.1" sphinx-rtd-theme = "^1.0.0" toml = "^0.10.0" +pytest-asyncio = "^0.23.6" +pytest-xdist = "^3.5.0" [tool.pytest.ini_options] diff --git a/tests/cli/common.py b/tests/cli/common.py index dc679f56..477fc09c 100644 --- a/tests/cli/common.py +++ b/tests/cli/common.py @@ -7,6 +7,7 @@ import sys import typing from collections.abc import Callable +from logging import getLogger as get_logger from subprocess import CompletedProcess from typing import Any @@ -16,12 +17,14 @@ from pytest_regressions.file_regression import FileRegressionFixture from typing_extensions import ParamSpec -from milatools.cli.utils import MilatoolsUserError, removesuffix -from milatools.utils.remote_v2 import RemoteV2 +from milatools.cli.utils import SSH_CACHE_DIR, SSH_CONFIG_FILE, removesuffix +from milatools.utils.remote_v2 import RemoteV2, get_controlpath_for if typing.TYPE_CHECKING: from typing_extensions import TypeGuard +logger = get_logger(__name__) + in_github_CI = os.environ.get("GITHUB_ACTIONS") == "true" """True if this is being run inside the GitHub CI.""" @@ -38,25 +41,36 @@ skip_param_if_on_github_ci = functools.partial(pytest.param, marks=skip_if_on_github_CI) -passwordless_ssh_connection_to_localhost_is_setup = False +def ssh_to_localhost_is_setup() -> bool: + SSH_CACHE_DIR.mkdir(parents=True, exist_ok=True) + control_path = get_controlpath_for( + "localhost", + ssh_config_path=SSH_CONFIG_FILE, + ssh_cache_dir=SSH_CACHE_DIR, + ) + if sys.platform != "win32": + try: + _localhost_remote = RemoteV2("localhost", control_path=control_path) + except ( + subprocess.CalledProcessError, + subprocess.TimeoutExpired, + ) as err: + logger.error(f"SSH connection to localhost is not setup: {err}") + return False + return True -try: - localhost_remote = RemoteV2("localhost") -except ( - subprocess.CalledProcessError, - subprocess.TimeoutExpired, - RuntimeError, - MilatoolsUserError, -): try: - connection = fabric.Connection("localhost") + _connection = fabric.Connection("localhost") + _connection.open() except ( paramiko.ssh_exception.SSHException, paramiko.ssh_exception.NoValidConnectionsError, ): - passwordless_ssh_connection_to_localhost_is_setup = True -else: - passwordless_ssh_connection_to_localhost_is_setup = True + return False + return True + + +passwordless_ssh_connection_to_localhost_is_setup = ssh_to_localhost_is_setup() requires_ssh_to_localhost = pytest.mark.skipif( not passwordless_ssh_connection_to_localhost_is_setup, diff --git a/tests/cli/test_commands.py b/tests/cli/test_commands.py index 98e3a40a..20e2e48f 100644 --- a/tests/cli/test_commands.py +++ b/tests/cli/test_commands.py @@ -8,7 +8,8 @@ import pytest from pytest_regressions.file_regression import FileRegressionFixture -from milatools.cli.commands import _parse_lfs_quota_output, main +from milatools.cli.commands import main +from milatools.cli.common import _parse_lfs_quota_output from .common import requires_no_s_flag diff --git a/tests/cli/test_init_command.py b/tests/cli/test_init_command.py index d404a8d3..9c490b3a 100644 --- a/tests/cli/test_init_command.py +++ b/tests/cli/test_init_command.py @@ -37,12 +37,12 @@ setup_vscode_settings, setup_windows_ssh_config_from_wsl, ) -from milatools.cli.local import Local, check_passwordless -from milatools.cli.remote import Remote from milatools.cli.utils import ( SSHConfig, running_inside_WSL, ) +from milatools.utils.local_v1 import LocalV1, check_passwordless +from milatools.utils.remote_v1 import RemoteV1 from milatools.utils.remote_v2 import ( SSH_CACHE_DIR, SSH_CONFIG_FILE, @@ -101,7 +101,7 @@ def input_pipe(monkeypatch: pytest.MonkeyPatch, request: pytest.FixtureRequest): For confirmation prompts, just send one letter, otherwise the '\r' is passed to the next prompt, which sees it as just pressing enter, which uses the default value. """ - request.node.add_marker(raises_NoConsoleScreenBufferError_on_windows_ci_action()) + request.applymarker(raises_NoConsoleScreenBufferError_on_windows_ci_action()) with create_pipe_input() as input_pipe: monkeypatch.setattr( "questionary.confirm", @@ -1100,7 +1100,9 @@ def copy_fn(source, dest): @contextlib.contextmanager def backup_remote_dir( - remote: RemoteV2 | Remote, directory: PurePosixPath, backup_directory: PurePosixPath + remote: RemoteV2 | RemoteV1, + directory: PurePosixPath, + backup_directory: PurePosixPath, ): # IDEA: Make the equivalent function, but that backs up a directory on a remote # machine. @@ -1193,7 +1195,7 @@ def backup_local_ssh_cache_dir(): @pytest.fixture -def backup_remote_ssh_dir(login_node: RemoteV2 | Remote, cluster: str): +def backup_remote_ssh_dir(login_node: RemoteV2 | RemoteV1, cluster: str): """Creates a backup of the ~/.ssh directory on the remote cluster.""" if USE_MY_REAL_SSH_DIR: logger.critical( @@ -1234,7 +1236,7 @@ def backup_remote_ssh_dir(login_node: RemoteV2 | Remote, cluster: str): ) def test_setup_passwordless_ssh_access_to_cluster( cluster: str, - login_node: Remote | RemoteV2, + login_node: RemoteV1 | RemoteV2, input_pipe: PipeInput, backup_local_ssh_dir: Path, backup_local_ssh_cache_dir: Path, @@ -1538,7 +1540,7 @@ def test_setup_passwordless_ssh_access( # There should be an ssh key in the .ssh dir. # Won't ask to generate a key. create_ssh_keypair( - ssh_private_key_path=ssh_dir / "id_rsa_milatools", local=Local() + ssh_private_key_path=ssh_dir / "id_rsa_milatools", local=LocalV1() ) if drac_clusters_in_ssh_config: # We should get a prompt asking if we want to register the public key diff --git a/tests/cli/test_utils.py b/tests/cli/test_utils.py index ebfcf0af..f755d467 100644 --- a/tests/cli/test_utils.py +++ b/tests/cli/test_utils.py @@ -1,20 +1,28 @@ import functools import multiprocessing import random -from unittest.mock import patch +from pathlib import Path +from unittest.mock import Mock, patch import pytest +import questionary from prompt_toolkit.input.defaults import create_pipe_input +import milatools +import milatools.cli +import milatools.cli.init_command +from milatools.cli.init_command import setup_ssh_config from milatools.cli.utils import ( - get_fully_qualified_hostname_of_compute_node, get_fully_qualified_name, + get_hostname_to_use_for_compute_node, make_process, qn, randname, yn, ) +from .test_init_command import PipeInput, input_pipe # noqa + def test_randname(file_regression): random.seed(0) @@ -40,6 +48,64 @@ def test_hostname(): assert get_fully_qualified_name() +@pytest.fixture() +def ssh_config_file( + tmp_path_factory: pytest.TempPathFactory, + monkeypatch: pytest.MonkeyPatch, +) -> Path: + """Fixture that creates the SSH config as setup by `mila init`.""" + from milatools.cli.init_command import yn + + # NOTE: might want to put this in a fixture if we wanted the "real" mila / drac + # usernames in the config. + mila_username = drac_username = "bob" + + ssh_config_path = tmp_path_factory.mktemp(".ssh") / "ssh_config" + + def _yn(question: str) -> bool: + question = question.strip() + known_questions = { + f"There is no {ssh_config_path} file. Create one?": True, + "Do you also have an account on the ComputeCanada/DRAC clusters?": True, + "Is this OK?": True, + } + if question in known_questions: + return known_questions[question] + raise NotImplementedError(f"Unexpected question: {question}") + + mock_yn = Mock(spec=yn, side_effect=_yn) + monkeypatch.setattr(milatools.cli.init_command, yn.__name__, mock_yn) + + def _mock_unsafe_ask(question: str, *args, **kwargs) -> str: + question = question.strip() + known_questions = { + "What's your username on the mila cluster?": mila_username, + "What's your username on the CC/DRAC clusters?": drac_username, + } + if question in known_questions: + return known_questions[question] + raise NotImplementedError(f"Unexpected question: {question}") + + def _mock_text(message: str, *args, **kwargs): + return Mock( + spec=questionary.Question, + unsafe_ask=Mock( + spec=questionary.Question.unsafe_ask, + side_effect=functools.partial(_mock_unsafe_ask, message), + ), + ) + + mock_text = Mock( + spec=questionary.text, + side_effect=_mock_text, + ) + monkeypatch.setattr(questionary, questionary.text.__name__, mock_text) + + setup_ssh_config(ssh_config_path) + assert ssh_config_path.exists() + return ssh_config_path + + @pytest.mark.parametrize( ("cluster_name", "node", "expected"), [ @@ -54,7 +120,7 @@ def test_hostname(): + [ # Host !cedar cdr? cdr?? cdr??? cdr???? ("cedar", cnode, cnode) - for n in range(5) + for n in range(1, 5) for cnode in [f"cdr{'0' * n}"] ] + [ @@ -77,23 +143,30 @@ def test_hostname(): # for cnode in ["nia1234"] # ], ) -def test_get_fully_qualified_hostname_of_compute_node( - cluster_name: str, node: str, expected: str +def test_get_hostname_to_use_for_compute_node( + cluster_name: str, + node: str, + expected: str, + ssh_config_file: Path, ): assert ( - get_fully_qualified_hostname_of_compute_node( - node_name=node, cluster=cluster_name + get_hostname_to_use_for_compute_node( + node_name=node, cluster=cluster_name, ssh_config_path=ssh_config_file ) == expected ) -def test_get_fully_qualified_hostname_of_compute_node_unknown_cluster(): +def test_get_fully_qualified_hostname_of_compute_node_unknown_cluster( + ssh_config_file: Path, +): node_name = "some-node" with pytest.warns(UserWarning): assert ( - get_fully_qualified_hostname_of_compute_node( - node_name=node_name, cluster="unknown-cluster" + get_hostname_to_use_for_compute_node( + node_name=node_name, + cluster="unknown-cluster", + ssh_config_path=ssh_config_file, ) == node_name ) diff --git a/tests/conftest.py b/tests/conftest.py index b00a377b..56cc52db 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,11 @@ from __future__ import annotations import contextlib +import datetime +import functools +import os import shutil +import subprocess import sys import time from collections.abc import Generator @@ -9,32 +13,40 @@ from pathlib import Path from unittest.mock import Mock -import paramiko.ssh_exception import pytest +import pytest_asyncio from fabric.connection import Connection from milatools.cli import console from milatools.cli.init_command import DRAC_CLUSTERS -from milatools.cli.remote import Remote +from milatools.utils.compute_node import get_queued_milatools_job_ids +from milatools.utils.remote_v1 import RemoteV1 from milatools.utils.remote_v2 import ( RemoteV2, + UnsupportedPlatformError, get_controlpath_for, is_already_logged_in, ) -from tests.integration.conftest import SLURM_CLUSTER + +from .cli.common import ( + in_github_CI, + passwordless_ssh_connection_to_localhost_is_setup, + skip_if_on_github_CI, + xfails_on_windows, +) +from .integration.conftest import ( + JOB_NAME, + MAX_JOB_DURATION, + SLURM_CLUSTER, + WCKEY, + skip_if_not_already_logged_in, +) logger = get_logger(__name__) -passwordless_ssh_connection_to_localhost_is_setup = False +unsupported_on_windows = xfails_on_windows(raises=UnsupportedPlatformError, strict=True) -try: - Connection("localhost").open() -except ( - paramiko.ssh_exception.SSHException, - paramiko.ssh_exception.NoValidConnectionsError, -): - pass -else: - passwordless_ssh_connection_to_localhost_is_setup = True + +pytest.register_assert_rewrite("tests.utils.runner_tests") @pytest.fixture( @@ -90,9 +102,9 @@ def MockConnection( __repr__=lambda _: f"Connection({repr(host)})", ), ) - import milatools.cli.remote + import milatools.utils.remote_v1 - monkeypatch.setattr(milatools.cli.remote, Connection.__name__, MockConnection) + monkeypatch.setattr(milatools.utils.remote_v1, Connection.__name__, MockConnection) return MockConnection @@ -111,11 +123,11 @@ def mock_connection( @pytest.fixture(scope="function") def remote(mock_connection: Connection): assert isinstance(mock_connection.host, str) - return Remote(hostname=mock_connection.host, connection=mock_connection) + return RemoteV1(hostname=mock_connection.host, connection=mock_connection) @pytest.fixture(scope="function") -def login_node(cluster: str) -> Remote | RemoteV2: +def login_node(cluster: str) -> RemoteV1 | RemoteV2: """Fixture that gives a Remote connected to the login node of a slurm cluster. NOTE: Making this a function-scoped fixture because the Connection object of the @@ -132,11 +144,82 @@ def login_node(cluster: str) -> Remote | RemoteV2: "prior connection to the cluster." ) if sys.platform == "win32": - return Remote(cluster) + return RemoteV1(cluster) return RemoteV2(cluster) -@pytest.fixture(scope="session", params=[SLURM_CLUSTER]) +@pytest.fixture(scope="session") +def login_node_v2(cluster: str) -> RemoteV2: + if sys.platform == "win32": + pytest.skip("Test uses RemoteV2.") + if cluster not in ["mila", "localhost"] and not is_already_logged_in(cluster): + pytest.skip( + f"Requires ssh access to the login node of the {cluster} cluster, and a " + "prior connection to the cluster." + ) + return RemoteV2(cluster) + + +from .integration.conftest import ( # noqa: E402 + hangs_in_github_CI, +) + + +def _drac_cluster_param(hostname: str): + return pytest.param( + hostname, + marks=[ + pytest.mark.slow, + pytest.mark.xdist_group(name=hostname), + skip_if_not_already_logged_in(hostname), + pytest.mark.xfail( + hostname in clusters_under_maintenance, + reason=f"{hostname} cluster is down for maintenance.", + raises=subprocess.CalledProcessError, + ), + ], + ) + + +clusters_under_maintenance = os.environ.get("CLUSTER_DOWN", "").split(",") + +_cluster_params = [ + pytest.param( + "localhost", + marks=[ + pytest.mark.xdist_group(name="localhost"), + pytest.mark.skipif( + not (in_github_CI and SLURM_CLUSTER == "localhost"), + reason=( + "Only runs in the GitHub CI when localhost is a slurm cluster." + ), + ), + # todo: remove this mark once we're able to do sbatch and salloc in the + # GitHub CI. + hangs_in_github_CI, + ], + ), + pytest.param( + "mila", + marks=[ + pytest.mark.xdist_group(name="mila"), + skip_if_on_github_CI, + pytest.mark.xfail( + "mila" in clusters_under_maintenance, + reason="mila cluster is down for maintenance.", + raises=subprocess.CalledProcessError, + ), + ], + ), + _drac_cluster_param("narval"), + _drac_cluster_param("beluga"), + _drac_cluster_param("cedar"), + _drac_cluster_param("graham"), + _drac_cluster_param("niagara"), +] + ([SLURM_CLUSTER] if SLURM_CLUSTER not in ["mila", "localhost"] else []) + + +@pytest.fixture(scope="session", params=_cluster_params) def cluster(request: pytest.FixtureRequest) -> str: """Fixture that gives the hostname of the slurm cluster to use for tests. @@ -149,32 +232,72 @@ def test_something(remote: Remote): ``` """ - slurm_cluster_hostname = request.param - - if not slurm_cluster_hostname: + cluster_name = request.param + if not cluster_name: pytest.skip("Requires ssh access to a SLURM cluster.") # TODO: Re-enable this, but only on tests that say that they run jobs on the # cluster. # with cancel_all_milatools_jobs_before_and_after_tests(slurm_cluster_hostname): - return slurm_cluster_hostname + clusters_in_maintenance = os.environ.get("CLUSTER_DOWN", "").split(",") + if cluster_name in clusters_in_maintenance: + pytest.skip(reason=f"Cluster {cluster_name} is down for maintenance.") + # NOTE: Seems not possible to add this marker to all tests?. + # request.node.add_marker( + # pytest.mark.xfail( + # reason=f"Cluster {cluster_name} is down for maintenance.", + # raises=subprocess.CalledProcessError, + # ) + # ) + return cluster_name + + +@pytest.fixture(scope="session") +def job_name(request: pytest.FixtureRequest) -> str | None: + return getattr(request, "param", JOB_NAME) + + +@pytest_asyncio.fixture +async def launches_job_fixture(login_node_v2: RemoteV2, job_name: str): + jobs_before = await get_queued_milatools_job_ids(login_node_v2, job_name=job_name) + if jobs_before: + logger.info(f"Jobs in squeue before test: {jobs_before}") + try: + yield + finally: + jobs_after = await get_queued_milatools_job_ids( + login_node_v2, job_name=job_name + ) + if jobs_before: + logger.info(f"Jobs after test: {jobs_before}") + + new_jobs = jobs_after - jobs_before + if new_jobs: + login_node_v2.run( + "scancel " + " ".join(str(job_id) for job_id in new_jobs), display=True + ) + else: + logger.debug("Test apparently didn't launch any new jobs.") + + +launches_jobs = pytest.mark.usefixtures(launches_job_fixture.__name__) @contextlib.contextmanager -def cancel_all_milatools_jobs_before_and_after_tests(login_node: Remote | RemoteV2): +def cancel_all_milatools_jobs_before_and_after_tests(login_node: RemoteV1 | RemoteV2): from .integration.conftest import WCKEY logger.info( f"Cancelling milatools test jobs on {cluster} before running integration tests." ) - login_node.run(f"scancel -u $USER --wckey={WCKEY}") + login_node.run(f"scancel -u $USER --wckey={WCKEY}", display=False, hide=True) time.sleep(1) # Note: need to recreate this because login_node is a function-scoped fixture. yield logger.info( f"Cancelling milatools test jobs on {cluster} after running integration tests." ) - login_node.run(f"scancel -u $USER --wckey={WCKEY}") + login_node.run(f"scancel -u $USER --wckey={WCKEY}", display=False, hide=True) time.sleep(1) # Display the output of squeue just to be sure that the jobs were cancelled. logger.info(f"Checking that all jobs have been cancelked on {cluster}...") @@ -249,3 +372,95 @@ def already_logged_in( if control_path.exists(): control_path.unlink() shutil.move(moved_path, control_path) + + +@functools.lru_cache +def get_slurm_account(cluster: str) -> str: + """Gets the SLURM account of the user using sacctmgr on the slurm cluster. + + When there are multiple accounts, this selects the first account, alphabetically. + + On DRAC cluster, this uses the `def` allocations instead of `rrg`, and when + the rest of the accounts are the same up to a '_cpu' or '_gpu' suffix, it uses + '_cpu'. + + For example: + + ```text + def-someprofessor_cpu <-- this one is used. + def-someprofessor_gpu + rrg-someprofessor_cpu + rrg-someprofessor_gpu + ``` + """ + logger.info( + f"Fetching the list of SLURM accounts available on the {cluster} cluster." + ) + assert cluster in ["mila", "localhost"] or is_already_logged_in(cluster) + result = RemoteV2(cluster).run( + "sacctmgr --noheader show associations where user=$USER format=Account%50" + ) + accounts = [line.strip() for line in result.stdout.splitlines()] + assert accounts + logger.info(f"Accounts on the slurm cluster {cluster}: {accounts}") + account = sorted(accounts)[0] + logger.info(f"Using account {account} to launch jobs in tests.") + return account + + +@pytest.fixture(scope="session") +def slurm_account_on_cluster(cluster: str) -> str: + if cluster not in ["mila", "localhost"] and not is_already_logged_in(cluster): + # avoid test hanging on 2FA prompt. + pytest.skip(reason=f"Test needs an existing connection to {cluster} to run.") + return get_slurm_account(cluster) + + +@pytest.fixture(scope="session") +def max_job_duration( + request: pytest.FixtureRequest, cluster: str +) -> datetime.timedelta: + """Fixture that allows test to parametrize the duration of their jobs.""" + return getattr(request, "param", MAX_JOB_DURATION) + + +@pytest.fixture(scope="session") +def allocation_flags( + request: pytest.FixtureRequest, + slurm_account_on_cluster: str, + job_name: str | None, + max_job_duration: datetime.timedelta, +) -> list[str]: + """Flags passed to salloc or sbatch during tests. + + When parametrized, overrides individual flags: + ```python + @pytest.mark.parametrize("allocation_flags", [{"some_flag": "some_value"}], indirect=True) + def some_test(allocation_flags: list[str]) + assert "--some_flag=some_value" in allocation_flags + ``` + """ + default_allocation_options = { + "wckey": WCKEY, + "account": slurm_account_on_cluster, + "nodes": 1, + "ntasks": 1, + "cpus-per-task": 1, + "mem": "1G", + "time": max_job_duration, + "oversubscribe": None, # allow multiple such jobs to share resources. + } + if job_name is not None: + # Only set the job name when needed. For example, `mila code` tests don't want + # it to be set. + default_allocation_options["job-name"] = job_name + overrides = getattr(request, "param", {}) + assert isinstance(overrides, dict) + if overrides: + print(f"Overriding allocation options with {overrides}") + default_allocation_options = default_allocation_options.copy() + default_allocation_options.update(overrides) + return [ + f"--{key}={value}" if value is not None else f"--{key}" + for key, value in default_allocation_options.items() + ] diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index f709cf24..e343f00e 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -1,19 +1,14 @@ from __future__ import annotations import datetime -import functools import os import sys from logging import getLogger as get_logger import pytest -from milatools.cli.remote import Remote -from milatools.utils.remote_v2 import ( - SSH_CONFIG_FILE, - RemoteV2, - is_already_logged_in, -) +from milatools.cli.utils import SSH_CONFIG_FILE +from milatools.utils.remote_v2 import is_already_logged_in from tests.cli.common import in_github_CI, in_self_hosted_github_CI logger = get_logger(__name__) @@ -29,7 +24,7 @@ `None` when running on the github CI. """ -MAX_JOB_DURATION = datetime.timedelta(seconds=10) +MAX_JOB_DURATION = datetime.timedelta(minutes=5) hangs_in_github_CI = pytest.mark.skipif( SLURM_CLUSTER == "localhost", @@ -69,73 +64,3 @@ def skip_param_if_not_already_logged_in(cluster: str): skip_if_not_already_logged_in(cluster), ], ) - - -@functools.lru_cache -def get_slurm_account(cluster: str) -> str: - """Gets the SLURM account of the user using sacctmgr on the slurm cluster. - - When there are multiple accounts, this selects the first account, alphabetically. - - On DRAC cluster, this uses the `def` allocations instead of `rrg`, and when - the rest of the accounts are the same up to a '_cpu' or '_gpu' suffix, it uses - '_cpu'. - - For example: - - ```text - def-someprofessor_cpu <-- this one is used. - def-someprofessor_gpu - rrg-someprofessor_cpu - rrg-someprofessor_gpu - ``` - """ - logger.info( - f"Fetching the list of SLURM accounts available on the {cluster} cluster." - ) - if sys.platform == "win32": - result = Remote(cluster).run( - "sacctmgr --noheader show associations where user=$USER format=Account%50" - ) - else: - result = RemoteV2(cluster).run( - "sacctmgr --noheader show associations where user=$USER format=Account%50" - ) - accounts = [line.strip() for line in result.stdout.splitlines()] - assert accounts - logger.info(f"Accounts on the slurm cluster {cluster}: {accounts}") - account = sorted(accounts)[0] - logger.info(f"Using account {account} to launch jobs in tests.") - return account - - -@pytest.fixture(scope="session") -def slurm_account(cluster: str): - return get_slurm_account(cluster) - - -@pytest.fixture() -def allocation_flags( - cluster: str, slurm_account: str, request: pytest.FixtureRequest -) -> list[str]: - # note: thanks to lru_cache, this is only making one ssh connection per cluster. - allocation_options = { - "job-name": JOB_NAME, - "wckey": WCKEY, - "account": slurm_account, - "nodes": 1, - "ntasks": 1, - "cpus-per-task": 1, - "mem": "1G", - "time": MAX_JOB_DURATION, - "oversubscribe": None, # allow multiple such jobs to share resources. - } - overrides = getattr(request, "param", {}) - assert isinstance(overrides, dict) - if overrides: - print(f"Overriding allocation options with {overrides}") - allocation_options.update(overrides) - return [ - f"--{key}={value}" if value is not None else f"--{key}" - for key, value in allocation_options.items() - ] diff --git a/tests/integration/test_code_command.py b/tests/integration/test_code_command.py index 0f42c712..7a0dc38c 100644 --- a/tests/integration/test_code_command.py +++ b/tests/integration/test_code_command.py @@ -9,9 +9,10 @@ import pytest -from milatools.cli.commands import check_disk_quota, code -from milatools.cli.remote import Remote -from milatools.cli.utils import get_fully_qualified_hostname_of_compute_node +from milatools.cli.code_command import code +from milatools.cli.common import check_disk_quota +from milatools.cli.utils import get_hostname_to_use_for_compute_node +from milatools.utils.remote_v1 import RemoteV1 from milatools.utils.remote_v2 import RemoteV2 from ..cli.common import in_github_CI, skip_param_if_on_github_ci @@ -26,6 +27,7 @@ logger = get_logger(__name__) +@pytest.mark.slow @pytest.mark.parametrize( "cluster", [ @@ -49,7 +51,7 @@ indirect=True, ) def test_check_disk_quota( - login_node: Remote | RemoteV2, + login_node: RemoteV1 | RemoteV2, capsys: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture, ): # noqa: F811 @@ -60,6 +62,7 @@ def test_check_disk_quota( # IF the quota is met, then a `MilatoolsUserError` is logged. +@pytest.mark.slow @PARAMIKO_SSH_BANNER_BUG @pytest.mark.parametrize( "cluster", @@ -90,7 +93,7 @@ def test_check_disk_quota( ) @pytest.mark.parametrize("persist", [True, False]) def test_code( - login_node: Remote | RemoteV2, + login_node: RemoteV1 | RemoteV2, persist: bool, capsys: pytest.CaptureFixture, allocation_flags: list[str], @@ -135,7 +138,7 @@ def test_code( job_info = job_id_to_job_info[job_id] node = job_info["Node"] - node_hostname = get_fully_qualified_hostname_of_compute_node( + node_hostname = get_hostname_to_use_for_compute_node( node, cluster=login_node.hostname ) expected_line = f"(local) $ /usr/bin/echo -nw --remote ssh-remote+{node_hostname} {home}/{relative_path}" diff --git a/tests/integration/test_slurm_remote.py b/tests/integration/test_slurm_remote.py index 6e84f35f..7bc3442c 100644 --- a/tests/integration/test_slurm_remote.py +++ b/tests/integration/test_slurm_remote.py @@ -17,11 +17,12 @@ import milatools import milatools.cli import milatools.cli.utils -from milatools.cli.remote import Remote, SlurmRemote from milatools.cli.utils import CLUSTERS +from milatools.utils.remote_v1 import RemoteV1, SlurmRemote from milatools.utils.remote_v2 import RemoteV2 from ..cli.common import on_windows +from ..conftest import launches_jobs from .conftest import JOB_NAME, MAX_JOB_DURATION, SLURM_CLUSTER, hangs_in_github_CI logger = get_logger(__name__) @@ -49,7 +50,7 @@ def can_run_on_all_clusters(): def get_recent_jobs_info_dicts( - login_node: Remote | RemoteV2, + login_node: RemoteV1 | RemoteV2, since=datetime.timedelta(minutes=5), fields=("JobID", "JobName", "Node", "State"), ) -> list[dict[str, str]]: @@ -60,7 +61,7 @@ def get_recent_jobs_info_dicts( def get_recent_jobs_info( - login_node: Remote | RemoteV2, + login_node: RemoteV1 | RemoteV2, since=datetime.timedelta(minutes=5), fields=("JobID", "JobName", "Node", "State"), ) -> list[tuple[str, ...]]: @@ -68,13 +69,14 @@ def get_recent_jobs_info( # otherwise this would launch a job! assert not isinstance(login_node, SlurmRemote) lines = login_node.run( - f"sacct --noheader --allocations " + f"sacct --noheader --allocations --user=$USER " f"--starttime=now-{int(since.total_seconds())}seconds " "--format=" + ",".join(f"{field}%40" for field in fields), - display=True, + display=False, + hide=True, ).stdout.splitlines() # note: using maxsplit because the State field can contain spaces: "canceled by ..." - return [tuple(line.strip().split(maxsplit=len(fields))) for line in lines] + return [tuple(line.strip().split(maxsplit=len(fields) - 1)) for line in lines] def sleep_so_sacct_can_update(): @@ -82,37 +84,11 @@ def sleep_so_sacct_can_update(): time.sleep(_SACCT_UPDATE_DELAY.total_seconds()) -@requires_access_to_slurm_cluster -def test_cluster_setup(login_node: Remote | RemoteV2, allocation_flags: list[str]): - """Sanity Checks for the SLURM cluster of the CI: checks that `srun` works. - - NOTE: This is more-so a test to check that the slurm cluster used in the GitHub CI - is setup correctly, rather than to check that the Remote/SlurmRemote work correctly. - """ - - job_id, compute_node = ( - login_node.get_output( - f"srun {' '.join(allocation_flags)} bash -c 'echo $SLURM_JOB_ID $SLURMD_NODENAME'" - ) - .strip() - .split() - ) - assert compute_node - assert job_id.isdigit() - - sleep_so_sacct_can_update() - - # NOTE: the job should be done by now, since `.run` of the Remote is called with - # asynchronous=False. - sacct_output = get_recent_jobs_info(login_node, fields=("JobID", "JobName", "Node")) - assert (job_id, JOB_NAME, compute_node) in sacct_output - - @pytest.fixture def fabric_connection_to_login_node( - login_node: Remote | RemoteV2, request: pytest.FixtureRequest + login_node: RemoteV1 | RemoteV2, request: pytest.FixtureRequest ): - if isinstance(login_node, Remote): + if isinstance(login_node, RemoteV1): return login_node.connection if login_node.hostname not in ["localhost", "mila"]: @@ -122,7 +98,7 @@ def fabric_connection_to_login_node( f"might go through 2FA!" ) ) - return Remote(login_node.hostname).connection + return RemoteV1(login_node.hostname).connection @pytest.fixture @@ -163,10 +139,12 @@ def sbatch_slurm_remote( ) +@pytest.mark.slow @PARAMIKO_SSH_BANNER_BUG +@launches_jobs @requires_access_to_slurm_cluster def test_run( - login_node: Remote | RemoteV2, + login_node: RemoteV1 | RemoteV2, salloc_slurm_remote: SlurmRemote, ): """Test for `SlurmRemote.run` with persist=False without an initial call to @@ -206,11 +184,14 @@ def test_run( assert (job_id, JOB_NAME, compute_node) in sacct_output +@pytest.mark.skip(reason="The way this test checks if the job ran is brittle.") +@pytest.mark.slow @PARAMIKO_SSH_BANNER_BUG +@launches_jobs @hangs_in_github_CI @requires_access_to_slurm_cluster def test_ensure_allocation( - login_node: Remote | RemoteV2, + login_node: RemoteV1 | RemoteV2, salloc_slurm_remote: SlurmRemote, capsys: pytest.CaptureFixture[str], ): @@ -266,44 +247,19 @@ def test_ensure_allocation( if isinstance(login_node, RemoteV2) and login_node.hostname == "mila": assert hostname_from_remote_runner.startswith("login-") assert hostname_from_login_node_runner.startswith("login-") - elif isinstance(login_node, Remote): + elif isinstance(login_node, RemoteV1): assert hostname_from_remote_runner == hostname_from_login_node_runner - # TODO: IF the remote runner was to be connected to the compute node through the - # same interactive terminal, then we'd use this: - # result = remote_runner.run( - # "echo $SLURM_JOB_ID $SLURMD_NODENAME", - # echo=True, - # echo_format=T.bold_cyan( - # f"({compute_node_from_salloc_output})" + " $ {command}" - # ), - # in_stream=False, - # ) - # assert result - # assert not result.stderr - # assert result.stdout.strip() - # job_id, compute_node = result.stdout.strip().split() - # # cn-a001 vs cn-a001.server.mila.quebec for example. - # assert compute_node.startswith(compute_node_from_salloc_output) - # assert compute_node != login_node.hostname # hopefully also works in CI... - - # NOTE: too brittle. - # if datetime.datetime.now() - start_time < MAX_JOB_DURATION: - # # Check that the job shows up as still running in the output of `sacct`, since - # # we should not have reached the end time yet. - # sacct_output = get_recent_jobs_info( - # login_node, fields=("JobName", "Node", "State") - # ) - # assert [JOB_NAME, compute_node_from_salloc_output, "RUNNING"] in sacct_output - print(f"Sleeping for {MAX_JOB_DURATION.total_seconds()}s until job finishes...") time.sleep(MAX_JOB_DURATION.total_seconds()) - sacct_output = get_recent_jobs_info(login_node, fields=("JobName", "Node", "State")) - assert (JOB_NAME, compute_node_from_salloc_output, "COMPLETED") in sacct_output + sacct_output = get_recent_jobs_info(login_node, fields=("JobName", "Node")) + assert (JOB_NAME, compute_node_from_salloc_output) in sacct_output +@pytest.mark.slow @PARAMIKO_SSH_BANNER_BUG +@launches_jobs @pytest.mark.xfail( on_windows, raises=PermissionError, @@ -312,7 +268,7 @@ def test_ensure_allocation( @hangs_in_github_CI @requires_access_to_slurm_cluster def test_ensure_allocation_sbatch( - login_node: Remote | RemoteV2, sbatch_slurm_remote: SlurmRemote + login_node: RemoteV1 | RemoteV2, sbatch_slurm_remote: SlurmRemote ): job_data, login_node_remote_runner = sbatch_slurm_remote.ensure_allocation() print(job_data, login_node_remote_runner) diff --git a/tests/integration/test_sync_command.py b/tests/integration/test_sync_command.py index 6d9a926f..be7aa2fb 100644 --- a/tests/integration/test_sync_command.py +++ b/tests/integration/test_sync_command.py @@ -8,7 +8,7 @@ import pytest from typing_extensions import ParamSpec -from milatools.cli.local import Local +from milatools.utils.local_v2 import LocalV2 from milatools.utils.remote_v2 import RemoteV2 from milatools.utils.vscode_utils import ( extensions_to_install, @@ -16,7 +16,6 @@ install_vscode_extensions_task_function, sync_vscode_extensions, ) -from tests.integration.conftest import skip_param_if_not_already_logged_in from ..cli.common import ( requires_ssh_to_localhost, @@ -30,31 +29,27 @@ "source", [ pytest.param("localhost", marks=requires_ssh_to_localhost), - skip_param_if_not_already_logged_in("mila"), - skip_param_if_not_already_logged_in("narval"), - skip_param_if_not_already_logged_in("beluga"), - skip_param_if_not_already_logged_in("cedar"), - skip_param_if_not_already_logged_in("graham"), - skip_param_if_not_already_logged_in("niagara"), + "cluster", ], ) @pytest.mark.parametrize( "dest", [ pytest.param("localhost", marks=requires_ssh_to_localhost), - skip_param_if_not_already_logged_in("mila"), - skip_param_if_not_already_logged_in("narval"), - skip_param_if_not_already_logged_in("beluga"), - skip_param_if_not_already_logged_in("cedar"), - skip_param_if_not_already_logged_in("graham"), - skip_param_if_not_already_logged_in("niagara"), + "cluster", ], ) def test_sync_vscode_extensions( source: str, dest: str, + cluster: str, monkeypatch: pytest.MonkeyPatch, ): + if source == "cluster": + source = cluster + if dest == "cluster": + dest = cluster + if source == dest: pytest.skip("Source and destination are the same.") @@ -80,10 +75,9 @@ def _mock_and_patch( ) sync_vscode_extensions( - source=Local() if source == "localhost" else RemoteV2(source), - dest_clusters=[dest], + source=LocalV2() if source == "localhost" else RemoteV2(source), + destinations=[dest], ) - mock_task_function.assert_called_once() mock_extensions_to_install.assert_called_once() if source == "localhost": diff --git a/tests/utils/runner_tests.py b/tests/utils/runner_tests.py new file mode 100644 index 00000000..da2dba89 --- /dev/null +++ b/tests/utils/runner_tests.py @@ -0,0 +1,254 @@ +from __future__ import annotations + +import abc +import asyncio +import logging +import re +import subprocess +import time + +import pytest + +from milatools.utils.remote_v2 import RemoteV2 +from milatools.utils.runner import Runner + + +class RunnerTests(abc.ABC): + """Tests for a `Runner` implementation. + + Subclasses have to implement these methods: + - the `runner` fixture which should ideally be class or session-scoped; + - the `command_output_err` fixture should return a tuple containing 3 items: + - The command to run successfully + - the expected stdout or a `re.Pattern` to match against the stdout + - the expected stderr or a `re.Pattern` to match against the stderr + - the `command_exception_err` fixture should return a tuple containing 3 items: + (The command to run uncessfully, the expected exception, the expected stderr). + """ + + @abc.abstractmethod + @pytest.fixture(scope="class") + def runner(self) -> Runner: + raise NotImplementedError() + + @pytest.fixture( + scope="class", + params=[ + ("echo OK", "OK", ""), + # TODO: Test the proper escaping of variables. + # ("echo $USER", "todo", ""), + ], + ) + def command_output_err(self, request: pytest.FixtureRequest): + return request.param + + @pytest.fixture( + scope="class", + params=[ + ( + "cat /does/not/exist", + subprocess.CalledProcessError, + re.compile(r"cat: /does/not/exist: No such file or directory"), + ), + ], + ) + def command_exception_err(self, request: pytest.FixtureRequest): + return request.param + + @pytest.mark.parametrize("display", [True, False]) + @pytest.mark.parametrize("hide", [True, False, "out", "err", "stdout", "stderr"]) + def test_run( + self, + runner: Runner, + command_output_err: tuple[str, str | re.Pattern, str | re.Pattern], + hide: bool, + display: bool, + capsys: pytest.CaptureFixture, + ): + command, expected_output, expected_err = command_output_err + result = runner.run(command, display=display, hide=hide) + + if isinstance(expected_output, re.Pattern): + assert expected_output.search(result.stdout) + else: + assert result.stdout.strip() == expected_output + + if isinstance(expected_err, re.Pattern): + assert expected_err.search(result.stderr) + else: + assert result.stderr.strip() == expected_err + + printed_output, printed_err = capsys.readouterr() + assert isinstance(printed_output, str) + assert isinstance(printed_err, str) + + assert (f"({runner.hostname}) $ {command}" in printed_output) == display + + if result.stdout: + stdout_should_be_printed = hide not in [ + True, + "out", + "stdout", + ] + stdout_was_printed = result.stdout in printed_output + assert stdout_was_printed == stdout_should_be_printed + + if result.stderr: + error_should_be_printed = hide not in [ + True, + "err", + "stderr", + ] + error_was_printed = result.stderr in printed_err + assert error_was_printed == error_should_be_printed, ( + result.stderr, + printed_err, + ) + + @pytest.mark.parametrize("warn", [True, False]) + @pytest.mark.parametrize("display", [True, False]) + @pytest.mark.parametrize("hide", [True, False, "out", "err", "stdout", "stderr"]) + def test_run_with_error( + self, + runner: Runner, + command_exception_err: tuple[str, type[Exception], str | re.Pattern], + hide: bool, + warn: bool, + display: bool, + capsys: pytest.CaptureFixture, + caplog: pytest.LogCaptureFixture, + ): + command, expected_exception, expected_err = command_exception_err + + assert isinstance(expected_exception, type) and issubclass( + expected_exception, Exception + ) + + if not warn: + # Should raise an exception of this type. + with pytest.raises(expected_exception=expected_exception): + _ = runner.run(command, display=display, hide=hide, warn=warn) + # unreachable code here, so just pretend like it returns directly. + return + + with caplog.at_level(logging.WARNING): + result = runner.run(command, display=display, hide=hide, warn=warn) + + assert result.stdout == "" + if isinstance(expected_err, re.Pattern): + assert expected_err.search(result.stderr) + else: + assert result.stderr.strip() == expected_err + + if hide is True: + # Warnings not logged at all (because `warn=True` and `hide=True`). + assert caplog.records == [] + elif isinstance(expected_err, str): + assert len(caplog.records) == 1 + assert ( + caplog.records[0].message.strip() + == f"Command {command!r} returned non-zero exit code 1: {expected_err}" + ) + elif isinstance(expected_err, re.Pattern): + assert len(caplog.records) == 1 + message = caplog.records[0].message.strip() + # assert message.startswith( + # f"Command {command!r} returned non-zero exit code 1:" + # ) + assert expected_err.search(message) + + printed_output, printed_err = capsys.readouterr() + assert isinstance(printed_output, str) + assert isinstance(printed_err, str) + + assert (f"({runner.hostname}) $ {command}" in printed_output) == display + + if result.stdout: + stdout_should_be_printed = hide not in [ + True, + "out", + "stdout", + ] + stdout_was_printed = result.stdout in printed_output + assert stdout_was_printed == stdout_should_be_printed + + if result.stderr: + error_should_be_printed = hide not in [ + True, + "err", + "stderr", + ] + error_was_printed = result.stderr in printed_err + assert error_was_printed == error_should_be_printed, ( + result.stderr, + printed_err, + ) + + @pytest.mark.parametrize("display", [True, False]) + @pytest.mark.parametrize("hide", [True, False, "out", "err", "stdout", "stderr"]) + @pytest.mark.asyncio + async def test_run_async( + self, + runner: Runner, + command_output_err: tuple[str, str | re.Pattern, str | re.Pattern], + hide: bool, + display: bool, + capsys: pytest.CaptureFixture, + ): + command, expected_output, expected_err = command_output_err + result = await runner.run_async(command, display=display, hide=hide) + + if isinstance(expected_output, re.Pattern): + assert expected_output.match(result.stdout) + else: + assert result.stdout.strip() == expected_output + + if isinstance(expected_err, re.Pattern): + assert expected_err.match(result.stderr) + else: + assert result.stderr.strip() == expected_err + + printed_output, printed_err = capsys.readouterr() + assert isinstance(printed_output, str) + assert isinstance(printed_err, str) + + assert (f"({runner.hostname}) $ {command}" in printed_output) == display + + if result.stdout: + stdout_should_be_printed = hide not in [ + True, + "out", + "stdout", + ] + stdout_was_printed = result.stdout in printed_output + assert stdout_was_printed == stdout_should_be_printed + + if result.stderr: + error_should_be_printed = hide not in [ + True, + "err", + "stderr", + ] + error_was_printed = result.stderr in printed_err + assert error_was_printed == error_should_be_printed, ( + result.stderr, + printed_err, + ) + + @pytest.mark.asyncio + async def test_run_async_runs_in_parallel(self, runner: RemoteV2): + commands = [f"sleep {i}" for i in range(1, 3)] + start_time = time.time() + # Sequential time: + sequential_results = [runner.get_output(command) for command in commands] + sequential_time = time.time() - start_time + + start_time = time.time() + parallel_results = await asyncio.gather( + *(runner.get_output_async(command) for command in commands), + return_exceptions=False, + ) + parallel_time = time.time() - start_time + + assert sequential_results == parallel_results + assert parallel_time < sequential_time diff --git a/tests/utils/test_compute_node.py b/tests/utils/test_compute_node.py new file mode 100644 index 00000000..3461fef0 --- /dev/null +++ b/tests/utils/test_compute_node.py @@ -0,0 +1,213 @@ +from __future__ import annotations + +import asyncio +import re +from logging import getLogger as get_logger + +import pytest +import pytest_asyncio + +from milatools.utils.compute_node import ( + ComputeNode, + get_queued_milatools_job_ids, + salloc, + sbatch, +) +from milatools.utils.remote_v2 import RemoteV2 +from tests.utils.runner_tests import RunnerTests + +from ..conftest import launches_jobs, unsupported_on_windows + +logger = get_logger(__name__) +pytestmark = [launches_jobs, unsupported_on_windows] + + +@pytest.mark.slow +@pytest.mark.asyncio +async def test_salloc( + login_node_v2: RemoteV2, allocation_flags: list[str], job_name: str +): + compute_node = await salloc(login_node_v2, allocation_flags, job_name=job_name) + assert isinstance(compute_node, ComputeNode) + assert compute_node.hostname != login_node_v2.hostname + + # note: needs to be properly quoted so as not to evaluate the variable here! + job_id = compute_node.get_output("echo $SLURM_JOB_ID") + assert job_id.isdigit() + assert compute_node.job_id == int(job_id) + + all_slurm_env_vars = { + (split := line.split("="))[0]: split[1] + for line in compute_node.get_output("env | grep SLURM").splitlines() + } + # NOTE: We actually do have all the other SLURM env variables here because we're + # using `srun` with the job id on the login node to run our jobs. + assert all_slurm_env_vars["SLURM_JOB_ID"] == str(compute_node.job_id) + assert len(all_slurm_env_vars) > 1 + await compute_node.close() + + +@pytest.mark.slow +@pytest.mark.asyncio +async def test_sbatch( + login_node_v2: RemoteV2, allocation_flags: list[str], job_name: str +): + compute_node = await sbatch(login_node_v2, allocation_flags, job_name=job_name) + assert isinstance(compute_node, ComputeNode) + + assert compute_node.hostname != login_node_v2.hostname + job_id = compute_node.get_output("echo $SLURM_JOB_ID") + assert job_id.isdigit() + assert compute_node.job_id == int(job_id) + all_slurm_env_vars = { + (split := line.split("="))[0]: split[1] + for line in compute_node.get_output("env | grep SLURM").splitlines() + } + assert all_slurm_env_vars["SLURM_JOB_ID"] == str(compute_node.job_id) + assert len(all_slurm_env_vars) > 1 + await compute_node.close() + + +@pytest.fixture(scope="session", params=[True, False], ids=["sbatch", "salloc"]) +def persist(request: pytest.FixtureRequest): + return request.param + + +@pytest.mark.slow +@pytest.mark.asyncio +async def test_interrupt_allocation( + login_node_v2: RemoteV2, + allocation_flags: list[str], + job_name: str, + persist: bool, +): + """Test that checks that interrupting `salloc` or `sbatch` cancels the job + allocation. + + TODO: Try to get better control over when the interrupt happens, for example: + - while connecting via ssh; + - while waiting for the job to show up in `sacct`; + - while waiting for the job to start running. + """ + + async def get_jobs_in_squeue() -> set[int]: + return await get_queued_milatools_job_ids(login_node_v2, job_name=job_name) + + _jobs_before = await get_jobs_in_squeue() + + async def get_new_job_ids() -> set[int]: + """Retrieves the ID of the new jobs since we called `salloc` or `sbatch`.""" + new_job_ids: set[int] = set() + queued_jobs = await get_jobs_in_squeue() + new_job_ids = queued_jobs - _jobs_before + + while not new_job_ids: + queued_jobs = await get_jobs_in_squeue() + logger.info(f"{_jobs_before=}, {queued_jobs=}") + new_job_ids = queued_jobs - _jobs_before + if new_job_ids: + break + logger.info("Waiting for the job to show up in the output of `squeue`.") + await asyncio.sleep(0.1) + return new_job_ids + + # Check that a job allocation was indeed created. + # NOTE: Assuming that it takes more time for the job to be allocated than it takes for + # the job to show up in `squeue`. + salloc_task = asyncio.create_task( + sbatch(login_node_v2, sbatch_flags=allocation_flags, job_name=job_name) + if persist + else salloc(login_node_v2, salloc_flags=allocation_flags, job_name=job_name), + name="sbatch" if persist else "salloc", + ) + get_new_job_ids_task = asyncio.create_task( + get_new_job_ids(), name="get_new_job_ids" + ) + + new_job_ids = await asyncio.wait_for(get_new_job_ids_task, timeout=None) + assert not salloc_task.done() # hopefully we get the job ID from SQUEUE before the + # job is actually running... + salloc_task.cancel( + msg="Interrupting the job allocation as soon as the job ID shows up in squeue." + ) + assert new_job_ids and len(new_job_ids) == 1 + new_job_id = new_job_ids.pop() + # wait long enough for `squeue` to update and not show the job anymore. + await asyncio.sleep(10) + jobs_after = await get_jobs_in_squeue() + assert new_job_id not in jobs_after + assert jobs_after <= _jobs_before + + +@pytest.mark.slow +@launches_jobs +class TestComputeNode(RunnerTests): + @pytest_asyncio.fixture(scope="class") + async def runner( + self, login_node_v2: RemoteV2, persist: bool, allocation_flags: list[str] + ): + if persist: + runner = await sbatch( + login_node_v2, sbatch_flags=allocation_flags, job_name="mila-code" + ) + else: + runner = await salloc( + login_node_v2, salloc_flags=allocation_flags, job_name="mila-code" + ) + yield runner + await runner.close() + + @pytest.fixture( + scope="class", + params=[ + ("echo OK", "OK", ""), + ("echo $SLURM_JOB_ID", re.compile(r"^[0-9]+"), ""), + ], + ) + def command_output_err(self, request: pytest.FixtureRequest): + return request.param + + def test_run_gets_executed_in_job_step(self, runner: ComputeNode): + job_step_a = int(runner.get_output("echo $SLURM_STEP_ID")) + job_step_b = int(runner.get_output("echo $SLURM_STEP_ID")) + assert job_step_a >= 0 + assert job_step_b == job_step_a + 1 + + @pytest.mark.asyncio + async def test_run_async_gets_executed_in_job_step(self, runner: ComputeNode): + job_step_a = int(await runner.get_output_async("echo $SLURM_STEP_ID")) + job_step_b = int(await runner.get_output_async("echo $SLURM_STEP_ID")) + assert job_step_a >= 0 + assert job_step_b == job_step_a + 1 + + @pytest.mark.asyncio + async def test_close( + self, + login_node_v2: RemoteV2, + persist: bool, + allocation_flags: list[str], + job_name: str, + ): + # needs to be the last test with this remote though! + if persist: + compute_node = await sbatch( + login_node_v2, sbatch_flags=allocation_flags, job_name=job_name + ) + else: + compute_node = await salloc( + login_node_v2, salloc_flags=allocation_flags, job_name=job_name + ) + + await compute_node.close() + + job_state = await login_node_v2.get_output_async( + f"sacct --noheader --allocations --jobs {compute_node.job_id} --format=State%100", + display=True, + hide=False, + ) + if persist: + # batch jobs are scancelled. + assert job_state.startswith("CANCELLED") + else: + # interactive jobs are exited cleanly by just exiting in the terminal. + assert job_state == "COMPLETED" diff --git a/tests/cli/test_local.py b/tests/utils/test_local_v1.py similarity index 88% rename from tests/cli/test_local.py rename to tests/utils/test_local_v1.py index ae08771a..4882c977 100644 --- a/tests/cli/test_local.py +++ b/tests/utils/test_local_v1.py @@ -1,15 +1,16 @@ from __future__ import annotations +import os import sys from subprocess import PIPE import pytest from pytest_regressions.file_regression import FileRegressionFixture -from milatools.cli.local import CommandNotFoundError, Local, check_passwordless +from milatools.utils.local_v1 import CommandNotFoundError, LocalV1, check_passwordless from milatools.utils.remote_v2 import is_already_logged_in -from .common import ( +from ..cli.common import ( in_github_CI, in_self_hosted_github_CI, output_tester, @@ -38,7 +39,7 @@ def test_display( capsys: pytest.CaptureFixture, file_regression: FileRegressionFixture, ): - output_tester(lambda: (Local().display(cmd), None), capsys, file_regression) + output_tester(lambda: (LocalV1().display(cmd), None), capsys, file_regression) prints_unexpected_text_to_stdout_on_windows = xfails_on_windows( @@ -58,7 +59,7 @@ def test_silent_get( capsys: pytest.CaptureFixture, file_regression: FileRegressionFixture, ): - output_tester(lambda: (Local().silent_get(*cmd), None), capsys, file_regression) + output_tester(lambda: (LocalV1().silent_get(*cmd), None), capsys, file_regression) @prints_unexpected_text_to_stdout_on_windows @@ -69,7 +70,7 @@ def test_get( capsys: pytest.CaptureFixture, file_regression: FileRegressionFixture, ): - output_tester(lambda: (Local().get(*cmd), None), capsys, file_regression) + output_tester(lambda: (LocalV1().get(*cmd), None), capsys, file_regression) @prints_unexpected_text_to_stdout_on_windows @@ -81,7 +82,7 @@ def test_run( file_regression: FileRegressionFixture, ): def func(): - return Local().run(*cmd, capture_output=True), None + return LocalV1().run(*cmd, capture_output=True), None if cmd in [_FAKE_CMD, _FAIL_CODE_CMD]: @@ -108,7 +109,7 @@ def test_popen( file_regression: FileRegressionFixture, ): output_tester( - lambda: Local().popen(*cmd, stdout=PIPE, stderr=PIPE).communicate(), + lambda: LocalV1().popen(*cmd, stdout=PIPE, stderr=PIPE).communicate(), capsys, file_regression, ) @@ -191,4 +192,8 @@ def test_popen( def test_check_passwordless(hostname: str, expected: bool): # TODO: Maybe also test how `check_passwordless` behaves when using a key with a # passphrase. + clusters_in_maintenance = os.environ.get("CLUSTER_DOWN", "").split(",") + if hostname.partition("@")[0] in clusters_in_maintenance: + pytest.skip(reason=f"Cluster {hostname} is down for maintenance.") + assert check_passwordless(hostname) == expected diff --git a/tests/cli/test_local/test_display_cmd0_.txt b/tests/utils/test_local_v1/test_display_cmd0_.txt similarity index 100% rename from tests/cli/test_local/test_display_cmd0_.txt rename to tests/utils/test_local_v1/test_display_cmd0_.txt diff --git a/tests/cli/test_local/test_display_cmd1_.txt b/tests/utils/test_local_v1/test_display_cmd1_.txt similarity index 100% rename from tests/cli/test_local/test_display_cmd1_.txt rename to tests/utils/test_local_v1/test_display_cmd1_.txt diff --git a/tests/cli/test_local/test_get_cmd0_.txt b/tests/utils/test_local_v1/test_get_cmd0_.txt similarity index 100% rename from tests/cli/test_local/test_get_cmd0_.txt rename to tests/utils/test_local_v1/test_get_cmd0_.txt diff --git a/tests/cli/test_local/test_popen_cmd0_.txt b/tests/utils/test_local_v1/test_popen_cmd0_.txt similarity index 100% rename from tests/cli/test_local/test_popen_cmd0_.txt rename to tests/utils/test_local_v1/test_popen_cmd0_.txt diff --git a/tests/cli/test_local/test_run_cmd0_.txt b/tests/utils/test_local_v1/test_run_cmd0_.txt similarity index 100% rename from tests/cli/test_local/test_run_cmd0_.txt rename to tests/utils/test_local_v1/test_run_cmd0_.txt diff --git a/tests/cli/test_local/test_run_cmd1_.txt b/tests/utils/test_local_v1/test_run_cmd1_.txt similarity index 100% rename from tests/cli/test_local/test_run_cmd1_.txt rename to tests/utils/test_local_v1/test_run_cmd1_.txt diff --git a/tests/cli/test_local/test_run_cmd2_.txt b/tests/utils/test_local_v1/test_run_cmd2_.txt similarity index 100% rename from tests/cli/test_local/test_run_cmd2_.txt rename to tests/utils/test_local_v1/test_run_cmd2_.txt diff --git a/tests/cli/test_local/test_silent_get_cmd0_.txt b/tests/utils/test_local_v1/test_silent_get_cmd0_.txt similarity index 100% rename from tests/cli/test_local/test_silent_get_cmd0_.txt rename to tests/utils/test_local_v1/test_silent_get_cmd0_.txt diff --git a/tests/utils/test_local_v2.py b/tests/utils/test_local_v2.py new file mode 100644 index 00000000..70f8ad8a --- /dev/null +++ b/tests/utils/test_local_v2.py @@ -0,0 +1,14 @@ +from __future__ import annotations + +import pytest + +from milatools.utils.local_v2 import LocalV2 +from milatools.utils.runner import Runner + +from .runner_tests import RunnerTests + + +class TestLocalV2(RunnerTests): + @pytest.fixture(scope="class") + def runner(self) -> Runner: + return LocalV2() diff --git a/tests/utils/test_parallel_progress/test_parallel_progress_bar.txt b/tests/utils/test_parallel_progress/test_parallel_progress_bar.txt index 32dac6ec..cf817d7b 100644 --- a/tests/utils/test_parallel_progress/test_parallel_progress_bar.txt +++ b/tests/utils/test_parallel_progress/test_parallel_progress_bar.txt @@ -1,5 +1,5 @@ -✓ All jobs progress: 20/20 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 -✓ Task 0 - Done. 5/5 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 -✓ Task 1 - Done. 5/5 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 -✓ Task 2 - Done. 5/5 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 -✓ Task 3 - Done. 5/5 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 \ No newline at end of file +✓ All jobs progress: 20/20 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 +✓ Task 0 - Done. 5/5 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 +✓ Task 1 - Done. 5/5 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 +✓ Task 2 - Done. 5/5 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 +✓ Task 3 - Done. 5/5 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 \ No newline at end of file diff --git a/tests/cli/test_remote.py b/tests/utils/test_remote_v1.py similarity index 95% rename from tests/cli/test_remote.py rename to tests/utils/test_remote_v1.py index 0c495c86..f8e0342c 100644 --- a/tests/cli/test_remote.py +++ b/tests/utils/test_remote_v1.py @@ -16,15 +16,15 @@ from fabric.connection import Connection from pytest_regressions.file_regression import FileRegressionFixture -from milatools.cli.remote import ( +from milatools.cli.utils import T, cluster_to_connect_kwargs +from milatools.utils.remote_v1 import ( QueueIO, - Remote, + RemoteV1, SlurmRemote, get_first_node_name, ) -from milatools.cli.utils import T, cluster_to_connect_kwargs -from .common import function_call_string +from ..cli.common import function_call_string @pytest.mark.parametrize("keepalive", [0, 123]) @@ -38,7 +38,7 @@ def test_init( # This should have called the `Connection` class with the host, which we patched in # the fixture above. - r = Remote(host, keepalive=keepalive) + r = RemoteV1(host, keepalive=keepalive) # The Remote should have created a Connection instance (which happens to be # the mock_connection we made above). MockConnection.assert_called_once_with( @@ -64,7 +64,7 @@ def test_init_with_connection( mock_connection: Mock, ): """This test shows the behaviour of __init__ to isolate it from other tests.""" - r = Remote(mock_connection.host, connection=mock_connection, keepalive=keepalive) + r = RemoteV1(mock_connection.host, connection=mock_connection, keepalive=keepalive) MockConnection.assert_not_called() assert r.connection is mock_connection # The connection is not opened, and the transport is also not opened. @@ -80,23 +80,23 @@ def test_init_with_connection( ("method", "args"), [ ( - Remote.with_transforms, + RemoteV1.with_transforms, ( lambda cmd: cmd.replace("OK", "NOT_OK"), lambda cmd: f"echo 'command before' && {cmd}", ), ), ( - Remote.wrap, + RemoteV1.wrap, ("echo 'echo wrap' && {}",), ), ( - Remote.with_precommand, + RemoteV1.with_precommand, ("echo 'echo precommand'",), ), # this need to be a file to source before running the command. - (Remote.with_profile, (".bashrc",)), - (Remote.with_bash, ()), + (RemoteV1.with_profile, (".bashrc",)), + (RemoteV1.with_bash, ()), ], ) def test_remote_transform_methods( @@ -112,13 +112,13 @@ def test_remote_transform_methods( """Test the methods of `Remote` that modify the commands passed to `run` before it gets passed to the connection and run on the server.""" mock_connection = mock_connection - r = Remote( + r = RemoteV1( host, connection=mock_connection, transforms=initial_transforms, ) # Call the method on the remote, which should return a new Remote. - modified_remote: Remote = method(r, *args) + modified_remote: RemoteV1 = method(r, *args) assert modified_remote.hostname == r.hostname assert modified_remote.connection is r.connection @@ -143,7 +143,7 @@ def test_remote_transform_methods( After creating a Remote like so: ```python -remote = {function_call_string(Remote, host, connection=mock_connection, transforms=())} +remote = {function_call_string(RemoteV1, host, connection=mock_connection, transforms=())} ``` and then calling: @@ -173,7 +173,7 @@ def test_remote_transform_methods( @pytest.mark.parametrize("message", ["foobar"]) def test_display( message: str, - remote: Remote, + remote: RemoteV1, capsys: pytest.CaptureFixture, ): remote.display(message) @@ -205,7 +205,7 @@ def hide( @pytest.mark.parametrize("warn", [True, False]) @pytest.mark.parametrize("display", [True, False, None]) def test_run( - remote: Remote, + remote: RemoteV1, command: str, expected_output: str, asynchronous: bool, @@ -277,7 +277,7 @@ def test_get_output( mock_result = Mock(wraps=invoke.runners.Result(Mock(wraps=command_output))) mock_connection.run.return_value = mock_result - r = Remote(host, connection=mock_connection) + r = RemoteV1(host, connection=mock_connection) output = r.get_output(command, display=None, hide=hide, warn=warn) assert output == command_output @@ -304,7 +304,7 @@ def test_get_lines( command = " && ".join(f"echo '{line}'" for line in expected_lines) command_output = "\n".join(expected_lines) mock_connection.run.return_value = invoke.runners.Result(stdout=command_output) - r = Remote(host, connection=mock_connection) + r = RemoteV1(host, connection=mock_connection) lines = r.get_lines(command, hide=hide, warn=warn) # NOTE: We'd expect this, but instead we get ['Line', '1', 'has', 'this', 'value', # TODO: Uncomment this if we fix `get_lines` to split based on lines, or remove this @@ -327,7 +327,7 @@ def write_lines_with_sleeps(lines: Iterable[str], sleep_time: float = 0.1): @pytest.mark.parametrize("wait", [True, False]) @pytest.mark.parametrize("pty", [True, False]) def test_extract( - remote: Remote, + remote: RemoteV1, wait: bool, pty: bool, ): @@ -361,7 +361,7 @@ def _xfail_if_not_on_localhost(host: str): pytest.xfail("This test only works on localhost.") -def test_get(remote: Remote, tmp_path: Path, host: str): +def test_get(remote: RemoteV1, tmp_path: Path, host: str): # TODO: Make this test smarter? or no need? (because we'd be testing fabric at that # point?) _xfail_if_not_on_localhost(remote.hostname) @@ -374,7 +374,7 @@ def test_get(remote: Remote, tmp_path: Path, host: str): assert dest.read_text() == source_content -def test_put(remote: Remote, tmp_path: Path): +def test_put(remote: RemoteV1, tmp_path: Path): _xfail_if_not_on_localhost(remote.hostname) src = tmp_path / "foo" dest = tmp_path / "bar" @@ -389,7 +389,7 @@ def test_put(remote: Remote, tmp_path: Path): assert dest.read_text() == source_content -def test_puttext(remote: Remote, tmp_path: Path): +def test_puttext(remote: RemoteV1, tmp_path: Path): _xfail_if_not_on_localhost(remote.hostname) dest_dir = tmp_path / "bar/baz" dest = tmp_path / f"{dest_dir}/bob.txt" @@ -402,7 +402,7 @@ def test_puttext(remote: Remote, tmp_path: Path): assert dest.read_text() == some_text -def test_home(remote: Remote): +def test_home(remote: RemoteV1): home_dir = remote.home() remote.connection.run.assert_called_once() assert remote.connection.run.mock_calls[0].args[0] == "echo $HOME" @@ -413,11 +413,11 @@ def test_home(remote: Remote): assert home_dir == str(Path.home()) -def test_persist(remote: Remote): +def test_persist(remote: RemoteV1): assert remote.persist() is remote -def test_ensure_allocation(remote: Remote): +def test_ensure_allocation(remote: RemoteV1): assert remote.ensure_allocation() == ({"node_name": remote.hostname}, None) @@ -540,7 +540,7 @@ def test_with_transforms(self, mock_connection: Connection, persist: bool | None # It isn't a very useful test, but it's better than not having one for now. # The test for Remote.run above checks that `run` on the transformed remote # does what we expect. - assert SlurmRemote.run is Remote.run + assert SlurmRemote.run is RemoteV1.run alloc = ["--time=00:01:00"] transforms = [some_transform] new_transforms = [some_other_transform] diff --git a/tests/cli/test_remote/test_QueueIO.txt b/tests/utils/test_remote_v1/test_QueueIO.txt similarity index 100% rename from tests/cli/test_remote/test_QueueIO.txt rename to tests/utils/test_remote_v1/test_QueueIO.txt diff --git a/tests/cli/test_remote/test_remote_transform_methods_localhost_with_bash_args4_initial_transforms0_echo_OK_.md b/tests/utils/test_remote_v1/test_remote_transform_methods_localhost_with_bash_args4_initial_transforms0_echo_OK_.md similarity index 82% rename from tests/cli/test_remote/test_remote_transform_methods_localhost_with_bash_args4_initial_transforms0_echo_OK_.md rename to tests/utils/test_remote_v1/test_remote_transform_methods_localhost_with_bash_args4_initial_transforms0_echo_OK_.md index 4c739f1e..d54e4e87 100644 --- a/tests/cli/test_remote/test_remote_transform_methods_localhost_with_bash_args4_initial_transforms0_echo_OK_.md +++ b/tests/utils/test_remote_v1/test_remote_transform_methods_localhost_with_bash_args4_initial_transforms0_echo_OK_.md @@ -1,7 +1,7 @@ After creating a Remote like so: ```python -remote = Remote('localhost', connection=Connection('localhost'), transforms=()) +remote = RemoteV1('localhost', connection=Connection('localhost'), transforms=()) ``` and then calling: diff --git a/tests/cli/test_remote/test_remote_transform_methods_localhost_with_precommand_args2_initial_transforms0_echo_OK_.md b/tests/utils/test_remote_v1/test_remote_transform_methods_localhost_with_precommand_args2_initial_transforms0_echo_OK_.md similarity index 84% rename from tests/cli/test_remote/test_remote_transform_methods_localhost_with_precommand_args2_initial_transforms0_echo_OK_.md rename to tests/utils/test_remote_v1/test_remote_transform_methods_localhost_with_precommand_args2_initial_transforms0_echo_OK_.md index dfa0609f..df7127a4 100644 --- a/tests/cli/test_remote/test_remote_transform_methods_localhost_with_precommand_args2_initial_transforms0_echo_OK_.md +++ b/tests/utils/test_remote_v1/test_remote_transform_methods_localhost_with_precommand_args2_initial_transforms0_echo_OK_.md @@ -1,7 +1,7 @@ After creating a Remote like so: ```python -remote = Remote('localhost', connection=Connection('localhost'), transforms=()) +remote = RemoteV1('localhost', connection=Connection('localhost'), transforms=()) ``` and then calling: diff --git a/tests/cli/test_remote/test_remote_transform_methods_localhost_with_profile_args3_initial_transforms0_echo_OK_.md b/tests/utils/test_remote_v1/test_remote_transform_methods_localhost_with_profile_args3_initial_transforms0_echo_OK_.md similarity index 82% rename from tests/cli/test_remote/test_remote_transform_methods_localhost_with_profile_args3_initial_transforms0_echo_OK_.md rename to tests/utils/test_remote_v1/test_remote_transform_methods_localhost_with_profile_args3_initial_transforms0_echo_OK_.md index f67b9f0e..efb7bcc7 100644 --- a/tests/cli/test_remote/test_remote_transform_methods_localhost_with_profile_args3_initial_transforms0_echo_OK_.md +++ b/tests/utils/test_remote_v1/test_remote_transform_methods_localhost_with_profile_args3_initial_transforms0_echo_OK_.md @@ -1,7 +1,7 @@ After creating a Remote like so: ```python -remote = Remote('localhost', connection=Connection('localhost'), transforms=()) +remote = RemoteV1('localhost', connection=Connection('localhost'), transforms=()) ``` and then calling: diff --git a/tests/cli/test_remote/test_remote_transform_methods_localhost_with_transforms_args0_initial_transforms0_echo_OK_.md b/tests/utils/test_remote_v1/test_remote_transform_methods_localhost_with_transforms_args0_initial_transforms0_echo_OK_.md similarity index 86% rename from tests/cli/test_remote/test_remote_transform_methods_localhost_with_transforms_args0_initial_transforms0_echo_OK_.md rename to tests/utils/test_remote_v1/test_remote_transform_methods_localhost_with_transforms_args0_initial_transforms0_echo_OK_.md index c7227892..2c16d3ad 100644 --- a/tests/cli/test_remote/test_remote_transform_methods_localhost_with_transforms_args0_initial_transforms0_echo_OK_.md +++ b/tests/utils/test_remote_v1/test_remote_transform_methods_localhost_with_transforms_args0_initial_transforms0_echo_OK_.md @@ -1,7 +1,7 @@ After creating a Remote like so: ```python -remote = Remote('localhost', connection=Connection('localhost'), transforms=()) +remote = RemoteV1('localhost', connection=Connection('localhost'), transforms=()) ``` and then calling: diff --git a/tests/cli/test_remote/test_remote_transform_methods_localhost_wrap_args1_initial_transforms0_echo_OK_.md b/tests/utils/test_remote_v1/test_remote_transform_methods_localhost_wrap_args1_initial_transforms0_echo_OK_.md similarity index 83% rename from tests/cli/test_remote/test_remote_transform_methods_localhost_wrap_args1_initial_transforms0_echo_OK_.md rename to tests/utils/test_remote_v1/test_remote_transform_methods_localhost_wrap_args1_initial_transforms0_echo_OK_.md index f3bc1b5a..2c0cf474 100644 --- a/tests/cli/test_remote/test_remote_transform_methods_localhost_wrap_args1_initial_transforms0_echo_OK_.md +++ b/tests/utils/test_remote_v1/test_remote_transform_methods_localhost_wrap_args1_initial_transforms0_echo_OK_.md @@ -1,7 +1,7 @@ After creating a Remote like so: ```python -remote = Remote('localhost', connection=Connection('localhost'), transforms=()) +remote = RemoteV1('localhost', connection=Connection('localhost'), transforms=()) ``` and then calling: diff --git a/tests/cli/test_remote/test_srun_transform_persist_localhost_.md b/tests/utils/test_remote_v1/test_srun_transform_persist_localhost_.md similarity index 100% rename from tests/cli/test_remote/test_srun_transform_persist_localhost_.md rename to tests/utils/test_remote_v1/test_srun_transform_persist_localhost_.md diff --git a/tests/utils/test_remote_v2.py b/tests/utils/test_remote_v2.py index 1fe92d21..5f911cc2 100644 --- a/tests/utils/test_remote_v2.py +++ b/tests/utils/test_remote_v2.py @@ -1,68 +1,106 @@ +from __future__ import annotations + from pathlib import Path from unittest.mock import Mock import pytest import milatools.utils.remote_v2 +from milatools.cli.init_command import DRAC_CLUSTERS +from milatools.cli.utils import SSH_CONFIG_FILE +from milatools.utils.local_v2 import run_async from milatools.utils.remote_v2 import ( RemoteV2, - UnsupportedPlatformError, control_socket_is_running, + control_socket_is_running_async, get_controlpath_for, is_already_logged_in, ) -from tests.integration.conftest import skip_param_if_not_already_logged_in +from tests.utils.runner_tests import RunnerTests -from ..cli.common import requires_ssh_to_localhost, xfails_on_windows +from ..cli.common import ( + requires_ssh_to_localhost, +) +from ..conftest import unsupported_on_windows -pytestmark = [xfails_on_windows(raises=UnsupportedPlatformError, strict=True)] +pytestmark = [unsupported_on_windows] -@requires_ssh_to_localhost -def test_init_with_controlpath(tmp_path: Path): - control_path = tmp_path / "socketfile" - remote = RemoteV2("localhost", control_path=control_path) - assert control_path.exists() - files = remote.get_output(f"ls {control_path.parent}").split() - assert files == [control_path.name] +class TestRemoteV2(RunnerTests): + @pytest.fixture(scope="session") + def hostname(self, request: pytest.FixtureRequest, cluster: str): + return getattr(request, "param", cluster) + @pytest.fixture(scope="class") + def runner(self, hostname: str): + return RemoteV2(hostname) -@requires_ssh_to_localhost -def test_init_with_none_controlpath(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): - control_path = tmp_path / "socketfile" - mock_get_controlpath_for = Mock( - wraps=get_controlpath_for, return_value=control_path - ) + @pytest.mark.slow # can be slow due to initiating a new connection (~30s for Mila) + @pytest.mark.parametrize("use_async_init", [False, True], ids=["sync", "async"]) + @pytest.mark.asyncio + async def test_init_with_controlpath( + self, hostname: str, tmp_path: Path, use_async_init: bool + ): + # NOTE: Need to skip any cluster where 2FA might be enabled, because we're + # specifying a different controlpath here so it would go through 2FA. + if hostname in DRAC_CLUSTERS: + pytest.skip(reason="2FA might be enabled on this cluster.") - monkeypatch.setattr( - milatools.utils.remote_v2, - get_controlpath_for.__name__, - mock_get_controlpath_for, - ) - remote = RemoteV2("localhost", control_path=None) - mock_get_controlpath_for.assert_called_once_with("localhost") - assert control_path.exists() - files = remote.get_output(f"ls {control_path.parent}").split() - assert files == [control_path.name] - - -@pytest.mark.parametrize( - "hostname", - [ - pytest.param("localhost", marks=requires_ssh_to_localhost), - skip_param_if_not_already_logged_in("mila"), - skip_param_if_not_already_logged_in("narval"), - skip_param_if_not_already_logged_in("beluga"), - skip_param_if_not_already_logged_in("cedar"), - skip_param_if_not_already_logged_in("graham"), - skip_param_if_not_already_logged_in("niagara"), - ], -) -def test_run(hostname: str): - command = "echo Hello World" - remote = RemoteV2(hostname) - output = remote.get_output(command) - assert output == "Hello World" + control_path = tmp_path / "socketfile" + assert not control_path.exists() + + remote = ( + (await RemoteV2.connect(hostname, control_path=control_path)) + if use_async_init + else RemoteV2(hostname, control_path=control_path) + ) + assert control_path.exists() + + if hostname == "localhost": + files = remote.get_output(f"ls {control_path.parent}").split() + assert files == [control_path.name] + + await run_async( + ("ssh", f"-oControlPath={control_path}", "-O", "exit", hostname), + ) + assert not control_path.exists() + + @requires_ssh_to_localhost + @pytest.mark.parametrize("use_async_init", [False, True]) + @pytest.mark.asyncio + async def test_init_with_none_controlpath( + self, + hostname: str, + use_async_init: bool, + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, + ): + if hostname in DRAC_CLUSTERS: + pytest.skip(reason="2FA might be enabled on this cluster.") + + control_path = tmp_path / "socketfile" + mock_get_controlpath_for = Mock( + wraps=get_controlpath_for, return_value=control_path + ) + + monkeypatch.setattr( + milatools.utils.remote_v2, + get_controlpath_for.__name__, + mock_get_controlpath_for, + ) + remote = ( + (await RemoteV2.connect(hostname, control_path=None)) + if use_async_init + else RemoteV2(hostname, control_path=None) + ) + mock_get_controlpath_for.assert_called_once_with( + hostname, ssh_config_path=SSH_CONFIG_FILE + ) + assert control_path.exists() + + if hostname == "localhost": + files = remote.get_output(f"ls {control_path.parent}").split() + assert files == [control_path.name] # NOTE: The timeout here is a part of the test: if we are already connected, running the @@ -85,3 +123,12 @@ def test_is_already_logged_in( def test_controlsocket_is_running(cluster: str, already_logged_in: bool): control_path = get_controlpath_for(cluster) assert control_socket_is_running(cluster, control_path) == already_logged_in + + +@pytest.mark.asyncio +async def test_controlsocket_is_running_async(cluster: str, already_logged_in: bool): + control_path = get_controlpath_for(cluster) + assert ( + await control_socket_is_running_async(cluster, control_path) + == already_logged_in + ) diff --git a/tests/utils/test_vscode_utils.py b/tests/utils/test_vscode_utils.py index 7b2126aa..217766b7 100644 --- a/tests/utils/test_vscode_utils.py +++ b/tests/utils/test_vscode_utils.py @@ -1,6 +1,5 @@ from __future__ import annotations -import getpass import multiprocessing import shutil import sys @@ -11,9 +10,8 @@ import pytest -from milatools.cli.local import Local -from milatools.cli.remote import Remote -from milatools.cli.utils import running_inside_WSL +from milatools.cli.utils import MilatoolsUserError, running_inside_WSL +from milatools.utils.local_v2 import LocalV2 from milatools.utils.parallel_progress import ProgressDict from milatools.utils.remote_v2 import RemoteV2, UnsupportedPlatformError from milatools.utils.vscode_utils import ( @@ -91,11 +89,14 @@ def test_running_inside_WSL(): def test_get_vscode_executable_path(): - code = get_vscode_executable_path() if vscode_installed(): - assert code is not None and Path(code).exists() + code = get_vscode_executable_path() + assert Path(code).exists() else: - assert code is None + with pytest.raises( + MilatoolsUserError, match="Command 'code' does not exist locally." + ): + get_vscode_executable_path() @pytest.fixture @@ -134,14 +135,14 @@ def test_sync_vscode_extensions_in_parallel_with_hostnames( # Make the destination slightly different so it actually gets wrapped as a # `Remote(v2)` object. "localhost", - destinations=[f"{getpass.getuser()}@localhost"], + destinations=["localhost"], ) @requires_vscode @requires_ssh_to_localhost def test_sync_vscode_extensions_in_parallel(): - results = sync_vscode_extensions(Local(), dest_clusters=[Local()]) + results = sync_vscode_extensions(LocalV2(), destinations=[LocalV2()]) assert results == {"localhost": {"info": "Done.", "progress": 0, "total": 0}} @@ -205,10 +206,6 @@ def missing_extensions( return vscode_extensions[2] -def _remote(hostname: str): - return RemoteV2(hostname) if sys.platform != "win32" else Remote(hostname) - - @uses_remote_v2 @requires_ssh_to_localhost @requires_vscode @@ -225,7 +222,7 @@ def test_install_vscode_extensions_task_function( task_progress_dict: DictProxy[TaskID, ProgressDict] = manager.dict() - _fake_remote = _remote("localhost") + _fake_remote = RemoteV2("localhost") result = install_vscode_extensions_task_function( task_progress_dict=task_progress_dict, @@ -253,7 +250,7 @@ def test_install_vscode_extensions_task_function( def test_install_vscode_extension(missing_extensions: dict[str, str]): extension_name, version = next(iter(missing_extensions.items())) result = install_vscode_extension( - remote=_remote("localhost"), + remote=RemoteV2("localhost"), code_server_executable=str(get_vscode_executable_path()), extension=f"{extension_name}@{version}", verbose=False, @@ -281,7 +278,7 @@ def test_get_local_vscode_extensions(): def test_get_remote_vscode_extensions(): # We make it so this calls the local `code` command over SSH to localhost, # therefore the "remote" extensions are the same as the local extensions. - fake_remote = _remote("localhost") + fake_remote = RemoteV2("localhost") local_vscode_executable = get_vscode_executable_path() assert local_vscode_executable is not None