Skip to content

Commit

Permalink
test: use single SSH connection for lifetime of microvm
Browse files Browse the repository at this point in the history
Instead of creating new SSH connections every time we want to run a
command inside the microvm, open a single ssh connection in the
constructor of `SSHConnection`, and reuse it until we kill the microvm.

Use the `fabric` SSH library to achieve this.

Signed-off-by: Patrick Roy <roypat@amazon.co.uk>
  • Loading branch information
roypat committed Dec 10, 2024
1 parent 9a8c5a8 commit 354b67e
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 87 deletions.
12 changes: 10 additions & 2 deletions tests/framework/microvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,8 @@ def __init__(
self.mem_size_bytes = None
self.cpu_template_name = None

self._ssh_connections = []

self._pre_cmd = []
if numa_node:
node_str = str(numa_node)
Expand Down Expand Up @@ -282,6 +284,10 @@ def kill(self):
for monitor in self.monitors:
monitor.stop()

# Cleanup all SSH connections
for conn in self._ssh_connections:
conn.close()

# We start with vhost-user backends,
# because if we stop Firecracker first, the backend will want
# to exit as well and this will cause a race condition.
Expand Down Expand Up @@ -1007,13 +1013,15 @@ def ssh_iface(self, iface_idx=0):
"""Return a cached SSH connection on a given interface id."""
guest_ip = list(self.iface.values())[iface_idx]["iface"].guest_ip
self.ssh_key = Path(self.ssh_key)
return net_tools.SSHConnection(
netns=self.netns.id,
connection = net_tools.SSHConnection(
netns_=self.netns.id,
ssh_key=self.ssh_key,
user="root",
host=guest_ip,
on_error=self._dump_debug_information,
)
self._ssh_connections.append(connection)
return connection

@property
def ssh(self):
Expand Down
127 changes: 47 additions & 80 deletions tests/host_tools/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
import ipaddress
import random
import string
import subprocess
from dataclasses import dataclass, field
from io import BytesIO
from pathlib import Path

from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
import netns
from fabric import Connection
from tenacity import retry, stop_after_attempt, wait_fixed

from framework import utils
from framework.utils import CommandReturn


class SSHConnection:
Expand All @@ -22,13 +25,14 @@ class SSHConnection:
the hostname obtained from the MAC address, the username for logging into
the image and the path of the ssh key.
This translates into an SSH connection as follows:
ssh -i ssh_key_path username@hostname
Uses the fabric library to establish a single connection once, and then
keep it alive for the lifetime of the microvm, to avoid spurious failures
due to reestablishing SSH connections for every single command sent.
"""

def __init__(self, netns, ssh_key: Path, host, user, *, on_error=None):
def __init__(self, netns_, ssh_key: Path, host, user, *, on_error=None):
"""Instantiate a SSH client and connect to a microVM."""
self.netns = netns
self.netns = netns_
self.ssh_key = ssh_key
# check that the key exists and the permissions are 0o400
# This saves a lot of debugging time.
Expand All @@ -40,26 +44,23 @@ def __init__(self, netns, ssh_key: Path, host, user, *, on_error=None):

self._on_error = None

self.options = [
"-o",
"LogLevel=ERROR",
"-o",
"ConnectTimeout=1",
"-o",
"StrictHostKeyChecking=no",
"-o",
"UserKnownHostsFile=/dev/null",
"-o",
"PreferredAuthentications=publickey",
"-i",
str(self.ssh_key),
]
self._connection = Connection(
host,
user,
connect_timeout=1,
connect_kwargs={
"key_filename": str(self.ssh_key),
"banner_timeout": 1,
"auth_timeout": 1,
},
)

# _init_connection loops until it can connect to the guest
# dumping debug state on every iteration is not useful or wanted, so
# only dump it once if _all_ iterations fail.
try:
self._init_connection()
with netns.NetNS(netns_):
self._init_connection()
except Exception as exc:
if on_error:
on_error(exc)
Expand All @@ -68,35 +69,19 @@ def __init__(self, netns, ssh_key: Path, host, user, *, on_error=None):

self._on_error = on_error

@property
def user_host(self):
"""remote address for in SSH format <user>@<IP>"""
return f"{self.user}@{self.host}"

def remote_path(self, path):
"""Convert a path to remote"""
return f"{self.user_host}:{path}"

def _scp(self, path1, path2, options):
"""Copy files to/from the VM using scp."""
self._exec(["scp", *options, path1, path2], check=True)

def scp_put(self, local_path, remote_path, recursive=False):
def scp_put(self, local_path, remote_path):
"""Copy files to the VM using scp."""
opts = self.options.copy()
if recursive:
opts.append("-r")
self._scp(local_path, self.remote_path(remote_path), opts)
self._connection.put(local_path, remote_path)

def scp_get(self, remote_path, local_path, recursive=False):
def scp_get(self, remote_path, local_path):
"""Copy files from the VM using scp."""
opts = self.options.copy()
if recursive:
opts.append("-r")
self._scp(self.remote_path(remote_path), local_path, opts)
self._connection.get(remote_path, local_path)

@retry(
retry=retry_if_exception_type(ChildProcessError),
wait=wait_fixed(0.5),
stop=stop_after_attempt(20),
reraise=True,
Expand All @@ -106,61 +91,43 @@ def _init_connection(self):
Since we're connecting to a microVM we just started, we'll probably
have to wait for it to boot up and start the SSH server.
We'll keep trying to execute a remote command that can't fail
(`/bin/true`), until we get a successful (0) exit code.
We'll keep trying to open the connection in a loop for 20 attempts with 0.5s
delay. Each connection attempt has a timeout of 1s.
"""
self.check_output("true", timeout=100, debug=True)
self._connection.open()

def run(self, cmd_string, timeout=None, *, check=False, debug=False):
def run(self, cmd_string, timeout=None, *, check=False):
"""
Execute the command passed as a string in the ssh context.
If `debug` is set, pass `-vvv` to `ssh`. Note that this will clobber stderr.
"""
command = ["ssh", *self.options, self.user_host, cmd_string]

if debug:
command.insert(1, "-vvv")

return self._exec(command, timeout, check=check)
return self._exec(cmd_string, timeout, check=check)

def check_output(self, cmd_string, timeout=None, *, debug=False):
def check_output(self, cmd_string, timeout=None):
"""Same as `run`, but raises an exception on non-zero return code of remote command"""
return self.run(cmd_string, timeout, check=True, debug=debug)
return self.run(cmd_string, timeout, check=True)

def close(self):
"""Closes this SSHConnection"""
self._connection.close()

def _exec(self, cmd, timeout=None, check=False):
"""Private function that handles the ssh client invocation."""
if self.netns is not None:
cmd = ["ip", "netns", "exec", self.netns] + cmd

try:
return utils.run_cmd(cmd, check=check, timeout=timeout)
# - warn=True means "do not raise exception on non-zero exit code, instead just log", e.g.
# it's the inverse of our "check" argument.
# - hide=True means "do not always log stdout/stderr"
# - in_stream=BytesIO(b"") is needed to immediately close stdin of the remote command
# without this, command that only exit after their stdin is closed would hang forever
# and this hang would bypass the pytest timeout.
result = self._connection.run(
cmd, timeout=timeout, warn=not check, hide=True, in_stream=BytesIO(b"")
)
except Exception as exc:
if self._on_error:
self._on_error(exc)

raise

# pylint:disable=invalid-name
def Popen(
self,
cmd: str,
stdin=subprocess.DEVNULL,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
**kwargs,
) -> subprocess.Popen:
"""Execute the command in the guest and return a Popen object.
pop = uvm.ssh.Popen("while true; do echo $(date -Is) $RANDOM; sleep 1; done")
pop.stdout.read(16)
"""
cmd = ["ssh", *self.options, self.user_host, cmd]
if self.netns is not None:
cmd = ["ip", "netns", "exec", self.netns] + cmd
return subprocess.Popen(
cmd, stdin=stdin, stdout=stdout, stderr=stderr, **kwargs
)
return CommandReturn(result.exited, result.stdout, result.stderr)


def mac_from_ip(ip_address):
Expand Down
4 changes: 2 additions & 2 deletions tests/integration_tests/functional/test_balloon.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

import logging
import time
from subprocess import TimeoutExpired

import pytest
from invoke import CommandTimedOut
from tenacity import retry, stop_after_attempt, wait_fixed

from framework.utils import check_output, get_free_mem_ssh
Expand Down Expand Up @@ -74,7 +74,7 @@ def make_guest_dirty_memory(ssh_connection, amount_mib=32):
logger.error("while running: %s", cmd)
logger.error("stdout: %s", stdout)
logger.error("stderr: %s", stderr)
except TimeoutExpired:
except CommandTimedOut:
# It's ok if this expires. Sometimes the SSH connection
# gets killed by the OOM killer *after* the fillmem program
# started. As a result, we can ignore timeouts here.
Expand Down
8 changes: 5 additions & 3 deletions tests/integration_tests/functional/test_pause_resume.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,18 @@ def test_pause_resume(uvm_nano):
# Flush and reset metrics as they contain pre-pause data.
microvm.flush_metrics()

# Verify guest is no longer active.
with pytest.raises(ChildProcessError):
# Verify guest is no longer active (by observing a failure to reconnect)
with pytest.raises(TimeoutError):
microvm.ssh.close()
microvm.ssh.check_output("true")

# Verify emulation was indeed paused and no events from either
# guest or host side were handled.
verify_net_emulation_paused(microvm.flush_metrics())

# Verify guest is no longer active.
with pytest.raises(ChildProcessError):
with pytest.raises(TimeoutError):
microvm.ssh.close()
microvm.ssh.check_output("true")

# Pausing the microVM when it is already `Paused` is allowed
Expand Down

0 comments on commit 354b67e

Please sign in to comment.