Skip to content

Commit

Permalink
Move functions a bit in remote_v2.py
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
  • Loading branch information
lebrice committed May 1, 2024
1 parent d9022c9 commit 5963880
Showing 1 changed file with 105 additions and 105 deletions.
210 changes: 105 additions & 105 deletions milatools/utils/remote_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,111 +30,6 @@ class UnsupportedPlatformError(MilatoolsUserError):
...


def raise_error_if_running_on_windows():
if sys.platform == "win32":
raise UnsupportedPlatformError(
"This feature isn't supported on Windows, as it requires an SSH client "
"with SSH multiplexing support (ControlMaster, ControlPath and "
"ControlPersist).\n"
"Please consider switching to the Windows Subsystem for Linux (WSL).\n"
"See https://learn.microsoft.com/en-us/windows/wsl/install for a guide on "
"setting up WSL."
)


# note: Could potentially cache the results of this function if we wanted to, assuming
# that the ssh config file doesn't change.


def ssh_command(
hostname: str,
control_path: Path,
command: str,
control_master: Literal["yes", "no", "auto", "ask", "autoask"] = "auto",
control_persist: int | str | Literal["yes", "no"] = "yes",
ssh_config_path: Path = SSH_CONFIG_FILE,
):
"""Returns a tuple of strings to be used as the command to be run in a subprocess.
When the path to the SSH config file is passed and exists, this will only add the
options which aren't already set in the SSH config, so as to avoid redundant
arguments to the `ssh` command.
Parameters
----------
hostname: The hostname to connect to.
control_path : See https://man.openbsd.org/ssh_config#ControlPath
command: The command to run on the remote host (kept as a string).
control_master: See https://man.openbsd.org/ssh_config#ControlMaster
control_persist: See https://man.openbsd.org/ssh_config#ControlPersist
ssh_config_path: Path to the ssh config file.
Returns
-------
The tuple of strings to pass to `subprocess.run` or similar.
"""
control_path = control_path.expanduser()
ssh_config_path = ssh_config_path.expanduser()
if not ssh_config_path.exists():
return (
"ssh",
f"-oControlMaster={control_master}",
f"-oControlPersist={control_persist}",
f"-oControlPath={control_path}",
hostname,
command,
)

ssh_command: list[str] = ["ssh"]
ssh_config_entry = SSHConfig.from_path(str(ssh_config_path)).lookup(hostname)
if ssh_config_entry.get("controlmaster") != control_master:
ssh_command.append(f"-oControlMaster={control_master}")
if ssh_config_entry.get("controlpersist") != control_persist:
ssh_command.append(f"-oControlPersist={control_persist}")

control_path_in_config = ssh_config_entry.get("controlpath")
if (
control_path_in_config is None
or Path(control_path_in_config).expanduser() != control_path
):
# Only add the ControlPath arg if it is not in the config, or if it differs from
# the value in the config.
ssh_command.append(f"-oControlPath={control_path}")
ssh_command.append(hostname)
# NOTE: Not quoting the command here, `subprocess.run` does it (since shell=False).
ssh_command.append(command)
return tuple(ssh_command)


def control_socket_is_running(host: str, control_path: Path) -> bool:
"""Check whether the control socket at the given path is running."""
control_path = control_path.expanduser()
if not control_path.exists():
return False

result = subprocess.run(
(
"ssh",
"-O",
"check",
f"-oControlPath={control_path}",
host,
),
check=False,
capture_output=True,
text=True,
shell=False,
)
if (
result.returncode != 0
or not result.stderr
or not result.stderr.startswith("Master running")
):
logger.debug(f"{control_path=} doesn't exist or isn't running: {result=}.")
return False
return True


@dataclasses.dataclass(init=False)
class RemoteV2(Runner):
"""Simpler Remote where commands are run in subprocesses sharing an SSH connection.
Expand Down Expand Up @@ -290,6 +185,111 @@ async def _start_async(self) -> None:
self._started = True


def raise_error_if_running_on_windows():
if sys.platform == "win32":
raise UnsupportedPlatformError(
"This feature isn't supported on Windows, as it requires an SSH client "
"with SSH multiplexing support (ControlMaster, ControlPath and "
"ControlPersist).\n"
"Please consider switching to the Windows Subsystem for Linux (WSL).\n"
"See https://learn.microsoft.com/en-us/windows/wsl/install for a guide on "
"setting up WSL."
)


# note: Could potentially cache the results of this function if we wanted to, assuming
# that the ssh config file doesn't change.


def ssh_command(
hostname: str,
control_path: Path,
command: str,
control_master: Literal["yes", "no", "auto", "ask", "autoask"] = "auto",
control_persist: int | str | Literal["yes", "no"] = "yes",
ssh_config_path: Path = SSH_CONFIG_FILE,
):
"""Returns a tuple of strings to be used as the command to be run in a subprocess.
When the path to the SSH config file is passed and exists, this will only add the
options which aren't already set in the SSH config, so as to avoid redundant
arguments to the `ssh` command.
Parameters
----------
hostname: The hostname to connect to.
control_path : See https://man.openbsd.org/ssh_config#ControlPath
command: The command to run on the remote host (kept as a string).
control_master: See https://man.openbsd.org/ssh_config#ControlMaster
control_persist: See https://man.openbsd.org/ssh_config#ControlPersist
ssh_config_path: Path to the ssh config file.
Returns
-------
The tuple of strings to pass to `subprocess.run` or similar.
"""
control_path = control_path.expanduser()
ssh_config_path = ssh_config_path.expanduser()
if not ssh_config_path.exists():
return (
"ssh",
f"-oControlMaster={control_master}",
f"-oControlPersist={control_persist}",
f"-oControlPath={control_path}",
hostname,
command,
)

ssh_command: list[str] = ["ssh"]
ssh_config_entry = SSHConfig.from_path(str(ssh_config_path)).lookup(hostname)
if ssh_config_entry.get("controlmaster") != control_master:
ssh_command.append(f"-oControlMaster={control_master}")
if ssh_config_entry.get("controlpersist") != control_persist:
ssh_command.append(f"-oControlPersist={control_persist}")

control_path_in_config = ssh_config_entry.get("controlpath")
if (
control_path_in_config is None
or Path(control_path_in_config).expanduser() != control_path
):
# Only add the ControlPath arg if it is not in the config, or if it differs from
# the value in the config.
ssh_command.append(f"-oControlPath={control_path}")
ssh_command.append(hostname)
# NOTE: Not quoting the command here, `subprocess.run` does it (since shell=False).
ssh_command.append(command)
return tuple(ssh_command)


def control_socket_is_running(host: str, control_path: Path) -> bool:
"""Check whether the control socket at the given path is running."""
control_path = control_path.expanduser()
if not control_path.exists():
return False

result = subprocess.run(
(
"ssh",
"-O",
"check",
f"-oControlPath={control_path}",
host,
),
check=False,
capture_output=True,
text=True,
shell=False,
)
if (
result.returncode != 0
or not result.stderr
or not result.stderr.startswith("Master running")
):
logger.debug(f"{control_path=} doesn't exist or isn't running: {result=}.")
return False
return True


async def control_socket_is_running_async(host: str, control_path: Path) -> bool:
"""Check whether the control socket at the given path is running asynchronously."""
control_path = control_path.expanduser()
Expand Down

0 comments on commit 5963880

Please sign in to comment.