Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MT-68] Add ControlMaster ssh options in mila init #58

Merged
merged 18 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 32 additions & 7 deletions milatools/cli/init_command.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from __future__ import annotations

import difflib
from logging import getLogger as get_logger
from pathlib import Path

import questionary as qn

from .utils import SSHConfig, T, yn

logger = get_logger(__name__)


def setup_ssh_config(
ssh_config_path: str | Path = "~/.ssh/config",
Expand All @@ -31,6 +34,11 @@ def setup_ssh_config(
username: str = _get_username(ssh_config)
orig_config = ssh_config.cfg.config()

control_path_dir = Path("~/.cache/ssh")
# note: a bit nicer to keep the "~" in the path in the ssh config file, but we need to make
# sure that the directory actually exists.
control_path_dir.expanduser().mkdir(exist_ok=True, parents=True)

_add_ssh_entry(
ssh_config,
"mila",
Expand All @@ -40,6 +48,12 @@ def setup_ssh_config(
Port=2222,
ServerAliveInterval=120,
ServerAliveCountMax=5,
# Tries to reuse an existing connection, but if it fails, it will create a new one.
ControlMaster="auto",
# This makes a file per connection, like normandf@login.server.mila.quebec:2222
ControlPath=str(control_path_dir / r"%r@%h:%p"),
# persist for 10 minutes after the last connection ends.
ControlPersist=600,
)

_add_ssh_entry(
Expand All @@ -54,7 +68,8 @@ def setup_ssh_config(
RequestTTY="force",
ConnectTimeout=600,
ServerAliveInterval=120,
# NOTE: will not work with --gres prior to Slurm 22.05, because srun --overlap cannot share it
# NOTE: will not work with --gres prior to Slurm 22.05, because srun --overlap cannot share
# it
ProxyCommand=(
'ssh mila "/cvmfs/config.mila.quebec/scripts/milatools/slurm-proxy.sh mila-cpu --mem=8G"'
),
Expand Down Expand Up @@ -82,6 +97,12 @@ def setup_ssh_config(
HostName="%h",
User=username,
ProxyJump="mila",
# Tries to reuse an existing connection, but if it fails, it will create a new one.
ControlMaster="auto",
# This makes a file per connection, like normandf@login.server.mila.quebec:2222
ControlPath=str(control_path_dir / r"%r@%h:%p"),
# persist for 10 minutes after the last connection ends.
ControlPersist=600,
)

new_config = ssh_config.cfg.config()
Expand Down Expand Up @@ -162,7 +183,8 @@ def _get_username(ssh_config: SSHConfig) -> str:

while not username:
username = qn.text(
"What's your username on the mila cluster?\n", validate=_is_valid_username
"What's your username on the mila cluster?\n",
validate=_is_valid_username,
).unsafe_ask()
return username.strip()

Expand All @@ -185,7 +207,7 @@ def _add_ssh_entry(
host: str,
Host: str | None = None,
**entry,
) -> bool:
) -> None:
"""Interactively add an entry to the ssh config file.

Exits if the user doesn't want to add an entry or doesn't confirm the change.
Expand All @@ -196,7 +218,10 @@ def _add_ssh_entry(
assert not (host and Host)
host = Host or host
if host in ssh_config.hosts():
# Don't change an existing entry for now.
return False
ssh_config.add(host, **entry)
return True
existing_entry = ssh_config.host(host)
existing_entry.update(entry)
ssh_config.cfg.set(host, **existing_entry)
logger.debug(f"Updated {host} entry in ssh config.")
else:
ssh_config.add(host, **entry)
logger.debug(f"Adding new {host} entry in ssh config.")
3 changes: 2 additions & 1 deletion milatools/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,15 @@ def shjoin(split_command):
class SSHConfig:
"""Wrapper around sshconf with some extra niceties."""

def __init__(self, path):
def __init__(self, path: str | Path):
lebrice marked this conversation as resolved.
Show resolved Hide resolved
self.cfg = read_ssh_config(path)
self.add = self.cfg.add
self.remove = self.cfg.remove
self.rename = self.cfg.rename
self.save = self.cfg.save
self.host = self.cfg.host
self.hosts = self.cfg.hosts
self.set = self.cfg.set

def hoststring(self, host):
lines = []
Expand Down
Loading