diff --git a/milatools/cli/commands.py b/milatools/cli/commands.py index 6a449782..86286b12 100644 --- a/milatools/cli/commands.py +++ b/milatools/cli/commands.py @@ -38,6 +38,8 @@ ) from ..__version__ import __version__ +from ..utils.local_v1 import LocalV1 +from ..utils.remote_v1 import RemoteV1, SlurmRemote from .init_command import ( print_welcome_message, setup_keys_on_login_node, @@ -46,9 +48,7 @@ setup_vscode_settings, setup_windows_ssh_config_from_wsl, ) -from .local import Local from .profile import ensure_program, setup_profile -from .remote import Remote, SlurmRemote from .utils import ( CLUSTERS, Cluster, @@ -518,7 +518,7 @@ def forward( pass local_proc, _ = _forward( - local=Local(), + local=LocalV1(), node=f"{node}.server.mila.quebec", to_forward=remote_port, page=page, @@ -553,8 +553,8 @@ def code( node: Node to connect to alloc: Extra options to pass to slurm """ - here = Local() - remote = Remote(cluster) + here = LocalV1() + remote = RemoteV1(cluster) if cluster != "mila" and job is None and node is None: if not any("--account" in flag for flag in alloc): @@ -603,7 +603,7 @@ def code( copy_vscode_extensions_process.start() else: sync_vscode_extensions( - Local(), + LocalV1(), [cluster], ) @@ -697,10 +697,10 @@ def code( def connect(identifier: str, port: int | None): """Reconnect to a persistent server.""" - remote = Remote("mila") + remote = RemoteV1("mila") info = _get_server_info(remote, identifier) local_proc, _ = _forward( - local=Local(), + local=LocalV1(), node=f"{info['node_name']}.server.mila.quebec", to_forward=info["to_forward"], options={"token": info.get("token", None)}, @@ -718,7 +718,7 @@ def connect(identifier: str, port: int | None): def kill(identifier: str | None, all: bool = False): """Kill a persistent server.""" - remote = Remote("mila") + remote = RemoteV1("mila") if all: for identifier in remote.get_lines("ls .milatools/control", hide=True): @@ -740,7 +740,7 @@ def kill(identifier: str | None, all: bool = False): def serve_list(purge: bool): """List active servers.""" - remote = Remote("mila") + remote = RemoteV1("mila") to_purge = [] @@ -899,7 +899,7 @@ def aim(logdir: str, **kwargs: Unpack[StandardServerArgs]): def _get_server_info( - remote: Remote, identifier: str, hide: bool = False + remote: RemoteV1, identifier: str, hide: bool = False ) -> dict[str, str]: text = remote.get_output(f"cat .milatools/control/{identifier}", hide=hide) info = dict(line.split(" = ") for line in text.split("\n") if line) @@ -993,7 +993,7 @@ def _standard_server( elif persist: name = program - remote = Remote("mila") + remote = RemoteV1("mila") path = path or "~" if path == "~" or path.startswith("~/"): @@ -1107,7 +1107,7 @@ def _standard_server( options = {} local_proc, local_port = _forward( - local=Local(), + local=LocalV1(), node=get_fully_qualified_hostname_of_compute_node(node_name, cluster="mila"), to_forward=to_forward, options=options, @@ -1181,7 +1181,7 @@ def _parse_lfs_quota_output( return (used_gb, max_gb), (used_files, max_files) -def check_disk_quota(remote: Remote | RemoteV2) -> None: +def check_disk_quota(remote: RemoteV1 | RemoteV2) -> None: cluster = remote.hostname # NOTE: This is what the output of the command looks like on the Mila cluster: @@ -1262,7 +1262,7 @@ def get_colour(used: float, max: float) -> str: def _find_allocation( - remote: Remote, + remote: RemoteV1, node: str | None, job: str | None, alloc: list[str], @@ -1274,11 +1274,15 @@ def _find_allocation( if node is not None: node_name = get_fully_qualified_hostname_of_compute_node(node, cluster=cluster) - return Remote(node_name, connect_kwargs=cluster_to_connect_kwargs.get(cluster)) + return RemoteV1( + node_name, connect_kwargs=cluster_to_connect_kwargs.get(cluster) + ) elif job is not None: node_name = remote.get_output(f"squeue --jobs {job} -ho %N") - return Remote(node_name, connect_kwargs=cluster_to_connect_kwargs.get(cluster)) + return RemoteV1( + node_name, connect_kwargs=cluster_to_connect_kwargs.get(cluster) + ) else: alloc = ["-J", job_name, *alloc] @@ -1290,7 +1294,7 @@ def _find_allocation( def _forward( - local: Local, + local: LocalV1, node: str, to_forward: int | str, port: int | None, diff --git a/milatools/cli/init_command.py b/milatools/cli/init_command.py index d3b16981..c58ebcf5 100644 --- a/milatools/cli/init_command.py +++ b/milatools/cli/init_command.py @@ -17,12 +17,12 @@ from milatools.utils.remote_v2 import SSH_CONFIG_FILE +from ..utils.local_v1 import LocalV1, check_passwordless, display +from ..utils.remote_v1 import RemoteV1 from ..utils.vscode_utils import ( get_expected_vscode_settings_json_path, vscode_installed, ) -from .local import Local, check_passwordless, display -from .remote import Remote from .utils import SSHConfig, T, running_inside_WSL, yn logger = get_logger(__name__) @@ -238,7 +238,7 @@ def setup_passwordless_ssh_access(ssh_config: SSHConfig) -> bool: """ print("Checking passwordless authentication") - here = Local() + here = LocalV1() sshdir = Path.home() / ".ssh" # Check if there is a public key file in ~/.ssh @@ -294,7 +294,7 @@ def setup_passwordless_ssh_access_to_cluster(cluster: str) -> bool: Returns whether the operation completed successfully or not. """ - here = Local() + here = LocalV1() # Check that it is possible to connect without using a password. print(f"Checking if passwordless SSH access is setup for the {cluster} cluster.") # TODO: Potentially use a custom key like `~/.ssh/id_milatools.pub` instead of @@ -371,7 +371,7 @@ def setup_keys_on_login_node(cluster: str = "mila"): "This is required for `mila code` to work properly." ) # todo: avoid re-creating the `Remote` here, since it goes through 2FA each time! - remote = Remote(cluster) + remote = RemoteV1(cluster) try: pubkeys = remote.get_lines("ls -t ~/.ssh/id*.pub") print("# OK") @@ -443,7 +443,7 @@ def get_windows_home_path_in_wsl() -> Path: def create_ssh_keypair( ssh_private_key_path: Path, - local: Local | None = None, + local: LocalV1 | None = None, passphrase: str | None = "", ) -> None: """Creates a public/private key pair at the given path using ssh-keygen. @@ -452,7 +452,7 @@ def create_ssh_keypair( Otherwise, if passphrase is an empty string, no passphrase will be used (default). If a string is passed, it is passed to ssh-keygen and used as the passphrase. """ - local = local or Local() + local = local or LocalV1() command = [ "ssh-keygen", "-f", diff --git a/milatools/cli/profile.py b/milatools/cli/profile.py index d4ff7638..1a71af45 100644 --- a/milatools/cli/profile.py +++ b/milatools/cli/profile.py @@ -16,7 +16,7 @@ from .utils import askpath, yn if typing.TYPE_CHECKING: - from milatools.cli.remote import Remote + from milatools.utils.remote_v1 import RemoteV1 style = qn.Style( [ @@ -38,7 +38,7 @@ def _ask_name(message: str, default: str = "") -> str: qn.print(f"Invalid name: {name}", style="bold red") -def setup_profile(remote: Remote, path: str) -> str: +def setup_profile(remote: RemoteV1, path: str) -> str: profile = select_preferred(remote, path) preferred = profile is not None if not preferred: @@ -58,7 +58,7 @@ def setup_profile(remote: Remote, path: str) -> str: return profile -def select_preferred(remote: Remote, path: str) -> str | None: +def select_preferred(remote: RemoteV1, path: str) -> str | None: preferred = f"{path}/.milatools-profile" qn.print(f"Checking for preferred profile in {preferred}") @@ -71,7 +71,7 @@ def select_preferred(remote: Remote, path: str) -> str | None: return preferred -def select_profile(remote: Remote) -> str | None: +def select_profile(remote: RemoteV1) -> str | None: profdir = "~/.milatools/profiles" qn.print(f"Fetching profiles in {profdir}") @@ -109,7 +109,7 @@ def select_profile(remote: Remote) -> str | None: return profile -def create_profile(remote: Remote, path: str = "~"): +def create_profile(remote: RemoteV1, path: str = "~"): modules = select_modules(remote) mload = f"module load {' '.join(modules)}" @@ -139,7 +139,7 @@ def create_profile(remote: Remote, path: str = "~"): return prof_file -def select_modules(remote: Remote): +def select_modules(remote: RemoteV1): choices = [ Choice( title="miniconda/3", @@ -202,7 +202,7 @@ def _env_basename(pth: str) -> str | None: return base -def select_conda_environment(remote: Remote, loader: str = "module load miniconda/3"): +def select_conda_environment(remote: RemoteV1, loader: str = "module load miniconda/3"): qn.print("Fetching the list of conda environments...") envstr = remote.get_output("conda env list --json", hide=True) envlist: list[str] = json.loads(envstr)["envs"] @@ -254,7 +254,7 @@ def select_conda_environment(remote: Remote, loader: str = "module load minicond return env -def select_virtual_environment(remote: Remote, path): +def select_virtual_environment(remote: RemoteV1, path): envstr = remote.get_output( ( f"ls -d {path}/venv {path}/.venv {path}/virtualenv ~/virtualenvs/* " @@ -293,7 +293,7 @@ def select_virtual_environment(remote: Remote, path): return env -def ensure_program(remote: Remote, program: str, installers: dict[str, str]): +def ensure_program(remote: RemoteV1, program: str, installers: dict[str, str]): to_test = [program, *installers.keys()] progs = [ Path(p).name diff --git a/milatools/cli/utils.py b/milatools/cli/utils.py index e62b79ce..c26ced0b 100644 --- a/milatools/cli/utils.py +++ b/milatools/cli/utils.py @@ -24,7 +24,7 @@ from typing_extensions import ParamSpec, TypeGuard if typing.TYPE_CHECKING: - from milatools.cli.remote import Remote + from milatools.utils.remote_v1 import RemoteV1 control_file_var = contextvars.ContextVar("control_file", default="/dev/null") @@ -96,7 +96,7 @@ def randname(): @contextmanager -def with_control_file(remote: Remote, name=None): +def with_control_file(remote: RemoteV1, name=None): name = name or randname() pth = f".milatools/control/{name}" remote.run("mkdir -p ~/.milatools/control", hide=True) @@ -172,7 +172,7 @@ def yn(prompt: str, default: bool = True) -> bool: return qn.confirm(prompt, default=default).unsafe_ask() -def askpath(prompt: str, remote: Remote) -> str: +def askpath(prompt: str, remote: RemoteV1) -> str: while True: pth = qn.text(prompt).unsafe_ask() try: diff --git a/milatools/cli/local.py b/milatools/utils/local_v1.py similarity index 97% rename from milatools/cli/local.py rename to milatools/utils/local_v1.py index fb8e2ee7..34a4d68a 100644 --- a/milatools/cli/local.py +++ b/milatools/utils/local_v1.py @@ -14,12 +14,12 @@ from milatools.utils.remote_v2 import SSH_CONFIG_FILE, is_already_logged_in -from .utils import CommandNotFoundError, T, cluster_to_connect_kwargs +from ..cli.utils import CommandNotFoundError, T, cluster_to_connect_kwargs logger = get_logger(__name__) -class Local: +class LocalV1: def display(self, args: list[str] | tuple[str, ...]) -> None: display(args) diff --git a/milatools/cli/remote.py b/milatools/utils/remote_v1.py similarity index 99% rename from milatools/cli/remote.py rename to milatools/utils/remote_v1.py index 02b77201..e9e9620a 100644 --- a/milatools/cli/remote.py +++ b/milatools/utils/remote_v1.py @@ -19,7 +19,7 @@ from fabric import Connection from typing_extensions import Self, TypedDict, deprecated -from .utils import ( +from ..cli.utils import ( SSHConnectionError, T, cluster_to_connect_kwargs, @@ -109,7 +109,7 @@ def get_first_node_name(node_names_out: str) -> str: return base + inside_brackets.split("-")[0] -class Remote: +class RemoteV1: def __init__( self, hostname: str, @@ -444,7 +444,7 @@ def extract_script( return self.extract(shlex.join([dest, *args]), pattern=pattern, **kwargs) -class SlurmRemote(Remote): +class SlurmRemote(RemoteV1): def __init__( self, connection: fabric.Connection, @@ -533,7 +533,7 @@ def ensure_allocation( "jobid": results["jobid"], }, login_node_runner else: - remote = Remote(hostname=self.hostname, connection=self.connection) + remote = RemoteV1(hostname=self.hostname, connection=self.connection) command = shlex.join(["salloc", *self.alloc]) # We need to cd to $SCRATCH before we can run `salloc` on some clusters. command = f"cd $SCRATCH && {command}" diff --git a/milatools/utils/remote_v2.py b/milatools/utils/remote_v2.py index 37166e60..40ce0541 100644 --- a/milatools/utils/remote_v2.py +++ b/milatools/utils/remote_v2.py @@ -12,8 +12,8 @@ from paramiko import SSHConfig from milatools.cli import console -from milatools.cli.remote import Hide from milatools.cli.utils import DRAC_CLUSTERS, MilatoolsUserError +from milatools.utils.remote_v1 import Hide logger = get_logger(__name__) diff --git a/milatools/utils/vscode_utils.py b/milatools/utils/vscode_utils.py index d99816e5..bcd587b3 100644 --- a/milatools/utils/vscode_utils.py +++ b/milatools/utils/vscode_utils.py @@ -11,13 +11,12 @@ from pathlib import Path from typing import Literal, Sequence -from milatools.cli.local import Local -from milatools.cli.remote import Remote from milatools.cli.utils import ( CLUSTERS, batched, stripped_lines_of, ) +from milatools.utils.local_v1 import LocalV1 from milatools.utils.parallel_progress import ( DictProxy, ProgressDict, @@ -25,6 +24,7 @@ TaskID, parallel_progress_bar, ) +from milatools.utils.remote_v1 import RemoteV1 from milatools.utils.remote_v2 import RemoteV2 logger = get_logger(__name__) @@ -85,20 +85,20 @@ def sync_vscode_extensions_with_hostnames( if len(set(destinations)) != len(destinations): raise ValueError(f"{destinations=} contains duplicate hostnames!") - source_obj = Local() if source == "localhost" else RemoteV2(source) + source_obj = LocalV1() if source == "localhost" else RemoteV2(source) return sync_vscode_extensions(source_obj, destinations) def sync_vscode_extensions( - source: str | Local | RemoteV2, - dest_clusters: Sequence[str | Local | RemoteV2], + source: str | LocalV1 | RemoteV2, + dest_clusters: Sequence[str | LocalV1 | RemoteV2], ): """Syncs vscode extensions between `source` all all the clusters in `dest`. This spawns a thread for each cluster in `dest` and displays a parallel progress bar for the syncing of vscode extensions to each cluster. """ - if isinstance(source, Local): + if isinstance(source, LocalV1): source_hostname = "localhost" source_extensions = get_local_vscode_extensions() elif isinstance(source, RemoteV2): @@ -125,14 +125,14 @@ def sync_vscode_extensions( if dest_remote == "localhost": dest_hostname = dest_remote # type: ignore - dest_remote = Local() # pickleable - elif isinstance(dest_remote, Local): + dest_remote = LocalV1() # pickleable + elif isinstance(dest_remote, LocalV1): dest_hostname = "localhost" dest_remote = dest_remote # again, pickleable elif isinstance(dest_remote, RemoteV2): dest_hostname = dest_remote.hostname dest_remote = dest_remote # pickleable - elif isinstance(dest_remote, Remote): + elif isinstance(dest_remote, RemoteV1): # We unfortunately can't pass this kind of object to another process or # thread because it uses `fabric.Connection` which don't appear to be # pickleable. This means we will have to re-connect in the subprocess. @@ -180,7 +180,7 @@ def install_vscode_extensions_task_function( task_id: TaskID, dest_hostname: str | Literal["localhost"], source_extensions: dict[str, str], - remote: RemoteV2 | Local | None, + remote: RemoteV2 | LocalV1 | None, source_name: str, verbose: bool = False, ) -> ProgressDict: @@ -209,12 +209,12 @@ def _update_progress( if remote is None: if dest_hostname == "localhost": - remote = Local() + remote = LocalV1() else: _update_progress(0, "Connecting...") remote = RemoteV2(dest_hostname) - if isinstance(remote, Local): + if isinstance(remote, LocalV1): assert dest_hostname == "localhost" code_server_executable = get_vscode_executable_path() assert code_server_executable @@ -290,7 +290,7 @@ def _update_progress( def install_vscode_extension( - remote: Local | RemoteV2, + remote: LocalV1 | RemoteV2, code_server_executable: str, extension: str, verbose: bool = False, @@ -334,7 +334,7 @@ def get_local_vscode_extensions() -> dict[str, str]: def get_remote_vscode_extensions( - remote: Remote | RemoteV2, + remote: RemoteV1 | RemoteV2, remote_code_server_executable: str, ) -> dict[str, str]: """Returns the list of isntalled extensions and the path to the code-server @@ -387,7 +387,7 @@ def extensions_to_install( def find_code_server_executable( - remote: Remote | RemoteV2, remote_vscode_server_dir: str = "~/.vscode-server" + remote: RemoteV1 | RemoteV2, remote_vscode_server_dir: str = "~/.vscode-server" ) -> str | None: """Find the most recent `code-server` executable on the remote. diff --git a/tests/cli/test_init_command.py b/tests/cli/test_init_command.py index bf65b28e..84bd7ef7 100644 --- a/tests/cli/test_init_command.py +++ b/tests/cli/test_init_command.py @@ -38,12 +38,12 @@ setup_vscode_settings, setup_windows_ssh_config_from_wsl, ) -from milatools.cli.local import Local, check_passwordless -from milatools.cli.remote import Remote from milatools.cli.utils import ( SSHConfig, running_inside_WSL, ) +from milatools.utils.local_v1 import LocalV1, check_passwordless +from milatools.utils.remote_v1 import RemoteV1 from milatools.utils.remote_v2 import ( SSH_CACHE_DIR, SSH_CONFIG_FILE, @@ -1101,7 +1101,9 @@ def copy_fn(source, dest): @contextlib.contextmanager def backup_remote_dir( - remote: RemoteV2 | Remote, directory: PurePosixPath, backup_directory: PurePosixPath + remote: RemoteV2 | RemoteV1, + directory: PurePosixPath, + backup_directory: PurePosixPath, ): # IDEA: Make the equivalent function, but that backs up a directory on a remote # machine. @@ -1194,7 +1196,7 @@ def backup_local_ssh_cache_dir(): @pytest.fixture -def backup_remote_ssh_dir(login_node: RemoteV2 | Remote, cluster: str): +def backup_remote_ssh_dir(login_node: RemoteV2 | RemoteV1, cluster: str): """Creates a backup of the ~/.ssh directory on the remote cluster.""" if USE_MY_REAL_SSH_DIR: logger.critical( @@ -1235,7 +1237,7 @@ def backup_remote_ssh_dir(login_node: RemoteV2 | Remote, cluster: str): ) def test_setup_passwordless_ssh_access_to_cluster( cluster: str, - login_node: Remote | RemoteV2, + login_node: RemoteV1 | RemoteV2, input_pipe: PipeInput, backup_local_ssh_dir: Path, backup_local_ssh_cache_dir: Path, @@ -1539,7 +1541,7 @@ def test_setup_passwordless_ssh_access( # There should be an ssh key in the .ssh dir. # Won't ask to generate a key. create_ssh_keypair( - ssh_private_key_path=ssh_dir / "id_rsa_milatools", local=Local() + ssh_private_key_path=ssh_dir / "id_rsa_milatools", local=LocalV1() ) if drac_clusters_in_ssh_config: # We should get a prompt asking if we want to register the public key diff --git a/tests/conftest.py b/tests/conftest.py index b00a377b..088a2610 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,7 +15,7 @@ from milatools.cli import console from milatools.cli.init_command import DRAC_CLUSTERS -from milatools.cli.remote import Remote +from milatools.utils.remote_v1 import RemoteV1 from milatools.utils.remote_v2 import ( RemoteV2, get_controlpath_for, @@ -90,9 +90,9 @@ def MockConnection( __repr__=lambda _: f"Connection({repr(host)})", ), ) - import milatools.cli.remote + import milatools.utils.remote_v1 - monkeypatch.setattr(milatools.cli.remote, Connection.__name__, MockConnection) + monkeypatch.setattr(milatools.utils.remote_v1, Connection.__name__, MockConnection) return MockConnection @@ -111,11 +111,11 @@ def mock_connection( @pytest.fixture(scope="function") def remote(mock_connection: Connection): assert isinstance(mock_connection.host, str) - return Remote(hostname=mock_connection.host, connection=mock_connection) + return RemoteV1(hostname=mock_connection.host, connection=mock_connection) @pytest.fixture(scope="function") -def login_node(cluster: str) -> Remote | RemoteV2: +def login_node(cluster: str) -> RemoteV1 | RemoteV2: """Fixture that gives a Remote connected to the login node of a slurm cluster. NOTE: Making this a function-scoped fixture because the Connection object of the @@ -132,7 +132,7 @@ def login_node(cluster: str) -> Remote | RemoteV2: "prior connection to the cluster." ) if sys.platform == "win32": - return Remote(cluster) + return RemoteV1(cluster) return RemoteV2(cluster) @@ -161,7 +161,7 @@ def test_something(remote: Remote): @contextlib.contextmanager -def cancel_all_milatools_jobs_before_and_after_tests(login_node: Remote | RemoteV2): +def cancel_all_milatools_jobs_before_and_after_tests(login_node: RemoteV1 | RemoteV2): from .integration.conftest import WCKEY logger.info( diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index f709cf24..c0da07c7 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -8,7 +8,7 @@ import pytest -from milatools.cli.remote import Remote +from milatools.utils.remote_v1 import RemoteV1 from milatools.utils.remote_v2 import ( SSH_CONFIG_FILE, RemoteV2, @@ -94,7 +94,7 @@ def get_slurm_account(cluster: str) -> str: f"Fetching the list of SLURM accounts available on the {cluster} cluster." ) if sys.platform == "win32": - result = Remote(cluster).run( + result = RemoteV1(cluster).run( "sacctmgr --noheader show associations where user=$USER format=Account%50" ) else: diff --git a/tests/integration/test_code_command.py b/tests/integration/test_code_command.py index b6452a42..378a7e4c 100644 --- a/tests/integration/test_code_command.py +++ b/tests/integration/test_code_command.py @@ -10,8 +10,8 @@ import pytest from milatools.cli.commands import check_disk_quota, code -from milatools.cli.remote import Remote from milatools.cli.utils import get_fully_qualified_hostname_of_compute_node +from milatools.utils.remote_v1 import RemoteV1 from milatools.utils.remote_v2 import RemoteV2 from ..cli.common import in_github_CI, skip_param_if_on_github_ci @@ -49,7 +49,7 @@ indirect=True, ) def test_check_disk_quota( - login_node: Remote | RemoteV2, + login_node: RemoteV1 | RemoteV2, capsys: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture, ): # noqa: F811 @@ -90,7 +90,7 @@ def test_check_disk_quota( ) @pytest.mark.parametrize("persist", [True, False]) def test_code( - login_node: Remote | RemoteV2, + login_node: RemoteV1 | RemoteV2, persist: bool, capsys: pytest.CaptureFixture, allocation_flags: list[str], diff --git a/tests/integration/test_slurm_remote.py b/tests/integration/test_slurm_remote.py index 3c961542..9e0b2377 100644 --- a/tests/integration/test_slurm_remote.py +++ b/tests/integration/test_slurm_remote.py @@ -17,8 +17,8 @@ import milatools import milatools.cli import milatools.cli.utils -from milatools.cli.remote import Remote, SlurmRemote from milatools.cli.utils import CLUSTERS +from milatools.utils.remote_v1 import RemoteV1, SlurmRemote from milatools.utils.remote_v2 import RemoteV2 from ..cli.common import on_windows @@ -49,7 +49,7 @@ def can_run_on_all_clusters(): def get_recent_jobs_info_dicts( - login_node: Remote | RemoteV2, + login_node: RemoteV1 | RemoteV2, since=datetime.timedelta(minutes=5), fields=("JobID", "JobName", "Node", "State"), ) -> list[dict[str, str]]: @@ -60,7 +60,7 @@ def get_recent_jobs_info_dicts( def get_recent_jobs_info( - login_node: Remote | RemoteV2, + login_node: RemoteV1 | RemoteV2, since=datetime.timedelta(minutes=5), fields=("JobID", "JobName", "Node", "State"), ) -> list[tuple[str, ...]]: @@ -83,7 +83,7 @@ def sleep_so_sacct_can_update(): @requires_access_to_slurm_cluster -def test_cluster_setup(login_node: Remote | RemoteV2, allocation_flags: list[str]): +def test_cluster_setup(login_node: RemoteV1 | RemoteV2, allocation_flags: list[str]): """Sanity Checks for the SLURM cluster of the CI: checks that `srun` works. NOTE: This is more-so a test to check that the slurm cluster used in the GitHub CI @@ -110,9 +110,9 @@ def test_cluster_setup(login_node: Remote | RemoteV2, allocation_flags: list[str @pytest.fixture def fabric_connection_to_login_node( - login_node: Remote | RemoteV2, request: pytest.FixtureRequest + login_node: RemoteV1 | RemoteV2, request: pytest.FixtureRequest ): - if isinstance(login_node, Remote): + if isinstance(login_node, RemoteV1): return login_node.connection if login_node.hostname not in ["localhost", "mila"]: @@ -122,7 +122,7 @@ def fabric_connection_to_login_node( f"might go through 2FA!" ) ) - return Remote(login_node.hostname).connection + return RemoteV1(login_node.hostname).connection @pytest.fixture @@ -166,7 +166,7 @@ def sbatch_slurm_remote( @PARAMIKO_SSH_BANNER_BUG @requires_access_to_slurm_cluster def test_run( - login_node: Remote | RemoteV2, + login_node: RemoteV1 | RemoteV2, salloc_slurm_remote: SlurmRemote, ): """Test for `SlurmRemote.run` with persist=False without an initial call to @@ -210,7 +210,7 @@ def test_run( @hangs_in_github_CI @requires_access_to_slurm_cluster def test_ensure_allocation( - login_node: Remote | RemoteV2, + login_node: RemoteV1 | RemoteV2, salloc_slurm_remote: SlurmRemote, capsys: pytest.CaptureFixture[str], ): @@ -266,7 +266,7 @@ def test_ensure_allocation( if isinstance(login_node, RemoteV2) and login_node.hostname == "mila": assert hostname_from_remote_runner.startswith("login-") assert hostname_from_login_node_runner.startswith("login-") - elif isinstance(login_node, Remote): + elif isinstance(login_node, RemoteV1): assert hostname_from_remote_runner == hostname_from_login_node_runner # TODO: IF the remote runner was to be connected to the compute node through the @@ -317,7 +317,7 @@ def test_ensure_allocation( @hangs_in_github_CI @requires_access_to_slurm_cluster def test_ensure_allocation_sbatch( - login_node: Remote | RemoteV2, sbatch_slurm_remote: SlurmRemote + login_node: RemoteV1 | RemoteV2, sbatch_slurm_remote: SlurmRemote ): job_data, login_node_remote_runner = sbatch_slurm_remote.ensure_allocation() print(job_data, login_node_remote_runner) diff --git a/tests/integration/test_sync_command.py b/tests/integration/test_sync_command.py index 6d9a926f..0671ca57 100644 --- a/tests/integration/test_sync_command.py +++ b/tests/integration/test_sync_command.py @@ -8,7 +8,7 @@ import pytest from typing_extensions import ParamSpec -from milatools.cli.local import Local +from milatools.utils.local_v1 import LocalV1 from milatools.utils.remote_v2 import RemoteV2 from milatools.utils.vscode_utils import ( extensions_to_install, @@ -80,7 +80,7 @@ def _mock_and_patch( ) sync_vscode_extensions( - source=Local() if source == "localhost" else RemoteV2(source), + source=LocalV1() if source == "localhost" else RemoteV2(source), dest_clusters=[dest], ) diff --git a/tests/cli/test_local.py b/tests/utils/test_local_v1.py similarity index 91% rename from tests/cli/test_local.py rename to tests/utils/test_local_v1.py index ae08771a..a2b5e224 100644 --- a/tests/cli/test_local.py +++ b/tests/utils/test_local_v1.py @@ -6,10 +6,10 @@ import pytest from pytest_regressions.file_regression import FileRegressionFixture -from milatools.cli.local import CommandNotFoundError, Local, check_passwordless +from milatools.utils.local_v1 import CommandNotFoundError, LocalV1, check_passwordless from milatools.utils.remote_v2 import is_already_logged_in -from .common import ( +from ..cli.common import ( in_github_CI, in_self_hosted_github_CI, output_tester, @@ -38,7 +38,7 @@ def test_display( capsys: pytest.CaptureFixture, file_regression: FileRegressionFixture, ): - output_tester(lambda: (Local().display(cmd), None), capsys, file_regression) + output_tester(lambda: (LocalV1().display(cmd), None), capsys, file_regression) prints_unexpected_text_to_stdout_on_windows = xfails_on_windows( @@ -58,7 +58,7 @@ def test_silent_get( capsys: pytest.CaptureFixture, file_regression: FileRegressionFixture, ): - output_tester(lambda: (Local().silent_get(*cmd), None), capsys, file_regression) + output_tester(lambda: (LocalV1().silent_get(*cmd), None), capsys, file_regression) @prints_unexpected_text_to_stdout_on_windows @@ -69,7 +69,7 @@ def test_get( capsys: pytest.CaptureFixture, file_regression: FileRegressionFixture, ): - output_tester(lambda: (Local().get(*cmd), None), capsys, file_regression) + output_tester(lambda: (LocalV1().get(*cmd), None), capsys, file_regression) @prints_unexpected_text_to_stdout_on_windows @@ -81,7 +81,7 @@ def test_run( file_regression: FileRegressionFixture, ): def func(): - return Local().run(*cmd, capture_output=True), None + return LocalV1().run(*cmd, capture_output=True), None if cmd in [_FAKE_CMD, _FAIL_CODE_CMD]: @@ -108,7 +108,7 @@ def test_popen( file_regression: FileRegressionFixture, ): output_tester( - lambda: Local().popen(*cmd, stdout=PIPE, stderr=PIPE).communicate(), + lambda: LocalV1().popen(*cmd, stdout=PIPE, stderr=PIPE).communicate(), capsys, file_regression, ) diff --git a/tests/cli/test_local/test_display_cmd0_.txt b/tests/utils/test_local_v1/test_display_cmd0_.txt similarity index 100% rename from tests/cli/test_local/test_display_cmd0_.txt rename to tests/utils/test_local_v1/test_display_cmd0_.txt diff --git a/tests/cli/test_local/test_display_cmd1_.txt b/tests/utils/test_local_v1/test_display_cmd1_.txt similarity index 100% rename from tests/cli/test_local/test_display_cmd1_.txt rename to tests/utils/test_local_v1/test_display_cmd1_.txt diff --git a/tests/cli/test_local/test_get_cmd0_.txt b/tests/utils/test_local_v1/test_get_cmd0_.txt similarity index 100% rename from tests/cli/test_local/test_get_cmd0_.txt rename to tests/utils/test_local_v1/test_get_cmd0_.txt diff --git a/tests/cli/test_local/test_popen_cmd0_.txt b/tests/utils/test_local_v1/test_popen_cmd0_.txt similarity index 100% rename from tests/cli/test_local/test_popen_cmd0_.txt rename to tests/utils/test_local_v1/test_popen_cmd0_.txt diff --git a/tests/cli/test_local/test_run_cmd0_.txt b/tests/utils/test_local_v1/test_run_cmd0_.txt similarity index 100% rename from tests/cli/test_local/test_run_cmd0_.txt rename to tests/utils/test_local_v1/test_run_cmd0_.txt diff --git a/tests/cli/test_local/test_run_cmd1_.txt b/tests/utils/test_local_v1/test_run_cmd1_.txt similarity index 100% rename from tests/cli/test_local/test_run_cmd1_.txt rename to tests/utils/test_local_v1/test_run_cmd1_.txt diff --git a/tests/cli/test_local/test_run_cmd2_.txt b/tests/utils/test_local_v1/test_run_cmd2_.txt similarity index 100% rename from tests/cli/test_local/test_run_cmd2_.txt rename to tests/utils/test_local_v1/test_run_cmd2_.txt diff --git a/tests/cli/test_local/test_silent_get_cmd0_.txt b/tests/utils/test_local_v1/test_silent_get_cmd0_.txt similarity index 100% rename from tests/cli/test_local/test_silent_get_cmd0_.txt rename to tests/utils/test_local_v1/test_silent_get_cmd0_.txt diff --git a/tests/cli/test_remote.py b/tests/utils/test_remote_v1.py similarity index 95% rename from tests/cli/test_remote.py rename to tests/utils/test_remote_v1.py index b1bdcbb3..6277498c 100644 --- a/tests/cli/test_remote.py +++ b/tests/utils/test_remote_v1.py @@ -17,15 +17,15 @@ from fabric.connection import Connection from pytest_regressions.file_regression import FileRegressionFixture -from milatools.cli.remote import ( +from milatools.cli.utils import T, cluster_to_connect_kwargs +from milatools.utils.remote_v1 import ( QueueIO, - Remote, + RemoteV1, SlurmRemote, get_first_node_name, ) -from milatools.cli.utils import T, cluster_to_connect_kwargs -from .common import function_call_string +from ..cli.common import function_call_string @pytest.mark.parametrize("keepalive", [0, 123]) @@ -39,7 +39,7 @@ def test_init( # This should have called the `Connection` class with the host, which we patched in # the fixture above. - r = Remote(host, keepalive=keepalive) + r = RemoteV1(host, keepalive=keepalive) # The Remote should have created a Connection instance (which happens to be # the mock_connection we made above). MockConnection.assert_called_once_with( @@ -65,7 +65,7 @@ def test_init_with_connection( mock_connection: Mock, ): """This test shows the behaviour of __init__ to isolate it from other tests.""" - r = Remote(mock_connection.host, connection=mock_connection, keepalive=keepalive) + r = RemoteV1(mock_connection.host, connection=mock_connection, keepalive=keepalive) MockConnection.assert_not_called() assert r.connection is mock_connection # The connection is not opened, and the transport is also not opened. @@ -81,23 +81,23 @@ def test_init_with_connection( ("method", "args"), [ ( - Remote.with_transforms, + RemoteV1.with_transforms, ( lambda cmd: cmd.replace("OK", "NOT_OK"), lambda cmd: f"echo 'command before' && {cmd}", ), ), ( - Remote.wrap, + RemoteV1.wrap, ("echo 'echo wrap' && {}",), ), ( - Remote.with_precommand, + RemoteV1.with_precommand, ("echo 'echo precommand'",), ), # this need to be a file to source before running the command. - (Remote.with_profile, (".bashrc",)), - (Remote.with_bash, ()), + (RemoteV1.with_profile, (".bashrc",)), + (RemoteV1.with_bash, ()), ], ) def test_remote_transform_methods( @@ -113,13 +113,13 @@ def test_remote_transform_methods( """Test the methods of `Remote` that modify the commands passed to `run` before it gets passed to the connection and run on the server.""" mock_connection = mock_connection - r = Remote( + r = RemoteV1( host, connection=mock_connection, transforms=initial_transforms, ) # Call the method on the remote, which should return a new Remote. - modified_remote: Remote = method(r, *args) + modified_remote: RemoteV1 = method(r, *args) assert modified_remote.hostname == r.hostname assert modified_remote.connection is r.connection @@ -144,7 +144,7 @@ def test_remote_transform_methods( After creating a Remote like so: ```python -remote = {function_call_string(Remote, host, connection=mock_connection, transforms=())} +remote = {function_call_string(RemoteV1, host, connection=mock_connection, transforms=())} ``` and then calling: @@ -174,7 +174,7 @@ def test_remote_transform_methods( @pytest.mark.parametrize("message", ["foobar"]) def test_display( message: str, - remote: Remote, + remote: RemoteV1, capsys: pytest.CaptureFixture, ): remote.display(message) @@ -206,7 +206,7 @@ def hide( @pytest.mark.parametrize("warn", [True, False]) @pytest.mark.parametrize("display", [True, False, None]) def test_run( - remote: Remote, + remote: RemoteV1, command: str, expected_output: str, asynchronous: bool, @@ -278,7 +278,7 @@ def test_get_output( mock_result = Mock(wraps=invoke.runners.Result(Mock(wraps=command_output))) mock_connection.run.return_value = mock_result - r = Remote(host, connection=mock_connection) + r = RemoteV1(host, connection=mock_connection) output = r.get_output(command, display=None, hide=hide, warn=warn) assert output == command_output @@ -305,7 +305,7 @@ def test_get_lines( command = " && ".join(f"echo '{line}'" for line in expected_lines) command_output = "\n".join(expected_lines) mock_connection.run.return_value = invoke.runners.Result(stdout=command_output) - r = Remote(host, connection=mock_connection) + r = RemoteV1(host, connection=mock_connection) lines = r.get_lines(command, hide=hide, warn=warn) # NOTE: We'd expect this, but instead we get ['Line', '1', 'has', 'this', 'value', # TODO: Uncomment this if we fix `get_lines` to split based on lines, or remove this @@ -328,7 +328,7 @@ def write_lines_with_sleeps(lines: Iterable[str], sleep_time: float = 0.1): @pytest.mark.parametrize("wait", [True, False]) @pytest.mark.parametrize("pty", [True, False]) def test_extract( - remote: Remote, + remote: RemoteV1, wait: bool, pty: bool, ): @@ -362,7 +362,7 @@ def _xfail_if_not_on_localhost(host: str): pytest.xfail("This test only works on localhost.") -def test_get(remote: Remote, tmp_path: Path, host: str): +def test_get(remote: RemoteV1, tmp_path: Path, host: str): # TODO: Make this test smarter? or no need? (because we'd be testing fabric at that # point?) _xfail_if_not_on_localhost(remote.hostname) @@ -375,7 +375,7 @@ def test_get(remote: Remote, tmp_path: Path, host: str): assert dest.read_text() == source_content -def test_put(remote: Remote, tmp_path: Path): +def test_put(remote: RemoteV1, tmp_path: Path): _xfail_if_not_on_localhost(remote.hostname) src = tmp_path / "foo" dest = tmp_path / "bar" @@ -390,7 +390,7 @@ def test_put(remote: Remote, tmp_path: Path): assert dest.read_text() == source_content -def test_puttext(remote: Remote, tmp_path: Path): +def test_puttext(remote: RemoteV1, tmp_path: Path): _xfail_if_not_on_localhost(remote.hostname) dest_dir = tmp_path / "bar/baz" dest = tmp_path / f"{dest_dir}/bob.txt" @@ -403,7 +403,7 @@ def test_puttext(remote: Remote, tmp_path: Path): assert dest.read_text() == some_text -def test_home(remote: Remote): +def test_home(remote: RemoteV1): home_dir = remote.home() remote.connection.run.assert_called_once() assert remote.connection.run.mock_calls[0].args[0] == "echo $HOME" @@ -414,11 +414,11 @@ def test_home(remote: Remote): assert home_dir == str(Path.home()) -def test_persist(remote: Remote): +def test_persist(remote: RemoteV1): assert remote.persist() is remote -def test_ensure_allocation(remote: Remote): +def test_ensure_allocation(remote: RemoteV1): assert remote.ensure_allocation() == ({"node_name": remote.hostname}, None) @@ -544,7 +544,7 @@ def test_with_transforms(self, mock_connection: Connection, persist: bool | None # It isn't a very useful test, but it's better than not having one for now. # The test for Remote.run above checks that `run` on the transformed remote # does what we expect. - assert SlurmRemote.run is Remote.run + assert SlurmRemote.run is RemoteV1.run alloc = ["--time=00:01:00"] transforms = [some_transform] new_transforms = [some_other_transform] diff --git a/tests/cli/test_remote/test_QueueIO.txt b/tests/utils/test_remote_v1/test_QueueIO.txt similarity index 100% rename from tests/cli/test_remote/test_QueueIO.txt rename to tests/utils/test_remote_v1/test_QueueIO.txt diff --git a/tests/cli/test_remote/test_remote_transform_methods_localhost_with_bash_args4_initial_transforms0_echo_OK_.md b/tests/utils/test_remote_v1/test_remote_transform_methods_localhost_with_bash_args4_initial_transforms0_echo_OK_.md similarity index 82% rename from tests/cli/test_remote/test_remote_transform_methods_localhost_with_bash_args4_initial_transforms0_echo_OK_.md rename to tests/utils/test_remote_v1/test_remote_transform_methods_localhost_with_bash_args4_initial_transforms0_echo_OK_.md index 4c739f1e..d54e4e87 100644 --- a/tests/cli/test_remote/test_remote_transform_methods_localhost_with_bash_args4_initial_transforms0_echo_OK_.md +++ b/tests/utils/test_remote_v1/test_remote_transform_methods_localhost_with_bash_args4_initial_transforms0_echo_OK_.md @@ -1,7 +1,7 @@ After creating a Remote like so: ```python -remote = Remote('localhost', connection=Connection('localhost'), transforms=()) +remote = RemoteV1('localhost', connection=Connection('localhost'), transforms=()) ``` and then calling: diff --git a/tests/cli/test_remote/test_remote_transform_methods_localhost_with_precommand_args2_initial_transforms0_echo_OK_.md b/tests/utils/test_remote_v1/test_remote_transform_methods_localhost_with_precommand_args2_initial_transforms0_echo_OK_.md similarity index 84% rename from tests/cli/test_remote/test_remote_transform_methods_localhost_with_precommand_args2_initial_transforms0_echo_OK_.md rename to tests/utils/test_remote_v1/test_remote_transform_methods_localhost_with_precommand_args2_initial_transforms0_echo_OK_.md index dfa0609f..df7127a4 100644 --- a/tests/cli/test_remote/test_remote_transform_methods_localhost_with_precommand_args2_initial_transforms0_echo_OK_.md +++ b/tests/utils/test_remote_v1/test_remote_transform_methods_localhost_with_precommand_args2_initial_transforms0_echo_OK_.md @@ -1,7 +1,7 @@ After creating a Remote like so: ```python -remote = Remote('localhost', connection=Connection('localhost'), transforms=()) +remote = RemoteV1('localhost', connection=Connection('localhost'), transforms=()) ``` and then calling: diff --git a/tests/cli/test_remote/test_remote_transform_methods_localhost_with_profile_args3_initial_transforms0_echo_OK_.md b/tests/utils/test_remote_v1/test_remote_transform_methods_localhost_with_profile_args3_initial_transforms0_echo_OK_.md similarity index 82% rename from tests/cli/test_remote/test_remote_transform_methods_localhost_with_profile_args3_initial_transforms0_echo_OK_.md rename to tests/utils/test_remote_v1/test_remote_transform_methods_localhost_with_profile_args3_initial_transforms0_echo_OK_.md index f67b9f0e..efb7bcc7 100644 --- a/tests/cli/test_remote/test_remote_transform_methods_localhost_with_profile_args3_initial_transforms0_echo_OK_.md +++ b/tests/utils/test_remote_v1/test_remote_transform_methods_localhost_with_profile_args3_initial_transforms0_echo_OK_.md @@ -1,7 +1,7 @@ After creating a Remote like so: ```python -remote = Remote('localhost', connection=Connection('localhost'), transforms=()) +remote = RemoteV1('localhost', connection=Connection('localhost'), transforms=()) ``` and then calling: diff --git a/tests/cli/test_remote/test_remote_transform_methods_localhost_with_transforms_args0_initial_transforms0_echo_OK_.md b/tests/utils/test_remote_v1/test_remote_transform_methods_localhost_with_transforms_args0_initial_transforms0_echo_OK_.md similarity index 86% rename from tests/cli/test_remote/test_remote_transform_methods_localhost_with_transforms_args0_initial_transforms0_echo_OK_.md rename to tests/utils/test_remote_v1/test_remote_transform_methods_localhost_with_transforms_args0_initial_transforms0_echo_OK_.md index c7227892..2c16d3ad 100644 --- a/tests/cli/test_remote/test_remote_transform_methods_localhost_with_transforms_args0_initial_transforms0_echo_OK_.md +++ b/tests/utils/test_remote_v1/test_remote_transform_methods_localhost_with_transforms_args0_initial_transforms0_echo_OK_.md @@ -1,7 +1,7 @@ After creating a Remote like so: ```python -remote = Remote('localhost', connection=Connection('localhost'), transforms=()) +remote = RemoteV1('localhost', connection=Connection('localhost'), transforms=()) ``` and then calling: diff --git a/tests/cli/test_remote/test_remote_transform_methods_localhost_wrap_args1_initial_transforms0_echo_OK_.md b/tests/utils/test_remote_v1/test_remote_transform_methods_localhost_wrap_args1_initial_transforms0_echo_OK_.md similarity index 83% rename from tests/cli/test_remote/test_remote_transform_methods_localhost_wrap_args1_initial_transforms0_echo_OK_.md rename to tests/utils/test_remote_v1/test_remote_transform_methods_localhost_wrap_args1_initial_transforms0_echo_OK_.md index f3bc1b5a..2c0cf474 100644 --- a/tests/cli/test_remote/test_remote_transform_methods_localhost_wrap_args1_initial_transforms0_echo_OK_.md +++ b/tests/utils/test_remote_v1/test_remote_transform_methods_localhost_wrap_args1_initial_transforms0_echo_OK_.md @@ -1,7 +1,7 @@ After creating a Remote like so: ```python -remote = Remote('localhost', connection=Connection('localhost'), transforms=()) +remote = RemoteV1('localhost', connection=Connection('localhost'), transforms=()) ``` and then calling: diff --git a/tests/cli/test_remote/test_srun_transform_persist_localhost_.md b/tests/utils/test_remote_v1/test_srun_transform_persist_localhost_.md similarity index 100% rename from tests/cli/test_remote/test_srun_transform_persist_localhost_.md rename to tests/utils/test_remote_v1/test_srun_transform_persist_localhost_.md diff --git a/tests/utils/test_vscode_utils.py b/tests/utils/test_vscode_utils.py index 7b2126aa..ac4e8480 100644 --- a/tests/utils/test_vscode_utils.py +++ b/tests/utils/test_vscode_utils.py @@ -11,10 +11,10 @@ import pytest -from milatools.cli.local import Local -from milatools.cli.remote import Remote from milatools.cli.utils import running_inside_WSL +from milatools.utils.local_v1 import LocalV1 from milatools.utils.parallel_progress import ProgressDict +from milatools.utils.remote_v1 import RemoteV1 from milatools.utils.remote_v2 import RemoteV2, UnsupportedPlatformError from milatools.utils.vscode_utils import ( extensions_to_install, @@ -141,7 +141,7 @@ def test_sync_vscode_extensions_in_parallel_with_hostnames( @requires_vscode @requires_ssh_to_localhost def test_sync_vscode_extensions_in_parallel(): - results = sync_vscode_extensions(Local(), dest_clusters=[Local()]) + results = sync_vscode_extensions(LocalV1(), dest_clusters=[LocalV1()]) assert results == {"localhost": {"info": "Done.", "progress": 0, "total": 0}} @@ -206,7 +206,7 @@ def missing_extensions( def _remote(hostname: str): - return RemoteV2(hostname) if sys.platform != "win32" else Remote(hostname) + return RemoteV2(hostname) if sys.platform != "win32" else RemoteV1(hostname) @uses_remote_v2