Skip to content

Commit

Permalink
Catch and handle Paramiko SSH errors (#48)
Browse files Browse the repository at this point in the history
* Added a custom Exception class SSHConnectionError.
* Wraps Remote class with try except raise condition.
* Prints more explicit error message with possible workarounds instead of basic Exception error message which recommends to file a bug.

Co-authored-by: Ganda Tchabana <ganda.tchabana@mila.quebec>
Co-authored-by: Fabrice Normandin <normandf@mila.quebec>
  • Loading branch information
3 people authored Aug 23, 2023
1 parent dee7fb7 commit 330d3a6
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 6 deletions.
4 changes: 4 additions & 0 deletions milatools/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .utils import (
CommandNotFoundError,
MilatoolsUserError,
SSHConnectionError,
T,
get_fully_qualified_name,
qualified,
Expand All @@ -45,6 +46,9 @@ def main():
except MilatoolsUserError as exc:
# These are user errors and should not be reported
print("ERROR:", exc, file=sys.stderr)
except SSHConnectionError as err:
# These are errors coming from paramiko's failure to connect to the host
print("ERROR:", f"{err}", file=sys.stderr)
except Exception:
print(T.red(traceback.format_exc()), file=sys.stderr)
options = {
Expand Down
16 changes: 10 additions & 6 deletions milatools/cli/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
from pathlib import Path
from queue import Empty, Queue

import paramiko
import questionary as qn
from fabric import Connection

from .utils import T, control_file_var, here, shjoin
from .utils import SSHConnectionError, T, control_file_var, here, shjoin

batch_template = """#!/bin/bash
#SBATCH --output={output_file}
Expand Down Expand Up @@ -78,11 +79,14 @@ def get_first_node_name(node_names_out: str) -> str:
class Remote:
def __init__(self, hostname, connection=None, transforms=(), keepalive=60):
self.hostname = hostname
if connection is None:
connection = Connection(hostname)
if keepalive:
connection.open()
connection.transport.set_keepalive(keepalive)
try:
if connection is None:
connection = Connection(hostname)
if keepalive:
connection.open()
connection.transport.set_keepalive(keepalive)
except paramiko.SSHException as err:
raise SSHConnectionError(node_hostname=self.hostname, error=err)
self.connection = connection
self.transforms = transforms

Expand Down
27 changes: 27 additions & 0 deletions milatools/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pathlib import Path

import blessed
import paramiko
import questionary as qn
from invoke.exceptions import UnexpectedExit
from sshconf import read_ssh_config
Expand Down Expand Up @@ -88,6 +89,32 @@ def __str__(self):
return message


class SSHConnectionError(paramiko.SSHException):
def __init__(self, node_hostname: str, error: paramiko.SSHException):
super().__init__()
self.node_hostname = node_hostname
self.error = error

def __str__(self):
return (
"An error happened while trying to establish a connection with {0}".format(
self.node_hostname
)
+ "\n\t"
+ "-The cluster might be under maintenance"
+ "\n\t "
+ "Check #mila-cluster for updates on the state of the cluster"
+ "\n\t"
+ "-Check the status of your connection to the cluster by ssh'ing onto it."
+ "\n\t"
+ "-Retry connecting with mila"
+ "\n\t"
+ "-Try to exclude the node with -x {0} parameter".format(
self.node_hostname
)
)


def yn(prompt: str, default: bool = True) -> bool:
return qn.confirm(prompt, default=default).unsafe_ask()

Expand Down

0 comments on commit 330d3a6

Please sign in to comment.