diff --git a/fabtests/pytest/efa/efa_common.py b/fabtests/pytest/efa/efa_common.py index 4f5da4faf02..d5f0eb959c0 100644 --- a/fabtests/pytest/efa/efa_common.py +++ b/fabtests/pytest/efa/efa_common.py @@ -1,3 +1,4 @@ +import os import subprocess import functools from common import SshConnectionError, is_ssh_connection_error, has_ssh_connection_err_msg, ClientServerTest @@ -66,6 +67,29 @@ def has_gdrcopy(hostname): process = subprocess.run(command, shell=True, check=False, stdout=subprocess.PIPE) return process.returncode == 0 +def has_rdma(cmdline_args, operation): + """ + determine whether a host has rdma enabled in efa device + hostname: a host + operation: rdma operation name, allowed values are read and write + return: a boolean + """ + assert operation in ["read", "write"] + binpath = cmdline_args.binpath or "" + cmd = "timeout " + str(cmdline_args.timeout) \ + + " " + os.path.join(binpath, f"fi_efa_rdma_checker -o {operation}") + if cmdline_args.environments: + cmd = cmdline_args.environments + " " + cmd + proc = subprocess.run("ssh {} {}".format(cmdline_args.server_id, cmd), + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + shell=True, + universal_newlines=True) + if has_ssh_connection_err_msg(proc.stdout): + raise SshConnectionError() + + return proc.returncode == 0 + def efa_retrieve_gid(hostname): """ return the GID of efa device on a host diff --git a/fabtests/pytest/efa/test_efa_protocol_selection.py b/fabtests/pytest/efa/test_efa_protocol_selection.py index 76212febc10..949f2982304 100644 --- a/fabtests/pytest/efa/test_efa_protocol_selection.py +++ b/fabtests/pytest/efa/test_efa_protocol_selection.py @@ -1,6 +1,6 @@ import pytest -from efa.efa_common import has_gdrcopy +from efa.efa_common import has_gdrcopy, has_rdma # TODO Expand this test to run on all memory types (and rename) @@ -17,6 +17,9 @@ def test_transfer_with_read_protocol_cuda(cmdline_args, fabtest_name, cntrl_env_ from common import has_cuda, has_hmem_support from efa.efa_common import efa_run_client_server_test, efa_retrieve_hw_counter_value + if cntrl_env_var == "FI_EFA_INTER_MIN_READ_WRITE_SIZE" and has_rdma(cmdline_args, "write"): + pytest.skip("FI_EFA_INTER_MIN_READ_WRITE_SIZE is only applied to emulated write protocols") + if cmdline_args.server_id == cmdline_args.client_id: pytest.skip("No read for intra-node communication")