diff --git a/milatools/cli/commands.py b/milatools/cli/commands.py index d7881018..891d967d 100644 --- a/milatools/cli/commands.py +++ b/milatools/cli/commands.py @@ -31,28 +31,11 @@ import rich.logging from typing_extensions import TypedDict +from milatools.__version__ import __version__ from milatools.cli import console -from milatools.utils.local_v1 import LocalV1 -from milatools.utils.remote_v1 import RemoteV1, SlurmRemote -from milatools.utils.remote_v2 import RemoteV2 -from milatools.utils.vscode_utils import ( - get_code_command, - # install_local_vscode_extensions_on_remote, - sync_vscode_extensions, - sync_vscode_extensions_with_hostnames, -) - -from ..__version__ import __version__ -from .init import ( - print_welcome_message, - setup_keys_on_login_node, - setup_passwordless_ssh_access, - setup_ssh_config, - setup_vscode_settings, - setup_windows_ssh_config_from_wsl, -) -from .profile import ensure_program, setup_profile -from .utils import ( +from milatools.cli.init import init +from milatools.cli.profile import ensure_program, setup_profile +from milatools.cli.utils import ( CLUSTERS, Cluster, CommandNotFoundError, @@ -69,6 +52,14 @@ running_inside_WSL, with_control_file, ) +from milatools.utils.local_v1 import LocalV1 +from milatools.utils.remote_v1 import RemoteV1, SlurmRemote +from milatools.utils.remote_v2 import RemoteV2 +from milatools.utils.vscode_utils import ( + get_code_command, + sync_vscode_extensions, + sync_vscode_extensions_with_hostnames, +) if typing.TYPE_CHECKING: from typing_extensions import Unpack @@ -492,32 +483,6 @@ def intranet(search: Sequence[str]) -> None: webbrowser.open(url) -def init(): - """Set up your configuration and credentials.""" - - ############################# - # Step 1: SSH Configuration # - ############################# - - print("Checking ssh config") - - ssh_config = setup_ssh_config() - - # if we're running on WSL, we actually just copy the id_rsa + id_rsa.pub and the - # ~/.ssh/config to the Windows ssh directory (taking care to remove the - # ControlMaster-related entries) so that the user doesn't need to install Python on - # the Windows side. - if running_inside_WSL(): - setup_windows_ssh_config_from_wsl(linux_ssh_config=ssh_config) - - success = setup_passwordless_ssh_access(ssh_config=ssh_config) - if not success: - exit() - setup_keys_on_login_node() - setup_vscode_settings() - print_welcome_message() - - def forward( remote: str, page: str | None, diff --git a/milatools/cli/init.py b/milatools/cli/init.py index a132ea06..caa81cab 100644 --- a/milatools/cli/init.py +++ b/milatools/cli/init.py @@ -9,23 +9,23 @@ import sys import warnings from logging import getLogger as get_logger -from pathlib import Path +from pathlib import Path, PosixPath from typing import Any import questionary as qn from invoke.exceptions import UnexpectedExit from paramiko.config import SSHConfig as SSHConfigReader -from milatools.utils.remote_v2 import SSH_CONFIG_FILE - -from ..utils.local_v1 import LocalV1, check_passwordless, display -from ..utils.remote_v1 import RemoteV1 -from ..utils.vscode_utils import ( +from milatools.cli import console +from milatools.cli.utils import SSHConfig as SSHConfigWriter +from milatools.cli.utils import T, running_inside_WSL, yn +from milatools.utils.local_v1 import check_passwordless, display +from milatools.utils.local_v2 import LocalV2 +from milatools.utils.remote_v1 import RemoteV1 +from milatools.utils.vscode_utils import ( get_expected_vscode_settings_json_path, vscode_installed, ) -from .utils import SSHConfig as SSHConfigWriter -from .utils import T, running_inside_WSL, yn logger = get_logger(__name__) @@ -114,6 +114,34 @@ } +def init(): + """Set up your configuration and credentials.""" + + ############################# + # Step 1: SSH Configuration # + ############################# + + print("Checking ssh config") + ssh_config_path = Path("~/.ssh/config").expanduser() + + setup_ssh_config(ssh_config_path=ssh_config_path) + + # if we're running on WSL, we actually just copy the id_rsa + id_rsa.pub and the + # ~/.ssh/config to the Windows ssh directory (taking care to remove the + # ControlMaster-related entries) so that the user doesn't need to install Python on + # the Windows side. + if running_inside_WSL(): + assert isinstance(ssh_config_path, PosixPath) # we're running in linux (WSL). + setup_windows_ssh_config_from_wsl(linux_ssh_config_path=ssh_config_path) + + success = setup_passwordless_ssh_access(ssh_config_path) + if not success: + exit() + setup_keys_on_login_node() + setup_vscode_settings() + print_welcome_message() + + def setup_ssh_config( ssh_config_path: str | Path = "~/.ssh/config", ) -> SSHConfigReader: @@ -180,7 +208,7 @@ def setup_ssh_config( return ssh_config -def setup_windows_ssh_config_from_wsl(linux_ssh_config: SSHConfigWriter): +def setup_windows_ssh_config_from_wsl(linux_ssh_config_path: PosixPath): """Setup the Windows SSH configuration and public key from within WSL. This copies over the entries from the linux ssh configuration file, except for the @@ -192,6 +220,8 @@ def setup_windows_ssh_config_from_wsl(linux_ssh_config: SSHConfigWriter): This makes it so the user doesn't need to install Python/Anaconda on the Windows side in order to use `mila code` from within WSL. """ + linux_ssh_config = SSHConfigWriter(linux_ssh_config_path) + assert running_inside_WSL() # NOTE: This also assumes that a public/private key pair has already been generated # at ~/.ssh/id_rsa.pub and ~/.ssh/id_rsa. @@ -234,7 +264,7 @@ def setup_windows_ssh_config_from_wsl(linux_ssh_config: SSHConfigWriter): _copy_if_needed(linux_key_file, windows_key_file) -def setup_passwordless_ssh_access(ssh_config: SSHConfigWriter) -> bool: +def setup_passwordless_ssh_access(ssh_config_path: Path) -> bool: """Sets up passwordless ssh access to the Mila and optionally also to DRAC. Sets up ssh connection to the DRAC clusters if they are present in the SSH config @@ -242,37 +272,26 @@ def setup_passwordless_ssh_access(ssh_config: SSHConfigWriter) -> bool: Returns whether the operation completed successfully or not. """ - print("Checking passwordless authentication") + print("Setting up passwordless SSH access.") - here = LocalV1() - sshdir = Path.home() / ".ssh" + ssh_config = SSHConfigReader.from_path(str(ssh_config_path)) - # Check if there is a public key file in ~/.ssh - if not list(sshdir.glob("id*.pub")): - if yn("You have no public keys. Generate one?"): - # Run ssh-keygen with the given location and no passphrase. - ssh_private_key_path = Path.home() / ".ssh" / "id_rsa" - create_ssh_keypair(ssh_private_key_path, here) - else: - print("No public keys.") - return False + # TODO: Generate SSH keys with ssh-keygen (not setting the passphrase so users can choose to use a passphrase or not). + setup_passwordless_ssh_access_to_cluster("mila", ssh_config_path) - # TODO: This uses the public key set in the SSH config file, which may (or may not) - # be the random id*.pub file that was just checked for above. - success = setup_passwordless_ssh_access_to_cluster("mila") - if not success: - return False - setup_keys_on_login_node("mila") + hosts_in_ssh_config = [ + hostname + for hostname in ssh_config.get_hostnames() + if not any(c in hostname for c in "!*?") + ] - drac_clusters_in_ssh_config: list[str] = [] - hosts_in_config = ssh_config.hosts() - for cluster in DRAC_CLUSTERS: - if any(cluster in hostname for hostname in hosts_in_config): - drac_clusters_in_ssh_config.append(cluster) + drac_clusters_in_ssh_config: list[str] = list( + set(DRAC_CLUSTERS).intersection(hosts_in_ssh_config) + ) if not drac_clusters_in_ssh_config: logger.debug( - f"There are no DRAC clusters in the SSH config at {ssh_config.path}." + f"There are no DRAC clusters in the SSH config at {ssh_config_path}." ) return True @@ -285,14 +304,32 @@ def setup_passwordless_ssh_access(ssh_config: SSHConfigWriter) -> bool: "See https://docs.alliancecan.ca/wiki/SSH_Keys#Using_CCDB for more info." ) for drac_cluster in drac_clusters_in_ssh_config: - success = setup_passwordless_ssh_access_to_cluster(drac_cluster) + success = setup_passwordless_ssh_access_to_cluster( + drac_cluster, ssh_config_path + ) if not success: return False setup_keys_on_login_node(drac_cluster) return True -def setup_passwordless_ssh_access_to_cluster(cluster: str) -> bool: +def _get_private_key_path_for_hostname( + hostname: str, ssh_config_path: Path +) -> Path | None: + config = SSHConfigReader.from_path(str(ssh_config_path)) + identity_file = config.lookup(hostname).get("identityfile") + if not identity_file: + return None + # Seems to be a list for some reason? + if isinstance(identity_file, list): + assert identity_file + identity_file = identity_file[0] + return Path(identity_file).expanduser() + + +def setup_passwordless_ssh_access_to_cluster( + cluster: str, ssh_config_path: Path +) -> bool: """Sets up passwordless SSH access to the given hostname. On Mac/Linux, uses `ssh-copy-id`. Performs the steps of ssh-copy-id manually on @@ -300,20 +337,32 @@ def setup_passwordless_ssh_access_to_cluster(cluster: str) -> bool: Returns whether the operation completed successfully or not. """ - here = LocalV1() + here = LocalV2() # Check that it is possible to connect without using a password. print(f"Checking if passwordless SSH access is setup for the {cluster} cluster.") - # TODO: Potentially use a custom key like `~/.ssh/id_milatools.pub` instead of - # the default. - from paramiko.config import SSHConfig + ssh_private_key_path = _get_private_key_path_for_hostname(cluster, ssh_config_path) + # TODO: Simplify the code here by assuming that users just accepted the changes to + # their SSH config proposed by the first part of `mila init`. + # - Instead of making the code complicated with lots of corner cases, just raise an + # error if the SSH config doesn't match what we expect to see after `mila init`. + raise NotImplementedError() + if ssh_private_key_path is None: + # TODO: What to do if there isn't a private key set in the SSH config, but there + # is already a private key in the SSH dir? (it would be used by ssh). + console.log( + f"There is no private key set to be used for the {cluster} cluster." + ) + ssh_private_key_path = Path("~/.ssh/id_rsa").expanduser() + + if not ssh_private_key_path.exists(): + console.log( + f"The ssh key to use for host {cluster} does not exist at {ssh_private_key_path}. Creating it now." + ) + create_ssh_keypair(ssh_private_key_path) + config_writer = SSHConfigWriter(ssh_config_path) + + config_writer.set(cluster, IdentityFile=str(ssh_private_key_path)) - config = SSHConfig.from_path(str(SSH_CONFIG_FILE)) - identity_file = config.lookup(cluster).get("identityfile", "~/.ssh/id_rsa") - # Seems to be a list for some reason? - if isinstance(identity_file, list): - assert identity_file - identity_file = identity_file[0] - ssh_private_key_path = Path(identity_file).expanduser() ssh_public_key_path = ssh_private_key_path.with_suffix(".pub") assert ssh_public_key_path.exists() @@ -448,7 +497,7 @@ def get_windows_home_path_in_wsl() -> Path: def create_ssh_keypair( ssh_private_key_path: Path, - local: LocalV1 | None = None, + local: LocalV2 | None = None, passphrase: str | None = "", ) -> None: """Creates a public/private key pair at the given path using ssh-keygen. @@ -457,18 +506,17 @@ def create_ssh_keypair( Otherwise, if passphrase is an empty string, no passphrase will be used (default). If a string is passed, it is passed to ssh-keygen and used as the passphrase. """ - local = local or LocalV1() - command = [ + local = local or LocalV2() + command = ( "ssh-keygen", "-f", str(ssh_private_key_path.expanduser()), "-t", - "rsa", - ] + "rsa", # note: Could also let the user choose the type of encryption.. + ) if passphrase is not None: - command.extend(["-N", passphrase]) - display(command) - subprocess.run(command, check=True) + command += ("-N", passphrase) + local.run(command, display=True) def has_passphrase(ssh_private_key_path: Path) -> bool: