From 3663507ee35fcf6602ba9599dc354b00430b0fec Mon Sep 17 00:00:00 2001 From: David de la Iglesia Castro Date: Thu, 13 Oct 2022 19:15:48 +0200 Subject: [PATCH] ssh: Decouple `password` and `passphrase`. Introduce new options `passphrase` and `ask_passphrase`. Depends on updating config_schema on `dvc` repo. --- dvc_ssh/__init__.py | 20 +++++++++++--------- dvc_ssh/tests/test_prepare_credentials.py | 15 +++++++++++++++ 2 files changed, 26 insertions(+), 9 deletions(-) create mode 100644 dvc_ssh/tests/test_prepare_credentials.py diff --git a/dvc_ssh/__init__.py b/dvc_ssh/__init__.py index d16954e..46a4672 100644 --- a/dvc_ssh/__init__.py +++ b/dvc_ssh/__init__.py @@ -12,10 +12,9 @@ @wrap_with(threading.Lock()) @memoize -def ask_password(host, user, port): +def ask_password(host, user, port, desc): return getpass.getpass( - "Enter a private key passphrase or a password for " - f"host '{host}' port '{port}' user '{user}':\n" + f"Enter a {desc} for " f"host '{host}' port '{port}' user '{user}':\n" ) @@ -62,13 +61,16 @@ def _prepare_credentials(self, **config): or DEFAULT_PORT ) - if config.get("ask_password") and config.get("password") is None: - config["password"] = ask_password( - login_info["host"], login_info["username"], login_info["port"] - ) + for option in ("password", "passphrase"): + login_info[option] = config.get(option, None) - login_info["password"] = config.get("password") - login_info["passphrase"] = config.get("password") + if config.get(f"ask_{option}") and login_info[option] is None: + login_info[option] = ask_password( + login_info["host"], + login_info["username"], + login_info["port"], + option, + ) raw_keys = [] if config.get("keyfile"): diff --git a/dvc_ssh/tests/test_prepare_credentials.py b/dvc_ssh/tests/test_prepare_credentials.py new file mode 100644 index 0000000..97841d3 --- /dev/null +++ b/dvc_ssh/tests/test_prepare_credentials.py @@ -0,0 +1,15 @@ +import pytest + +from dvc_ssh import SSHFileSystem + + +@pytest.mark.parametrize("password", [None, "foo"]) +@pytest.mark.parametrize("passphrase", [None, "bar"]) +def test_passphrase(mocker, password, passphrase): + connect = mocker.patch("asyncssh.connect") + + kwargs = {"password": password, "passphrase": passphrase} + _ = SSHFileSystem(host="foo", **kwargs).fs + + assert connect.call_args[1]["password"] == password + assert connect.call_args[1]["passphrase"] == passphrase