diff --git a/milatools/cli/init_command.py b/milatools/cli/init_command.py index 27a891c8..a2586bae 100644 --- a/milatools/cli/init_command.py +++ b/milatools/cli/init_command.py @@ -14,6 +14,7 @@ import questionary as qn from invoke.exceptions import UnexpectedExit +from paramiko.config import SSHConfig as SSHConfigReader from ..utils.vscode_utils import ( get_expected_vscode_settings_json_path, @@ -21,7 +22,8 @@ ) from .local import Local, check_passwordless, display from .remote import Remote -from .utils import SSHConfig, T, running_inside_WSL, yn +from .utils import SSHConfig as SSHConfigWriter +from .utils import T, running_inside_WSL, yn logger = get_logger(__name__) @@ -66,8 +68,8 @@ # NOTE: will not work with --gres prior to Slurm 22.05, because srun --overlap # cannot share gpus "ProxyCommand": ( - 'ssh mila "/cvmfs/config.mila.quebec/scripts/milatools/slurm-proxy.sh ' - 'mila-cpu --mem=8G"' + "ssh mila " + '"/cvmfs/config.mila.quebec/scripts/milatools/slurm-proxy.sh mila-cpu --mem=8G"' ), "RemoteCommand": ( "/cvmfs/config.mila.quebec/scripts/milatools/entrypoint.sh mila-cpu" @@ -112,7 +114,7 @@ def setup_ssh_config( ssh_config_path: str | Path = "~/.ssh/config", -) -> SSHConfig: +) -> SSHConfigReader: """Interactively sets up some useful entries in the ~/.ssh/config file on the local machine. @@ -132,14 +134,18 @@ def setup_ssh_config( """ ssh_config_path = _setup_ssh_config_file(ssh_config_path) - ssh_config = SSHConfig(ssh_config_path) + + ssh_config = SSHConfigReader.from_path(str(ssh_config_path)) + ssh_config_writer = SSHConfigWriter(ssh_config_path) + mila_username: str = _get_mila_username(ssh_config) drac_username: str | None = _get_drac_username(ssh_config) - orig_config = ssh_config.cfg.config() + + orig_config = ssh_config_writer.cfg.config() for hostname, entry in MILA_ENTRIES.copy().items(): entry.update(User=mila_username) - _add_ssh_entry(ssh_config, hostname, entry) + _add_ssh_entry(ssh_config_writer, hostname, entry) _make_controlpath_dir(entry) if drac_username: @@ -148,31 +154,31 @@ def setup_ssh_config( ) for hostname, entry in DRAC_ENTRIES.copy().items(): entry.update(User=drac_username) - _add_ssh_entry(ssh_config, hostname, entry) + _add_ssh_entry(ssh_config_writer, hostname, entry) _make_controlpath_dir(entry) # Check for *.server.mila.quebec in ssh config, to connect to compute nodes old_cnode_pattern = "*.server.mila.quebec" - if old_cnode_pattern in ssh_config.hosts(): + if old_cnode_pattern in ssh_config_writer.hosts(): logger.info( f"The '{old_cnode_pattern}' entry in ~/.ssh/config is too general and " "should exclude login.server.mila.quebec. Fixing this." ) - ssh_config.remove(old_cnode_pattern) + ssh_config_writer.remove(old_cnode_pattern) - new_config = ssh_config.cfg.config() + new_config = ssh_config_writer.cfg.config() if orig_config == new_config: print("Did not change ssh config") - elif not _confirm_changes(ssh_config, previous=orig_config): + elif not _confirm_changes(ssh_config_writer, previous=orig_config): exit("Did not change ssh config") else: - ssh_config.save() + ssh_config_writer.save() print(f"Wrote {ssh_config_path}") return ssh_config -def setup_windows_ssh_config_from_wsl(linux_ssh_config: SSHConfig): +def setup_windows_ssh_config_from_wsl(linux_ssh_config: SSHConfigWriter): """Setup the Windows SSH configuration and public key from within WSL. This copies over the entries from the linux ssh configuration file, except for the @@ -191,7 +197,7 @@ def setup_windows_ssh_config_from_wsl(linux_ssh_config: SSHConfig): windows_ssh_config_path = windows_home / ".ssh/config" windows_ssh_config_path = _setup_ssh_config_file(windows_ssh_config_path) - windows_ssh_config = SSHConfig(windows_ssh_config_path) + windows_ssh_config = SSHConfigWriter(windows_ssh_config_path) initial_windows_config_contents = windows_ssh_config.cfg.config() _copy_valid_ssh_entries_to_windows_ssh_config_file( @@ -226,7 +232,7 @@ def setup_windows_ssh_config_from_wsl(linux_ssh_config: SSHConfig): _copy_if_needed(linux_key_file, windows_key_file) -def setup_passwordless_ssh_access(ssh_config: SSHConfig) -> bool: +def setup_passwordless_ssh_access(ssh_config: SSHConfigWriter) -> bool: """Sets up passwordless ssh access to the Mila and optionally also to DRAC. Sets up ssh connection to the DRAC clusters if they are present in the SSH config @@ -586,30 +592,18 @@ def ask_to_confirm_changes(before: str, after: str, path: str | Path) -> bool: return yn("\nIs this OK?") -def _confirm_changes(ssh_config: SSHConfig, previous: str) -> bool: +def _confirm_changes(ssh_config: SSHConfigWriter, previous: str) -> bool: before = previous + "\n" after = ssh_config.cfg.config() + "\n" return ask_to_confirm_changes(before, after, ssh_config.path) -def _get_mila_username(ssh_config: SSHConfig) -> str: +def _get_mila_username(ssh_config: SSHConfigReader) -> str: # Check for a mila entry in ssh config # NOTE: This also supports the case where there's a 'HOST mila some_alias_for_mila' # entry. # NOTE: ssh_config.host(entry) returns an empty dictionary if there is no entry. - username: str | None = None - hosts_with_mila_in_name_and_a_user_entry = [ - host - for host in ssh_config.hosts() - if "mila" in host.split() and "user" in ssh_config.host(host) - ] - # Note: If there are none, or more than one, then we'll ask the user for their - # username, just to be sure. - if len(hosts_with_mila_in_name_and_a_user_entry) == 1: - username = ssh_config.host(hosts_with_mila_in_name_and_a_user_entry[0]).get( - "user" - ) - + username: str | None = ssh_config.lookup("mila").get("user") while not username: username = qn.text( "What's your username on the mila cluster?\n", @@ -618,36 +612,32 @@ def _get_mila_username(ssh_config: SSHConfig) -> str: return username.strip() -def _get_drac_username(ssh_config: SSHConfig) -> str | None: +def _get_drac_username(ssh_config: SSHConfigReader) -> str | None: """Retrieve or ask the user for their username on the ComputeCanada/DRAC clusters.""" # Check for one of the DRAC entries in ssh config - username: str | None = None - hosts_with_cluster_in_name_and_a_user_entry = [ - host - for host in ssh_config.hosts() - if any( - cc_cluster in host.split() or f"!{cc_cluster}" in host.split() - for cc_cluster in DRAC_CLUSTERS - ) - and "user" in ssh_config.host(host) - ] - users_from_drac_config_entries = set( - ssh_config.host(host)["user"] - for host in hosts_with_cluster_in_name_and_a_user_entry + users_from_drac_config_entries: set[str] = set( + drac_cluster_username + for drac_cluster in DRAC_CLUSTERS + if (drac_cluster_username := ssh_config.lookup(drac_cluster).get("user")) + is not None ) + if len(users_from_drac_config_entries) == 1: + return users_from_drac_config_entries.pop().strip() + + username: str | None # Note: If there are none, or more than one, then we'll ask the user for their # username, just to be sure. - if len(users_from_drac_config_entries) == 1: - username = users_from_drac_config_entries.pop() - elif yn("Do you also have an account on the ComputeCanada/DRAC clusters?"): - while not username: - username = qn.text( + if yn("Do you also have an account on the ComputeCanada/DRAC clusters?"): + while not ( + username := qn.text( "What's your username on the CC/DRAC clusters?\n", validate=functools.partial( _is_valid_username, cluster_name="ComputeCanada/DRAC clusters" ), ).unsafe_ask() + ): + pass return username.strip() if username else None @@ -666,7 +656,7 @@ def _is_valid_username(text: str, cluster_name: str = "mila cluster") -> bool | def _add_ssh_entry( - ssh_config: SSHConfig, + ssh_config: SSHConfigWriter, host: str, entry: dict[str, str | int], *, @@ -701,7 +691,7 @@ def _add_ssh_entry( def _copy_valid_ssh_entries_to_windows_ssh_config_file( - linux_ssh_config: SSHConfig, windows_ssh_config: SSHConfig + linux_ssh_config: SSHConfigWriter, windows_ssh_config: SSHConfigWriter ): unsupported_keys_lowercase = set(k.lower() for k in WINDOWS_UNSUPPORTED_KEYS)