Skip to content

Commit

Permalink
Begin transition of mila init to RemoteV2
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
  • Loading branch information
lebrice committed Apr 18, 2024
1 parent 7c176b8 commit ff02514
Showing 1 changed file with 42 additions and 53 deletions.
95 changes: 42 additions & 53 deletions milatools/cli/init_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import questionary as qn
from invoke.exceptions import UnexpectedExit
from paramiko.config import SSHConfig as SSHConfigReader

from milatools.utils.remote_v2 import SSH_CONFIG_FILE

Expand All @@ -23,7 +24,8 @@
get_expected_vscode_settings_json_path,
vscode_installed,
)
from .utils import SSHConfig, T, running_inside_WSL, yn
from .utils import SSHConfig as SSHConfigWriter
from .utils import T, running_inside_WSL, yn

logger = get_logger(__name__)

Expand Down Expand Up @@ -68,8 +70,8 @@
# NOTE: will not work with --gres prior to Slurm 22.05, because srun --overlap
# cannot share gpus
"ProxyCommand": (
'ssh mila "/cvmfs/config.mila.quebec/scripts/milatools/slurm-proxy.sh '
'mila-cpu --mem=8G"'
"ssh mila "
'"/cvmfs/config.mila.quebec/scripts/milatools/slurm-proxy.sh mila-cpu --mem=8G"'
),
"RemoteCommand": (
"/cvmfs/config.mila.quebec/scripts/milatools/entrypoint.sh mila-cpu"
Expand Down Expand Up @@ -114,7 +116,7 @@

def setup_ssh_config(
ssh_config_path: str | Path = "~/.ssh/config",
) -> SSHConfig:
) -> SSHConfigReader:
"""Interactively sets up some useful entries in the ~/.ssh/config file on the local
machine.
Expand All @@ -134,14 +136,18 @@ def setup_ssh_config(
"""

ssh_config_path = _setup_ssh_config_file(ssh_config_path)
ssh_config = SSHConfig(ssh_config_path)

ssh_config = SSHConfigReader.from_path(str(ssh_config_path))
ssh_config_writer = SSHConfigWriter(ssh_config_path)

mila_username: str = _get_mila_username(ssh_config)
drac_username: str | None = _get_drac_username(ssh_config)
orig_config = ssh_config.cfg.config()

orig_config = ssh_config_writer.cfg.config()

for hostname, entry in MILA_ENTRIES.copy().items():
entry.update(User=mila_username)
_add_ssh_entry(ssh_config, hostname, entry)
_add_ssh_entry(ssh_config_writer, hostname, entry)
_make_controlpath_dir(entry)

if drac_username:
Expand All @@ -150,31 +156,31 @@ def setup_ssh_config(
)
for hostname, entry in DRAC_ENTRIES.copy().items():
entry.update(User=drac_username)
_add_ssh_entry(ssh_config, hostname, entry)
_add_ssh_entry(ssh_config_writer, hostname, entry)
_make_controlpath_dir(entry)

# Check for *.server.mila.quebec in ssh config, to connect to compute nodes
old_cnode_pattern = "*.server.mila.quebec"

if old_cnode_pattern in ssh_config.hosts():
if old_cnode_pattern in ssh_config_writer.hosts():
logger.info(
f"The '{old_cnode_pattern}' entry in ~/.ssh/config is too general and "
"should exclude login.server.mila.quebec. Fixing this."
)
ssh_config.remove(old_cnode_pattern)
ssh_config_writer.remove(old_cnode_pattern)

new_config = ssh_config.cfg.config()
new_config = ssh_config_writer.cfg.config()
if orig_config == new_config:
print("Did not change ssh config")
elif not _confirm_changes(ssh_config, previous=orig_config):
elif not _confirm_changes(ssh_config_writer, previous=orig_config):
exit("Did not change ssh config")
else:
ssh_config.save()
ssh_config_writer.save()
print(f"Wrote {ssh_config_path}")
return ssh_config


def setup_windows_ssh_config_from_wsl(linux_ssh_config: SSHConfig):
def setup_windows_ssh_config_from_wsl(linux_ssh_config: SSHConfigWriter):
"""Setup the Windows SSH configuration and public key from within WSL.
This copies over the entries from the linux ssh configuration file, except for the
Expand All @@ -193,7 +199,7 @@ def setup_windows_ssh_config_from_wsl(linux_ssh_config: SSHConfig):
windows_ssh_config_path = windows_home / ".ssh/config"
windows_ssh_config_path = _setup_ssh_config_file(windows_ssh_config_path)

windows_ssh_config = SSHConfig(windows_ssh_config_path)
windows_ssh_config = SSHConfigWriter(windows_ssh_config_path)

initial_windows_config_contents = windows_ssh_config.cfg.config()
_copy_valid_ssh_entries_to_windows_ssh_config_file(
Expand Down Expand Up @@ -228,7 +234,7 @@ def setup_windows_ssh_config_from_wsl(linux_ssh_config: SSHConfig):
_copy_if_needed(linux_key_file, windows_key_file)


def setup_passwordless_ssh_access(ssh_config: SSHConfig) -> bool:
def setup_passwordless_ssh_access(ssh_config: SSHConfigWriter) -> bool:
"""Sets up passwordless ssh access to the Mila and optionally also to DRAC.
Sets up ssh connection to the DRAC clusters if they are present in the SSH config
Expand Down Expand Up @@ -299,7 +305,6 @@ def setup_passwordless_ssh_access_to_cluster(cluster: str) -> bool:
print(f"Checking if passwordless SSH access is setup for the {cluster} cluster.")
# TODO: Potentially use a custom key like `~/.ssh/id_milatools.pub` instead of
# the default.

from paramiko.config import SSHConfig

config = SSHConfig.from_path(str(SSH_CONFIG_FILE))
Expand Down Expand Up @@ -616,30 +621,18 @@ def ask_to_confirm_changes(before: str, after: str, path: str | Path) -> bool:
return yn("\nIs this OK?")


def _confirm_changes(ssh_config: SSHConfig, previous: str) -> bool:
def _confirm_changes(ssh_config: SSHConfigWriter, previous: str) -> bool:
before = previous + "\n"
after = ssh_config.cfg.config() + "\n"
return ask_to_confirm_changes(before, after, ssh_config.path)


def _get_mila_username(ssh_config: SSHConfig) -> str:
def _get_mila_username(ssh_config: SSHConfigReader) -> str:
# Check for a mila entry in ssh config
# NOTE: This also supports the case where there's a 'HOST mila some_alias_for_mila'
# entry.
# NOTE: ssh_config.host(entry) returns an empty dictionary if there is no entry.
username: str | None = None
hosts_with_mila_in_name_and_a_user_entry = [
host
for host in ssh_config.hosts()
if "mila" in host.split() and "user" in ssh_config.host(host)
]
# Note: If there are none, or more than one, then we'll ask the user for their
# username, just to be sure.
if len(hosts_with_mila_in_name_and_a_user_entry) == 1:
username = ssh_config.host(hosts_with_mila_in_name_and_a_user_entry[0]).get(
"user"
)

username: str | None = ssh_config.lookup("mila").get("user")
while not username:
username = qn.text(
"What's your username on the mila cluster?\n",
Expand All @@ -648,36 +641,32 @@ def _get_mila_username(ssh_config: SSHConfig) -> str:
return username.strip()


def _get_drac_username(ssh_config: SSHConfig) -> str | None:
def _get_drac_username(ssh_config: SSHConfigReader) -> str | None:
"""Retrieve or ask the user for their username on the ComputeCanada/DRAC
clusters."""
# Check for one of the DRAC entries in ssh config
username: str | None = None
hosts_with_cluster_in_name_and_a_user_entry = [
host
for host in ssh_config.hosts()
if any(
cc_cluster in host.split() or f"!{cc_cluster}" in host.split()
for cc_cluster in DRAC_CLUSTERS
)
and "user" in ssh_config.host(host)
]
users_from_drac_config_entries = set(
ssh_config.host(host)["user"]
for host in hosts_with_cluster_in_name_and_a_user_entry
users_from_drac_config_entries: set[str] = set(
drac_cluster_username
for drac_cluster in DRAC_CLUSTERS
if (drac_cluster_username := ssh_config.lookup(drac_cluster).get("user"))
is not None
)
if len(users_from_drac_config_entries) == 1:
return users_from_drac_config_entries.pop().strip()

username: str | None
# Note: If there are none, or more than one, then we'll ask the user for their
# username, just to be sure.
if len(users_from_drac_config_entries) == 1:
username = users_from_drac_config_entries.pop()
elif yn("Do you also have an account on the ComputeCanada/DRAC clusters?"):
while not username:
username = qn.text(
if yn("Do you also have an account on the ComputeCanada/DRAC clusters?"):
while not (
username := qn.text(
"What's your username on the CC/DRAC clusters?\n",
validate=functools.partial(
_is_valid_username, cluster_name="ComputeCanada/DRAC clusters"
),
).unsafe_ask()
):
pass
return username.strip() if username else None


Expand All @@ -696,7 +685,7 @@ def _is_valid_username(text: str, cluster_name: str = "mila cluster") -> bool |


def _add_ssh_entry(
ssh_config: SSHConfig,
ssh_config: SSHConfigWriter,
host: str,
entry: dict[str, str | int],
*,
Expand Down Expand Up @@ -731,7 +720,7 @@ def _add_ssh_entry(


def _copy_valid_ssh_entries_to_windows_ssh_config_file(
linux_ssh_config: SSHConfig, windows_ssh_config: SSHConfig
linux_ssh_config: SSHConfigWriter, windows_ssh_config: SSHConfigWriter
):
unsupported_keys_lowercase = set(k.lower() for k in WINDOWS_UNSUPPORTED_KEYS)

Expand Down

0 comments on commit ff02514

Please sign in to comment.