Skip to content

Commit

Permalink
Fix sshkeys passphrase issue with mila init [MT-72] (#93)
Browse files Browse the repository at this point in the history
* Add `has_passphrase` function and test

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Add passphrase param to test_create_ssh_keypair

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Remove check for number of lines in private key

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Update poetry.lock file

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Increase timeout value for test_create_ssh_keypair

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix bug in create_ssh_keypair

Signed-off-by: Fabrice <normandf@mila.quebec>

* Add `use_shjoin` arg to `display`

Signed-off-by: Fabrice <normandf@mila.quebec>

* Simplify passing path of keyfile to ssh-keygen

Signed-off-by: Fabrice <normandf@mila.quebec>

* Simpify sending of ssh key on Windows

Signed-off-by: Fabrice <normandf@mila.quebec>

* Add a hardcore integration test (not in CI yet)

Signed-off-by: Fabrice <normandf@mila.quebec>

* Pass passphrase as f"-N='{passphrase}'"

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Simplify and add a docstring to create_ssh_keypair

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Setup ssh keypair if needed during test

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Tweak comment

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Also catch socket.gaierror for Windows errors

Signed-off-by: Fabrice <normandf@mila.quebec>

* Fix small typing error in utils.py

Signed-off-by: Fabrice <normandf@mila.quebec>

* Create the parent dir of sshkey in test

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Try to fix ssh-keygen errors on Windows (again)

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Create the SSH dir during test

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Update milatools/cli/init_command.py

* Remove the xfails for weird paths for ssh keys

Signed-off-by: Fabrice <normandf@mila.quebec>

* Fix failing test param on windows

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Apply suggestions from code review

Co-authored-by: satyaog <satyaog@gmail.com>

* Change has_passphrase

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix pre-commit hook issues

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Remove unused "test" dep group

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

---------

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
Signed-off-by: Fabrice <normandf@mila.quebec>
Co-authored-by: satyaog <satyaog@gmail.com>
  • Loading branch information
lebrice and satyaog authored Feb 9, 2024
1 parent f8bb9c1 commit 68b7dde
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 34 deletions.
81 changes: 63 additions & 18 deletions milatools/cli/init_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import difflib
import functools
import json
import shlex
import shutil
import subprocess
import sys
Expand All @@ -16,7 +15,7 @@
import questionary as qn
from invoke.exceptions import UnexpectedExit

from .local import Local, check_passwordless
from .local import Local, check_passwordless, display
from .remote import Remote
from .utils import SSHConfig, T, running_inside_WSL, yn
from .vscode_utils import (
Expand Down Expand Up @@ -292,11 +291,11 @@ def setup_passwordless_ssh_access_to_cluster(cluster: str) -> bool:
here = Local()
# Check that it is possible to connect without using a password.
print(f"Checking if passwordless SSH access is setup for the {cluster} cluster.")
# TODO: Potentially use the public key from the SSH config file instead of
# the default. It's also possible that ssh-copy-id selects the key from the
# config file, I'm not sure.
# ssh_private_key_path = Path.home() / ".ssh" / "id_rsa"

# TODO: Potentially use a custom key like `~/.ssh/id_milatools.pub` instead of
# the default.
ssh_private_key_path = Path.home() / ".ssh" / "id_rsa"
ssh_public_key_path = ssh_private_key_path.with_suffix(".pub")
assert ssh_public_key_path.exists()
if check_passwordless(cluster):
logger.info(f"Passwordless SSH access to {cluster} is already setup correctly.")
return True
Expand All @@ -310,19 +309,23 @@ def setup_passwordless_ssh_access_to_cluster(cluster: str) -> bool:

print("Please enter your password when prompted.")
if sys.platform == "win32":
# todo: the path to the key is hard-coded here.
# NOTE: This is to remove extra '^M' characters that would be added at the end
# of the file on the remote!
public_key_contents = ssh_public_key_path.read_text().replace("\r\n", "\n")
command = (
"powershell.exe",
"type",
"$env:USERPROFILE\\.ssh\\id_rsa.pub",
"|",
"ssh",
"-o",
"StrictHostKeyChecking=no",
cluster,
'"cat >> ~/.ssh/authorized_keys"',
"cat >> ~/.ssh/authorized_keys",
)
here.run(*command, check=True)
display(command)
import tempfile

with tempfile.NamedTemporaryFile("w", newline="\n") as f:
print(public_key_contents, end="", file=f)
f.seek(0)
subprocess.run(command, check=True, text=False, stdin=f)
else:
here.run("ssh-copy-id", "-o", "StrictHostKeyChecking=no", cluster, check=True)

Expand Down Expand Up @@ -410,11 +413,53 @@ def get_windows_home_path_in_wsl() -> Path:
return Path(f"/mnt/c/Users/{windows_username}")


def create_ssh_keypair(ssh_private_key_path: Path, local: Local) -> None:
local.run(
*shlex.split(
f'ssh-keygen -f {shlex.quote(str(ssh_private_key_path))} -t rsa -N=""'
def create_ssh_keypair(
ssh_private_key_path: Path,
local: Local | None = None,
passphrase: str | None = "",
) -> None:
"""Creates a public/private key pair at the given path using ssh-keygen.
If passphrase is `None`, ssh-keygen will prompt the user for a passphrase.
Otherwise, if passphrase is an empty string, no passphrase will be used (default).
If a string is passed, it is passed to ssh-keygen and used as the passphrase.
"""
local = local or Local()
command = [
"ssh-keygen",
"-f",
str(ssh_private_key_path.expanduser()),
"-t",
"rsa",
]
if passphrase is not None:
command.extend(["-N", passphrase])
display(command)
subprocess.run(command, check=True)


def has_passphrase(ssh_private_key_path: Path) -> bool:
"""Returns whether the SSH private key has a passphrase or not."""
assert ssh_private_key_path.exists()
result = subprocess.run(
args=(
"ssh-keygen",
"-y",
"-P=''",
"-f",
str(ssh_private_key_path),
),
capture_output=True,
text=True,
)
logger.debug(f"Result of ssh-keygen: {result}")
if result.returncode == 0:
return False
elif "incorrect passphrase supplied to decrypt private key" in result.stderr:
return True
raise NotImplementedError(
f"TODO: Unable to tell if the key at {ssh_private_key_path} has a passphrase "
f"or not! (result={result})"
)


Expand Down
8 changes: 6 additions & 2 deletions milatools/cli/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,12 @@ def check_passwordless(self, host: str):
return check_passwordless(host)


def display(split_command: list[str] | tuple[str, ...]) -> None:
print(T.bold_green("(local) $ ", shjoin(split_command)))
def display(split_command: list[str] | tuple[str, ...] | str) -> None:
if isinstance(split_command, str):
command = split_command
else:
command = shjoin(split_command)
print(T.bold_green("(local) $ ", command))


def check_passwordless(host: str) -> bool:
Expand Down
3 changes: 2 additions & 1 deletion milatools/cli/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,9 @@ def __init__(
# assert isinstance(connection.transport, paramiko.Transport)
transport: paramiko.Transport = connection.transport # type: ignore
transport.set_keepalive(keepalive)
except paramiko.SSHException as err:
except (paramiko.SSHException, socket.gaierror) as err:
raise SSHConnectionError(node_hostname=self.hostname, error=err)

self.connection = connection
self.transforms = transforms

Expand Down
4 changes: 2 additions & 2 deletions milatools/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def __str__(self):


class SSHConnectionError(paramiko.SSHException):
def __init__(self, node_hostname: str, error: paramiko.SSHException):
def __init__(self, node_hostname: str, error: Exception):
super().__init__()
self.node_hostname = node_hostname
self.error = error
Expand All @@ -158,7 +158,7 @@ def yn(prompt: str, default: bool = True) -> bool:
return qn.confirm(prompt, default=default).unsafe_ask()


def askpath(prompt, remote):
def askpath(prompt: str, remote: Remote) -> str:
while True:
pth = qn.text(prompt).unsafe_ask()
try:
Expand Down
139 changes: 128 additions & 11 deletions tests/cli/test_init_command.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import contextlib
import getpass
import json
import os
import shutil
Expand All @@ -12,6 +13,7 @@
from pathlib import Path
from unittest.mock import Mock

import fabric
import pytest
import pytest_mock
import questionary
Expand All @@ -26,14 +28,15 @@
_setup_ssh_config_file,
create_ssh_keypair,
get_windows_home_path_in_wsl,
has_passphrase,
setup_passwordless_ssh_access,
setup_passwordless_ssh_access_to_cluster,
setup_ssh_config,
setup_vscode_settings,
setup_windows_ssh_config_from_wsl,
)
from milatools.cli.local import Local, check_passwordless
from milatools.cli.utils import SSHConfig, running_inside_WSL
from milatools.cli.utils import SSHConfig, T, running_inside_WSL

from .common import (
in_github_CI,
Expand Down Expand Up @@ -727,20 +730,52 @@ def test_fixes_dir_permission_issues(


# takes a little longer in the CI runner (Windows in particular)
@pytest.mark.timeout(10)
def test_create_ssh_keypair(monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
here = Local()
mock_run = Mock(
wraps=subprocess.run,
)
monkeypatch.setattr(subprocess, "run", mock_run)
@pytest.mark.timeout(20)
@pytest.mark.parametrize(
("passphrase", "expected"),
[("", False), ("bobobo", True), ("\n", True), (" ", True)],
)
@pytest.mark.parametrize(
"filename",
[
"bob",
"dir with spaces/somefile",
"dir_with_'single_quotes'/somefile",
pytest.param(
'dir_with_"doublequotes"/somefile',
marks=pytest.mark.xfail(
sys.platform == "win32",
strict=True,
raises=OSError,
reason="Doesn't work on Windows.",
),
),
pytest.param(
"windows_style_dir\\bob",
marks=pytest.mark.skipif(
sys.platform != "win32", reason="only runs on Windows."
),
),
],
)
def test_create_ssh_keypair(
mocker: pytest_mock.MockerFixture,
tmp_path: Path,
filename: str,
passphrase: str,
expected: bool,
):
# Wrap the subprocess.run call (but also actually execute the commands).
subprocess_run = mocker.patch("subprocess.run", wraps=subprocess.run)

fake_ssh_folder = tmp_path / "fake_ssh"
fake_ssh_folder.mkdir(mode=0o700)
ssh_private_key_path = fake_ssh_folder / "bob"
ssh_private_key_path = fake_ssh_folder / filename
ssh_private_key_path.parent.mkdir(mode=0o700, exist_ok=True, parents=True)

create_ssh_keypair(ssh_private_key_path=ssh_private_key_path, local=here)
create_ssh_keypair(ssh_private_key_path=ssh_private_key_path, passphrase=passphrase)

mock_run.assert_called_once()
subprocess_run.assert_called_once()
assert ssh_private_key_path.exists()
if not on_windows:
assert ssh_private_key_path.stat().st_mode & 0o777 == 0o600
Expand All @@ -749,6 +784,8 @@ def test_create_ssh_keypair(monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
if not on_windows:
assert ssh_public_key_path.stat().st_mode & 0o777 == 0o644

assert has_passphrase(ssh_private_key_path) == expected


@pytest.fixture
def linux_ssh_config(
Expand Down Expand Up @@ -1075,6 +1112,13 @@ def test_setup_passwordless_ssh_access_to_cluster(
backup_authorized_keys_file = backup_ssh_dir / "authorized_keys"
assert backup_authorized_keys_file.exists()

ssh_private_key_path = ssh_dir / "id_rsa"
ssh_public_key_path = ssh_private_key_path.with_suffix(".pub")
if not ssh_public_key_path.exists():
create_ssh_keypair(ssh_private_key_path=ssh_private_key_path)
assert ssh_public_key_path.exists()
assert not has_passphrase(ssh_private_key_path)

if not passwordless_to_cluster_is_already_setup:
if authorized_keys_file.exists():
logger.warning(
Expand Down Expand Up @@ -1166,6 +1210,7 @@ def test_setup_passwordless_ssh_access(
f"Temporarily deleting the ssh dir (backed up at {backup_ssh_dir})"
)
shutil.rmtree(ssh_dir)
ssh_dir.mkdir(mode=0o700, exist_ok=False)

if not public_key_exists:
# There should be no ssh keys in the ssh dir before calling the function.
Expand Down Expand Up @@ -1252,3 +1297,75 @@ def test_setup_passwordless_ssh_access(
for drac_cluster in drac_clusters_in_ssh_config:
mock_setup_passwordless_ssh_access_to_cluster.assert_any_call(drac_cluster)
assert result is True


@pytest.fixture()
def cluster(request: pytest.FixtureRequest) -> str:
cluster_name: str | None = getattr(
request, "param", os.environ.get("SLURM_CLUSTER", None)
)
if not cluster_name:
pytest.skip(reason="Need a real slurm cluster to be specified")
return cluster_name


@pytest.fixture()
def authorized_keys_backup(cluster: str):
"""Fixture used to backup the authorized_keys file on the remote and restore it
after tests."""
connect_kwargs = {}
backup_authorized_keys_path = "~/.ssh/authorized_keys.backup"
if not check_passwordless(cluster):
if in_github_CI:
pytest.skip(
f"Can't run this test because passwordless SSH access to {cluster} is "
"not setup."
)
password = getpass.getpass(
T.red("\nEnter your password for SSH-ing to the cluster\n")
)
connect_kwargs = {"password": password}

remote = fabric.Connection(cluster, connect_kwargs=connect_kwargs)
remote.run(
f"cp ~/.ssh/authorized_keys {backup_authorized_keys_path}",
echo=True,
echo_format=T.bold_cyan(f"({cluster})" + " $ {command}"),
in_stream=False,
)
try:
yield backup_authorized_keys_path
finally:
remote.run(
"cp ~/.ssh/authorized_keys.backup ~/.ssh/authorized_keys",
echo=True,
echo_format=T.bold_cyan(f"({cluster})" + " $ {command}"),
in_stream=False,
)


@pytest.mark.timeout(None)
@pytest.mark.skipif(
in_github_CI, reason="Can't run this in the github CI since it asks for a password."
)
@pytest.mark.skipif(
"SLURM_CLUSTER" not in os.environ, reason="Only runs with a real cluster."
)
def test_setup_passwordless_ssh_access_to_real_cluster(
cluster: str,
authorized_keys_backup: str,
):
if check_passwordless(cluster):
logger.warning(
f"Temporarily removing the ~/.ssh/authorized_keys file on {cluster} "
f"(backed up at {cluster}:{authorized_keys_backup})"
)
fabric.Connection(cluster).run(
"rm ~/.ssh/authorized_keys",
echo=True,
echo_format=T.bold_cyan(f"({cluster})" + " $ {command}"),
in_stream=False,
)
assert not check_passwordless(cluster)
setup_passwordless_ssh_access_to_cluster(cluster)
assert check_passwordless(cluster)

0 comments on commit 68b7dde

Please sign in to comment.