diff --git a/testcases/consistency/test_consistency.py b/testcases/consistency/test_consistency.py index 776b9c4..dfeac79 100755 --- a/testcases/consistency/test_consistency.py +++ b/testcases/consistency/test_consistency.py @@ -11,10 +11,10 @@ import pytest import typing -test_string = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" -# Use a global test_info to get a better output when running pytest -test_info: typing.Dict[str, typing.Any] = {} +test_string = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" +test_info = os.getenv("TEST_INFO_FILE") +test_info_dict = testhelper.read_yaml(test_info) def file_content_check(f: typing.IO, comp_str: str) -> bool: @@ -23,7 +23,7 @@ def file_content_check(f: typing.IO, comp_str: str) -> bool: def consistency_check(mount_point: str, ipaddr: str, share_name: str) -> None: - mount_params = testhelper.get_mount_parameters(test_info, share_name) + mount_params = testhelper.get_mount_parameters(test_info_dict, share_name) mount_params["host"] = ipaddr try: test_file = testhelper.get_tmp_file(mount_point) @@ -55,22 +55,19 @@ def consistency_check(mount_point: str, ipaddr: str, share_name: str) -> None: def generate_consistency_check( - test_info_file: typing.Optional[str], + test_info_file: dict, ) -> typing.List[typing.Tuple[str, str]]: - global test_info if not test_info_file: return [] - test_info = testhelper.read_yaml(test_info_file) arr = [] - for ipaddr in test_info["public_interfaces"]: - for share_name in test_info["exported_sharenames"]: + for ipaddr in test_info_file["public_interfaces"]: + for share_name in test_info_file["exported_sharenames"]: arr.append((ipaddr, share_name)) return arr @pytest.mark.parametrize( - "ipaddr,share_name", - generate_consistency_check(os.getenv("TEST_INFO_FILE")), + "ipaddr,share_name", generate_consistency_check(test_info_dict) ) def test_consistency(ipaddr: str, share_name: str) -> None: tmp_root = testhelper.get_tmp_root() diff --git a/testcases/mount/test_mount.py b/testcases/mount/test_mount.py index d0bba9f..320ff2a 100755 --- a/testcases/mount/test_mount.py +++ b/testcases/mount/test_mount.py @@ -12,12 +12,12 @@ from .mount_io import check_io_consistency from .mount_dbm import check_dbm_consistency -# Use a global test_info to get a better output when running pytest -test_info: typing.Dict[str, typing.Any] = {} +test_info = os.getenv("TEST_INFO_FILE") +test_info_dict = testhelper.read_yaml(test_info) def mount_check(ipaddr: str, share_name: str) -> None: - mount_params = testhelper.get_mount_parameters(test_info, share_name) + mount_params = testhelper.get_mount_parameters(test_info_dict, share_name) mount_params["host"] = ipaddr tmp_root = testhelper.get_tmp_root() mount_point = testhelper.get_tmp_mount_point(tmp_root) @@ -38,21 +38,19 @@ def mount_check(ipaddr: str, share_name: str) -> None: def generate_mount_check( - test_info_file: typing.Optional[str], + test_info_file: dict, ) -> typing.List[typing.Tuple[str, str]]: - global test_info if not test_info_file: return [] - test_info = testhelper.read_yaml(test_info_file) arr = [] - for ipaddr in test_info["public_interfaces"]: - for share_name in test_info["exported_sharenames"]: + for ipaddr in test_info_file["public_interfaces"]: + for share_name in test_info_file["exported_sharenames"]: arr.append((ipaddr, share_name)) return arr @pytest.mark.parametrize( - "ipaddr,share_name", generate_mount_check(os.getenv("TEST_INFO_FILE")) + "ipaddr,share_name", generate_mount_check(test_info_dict) ) def test_mount(ipaddr: str, share_name: str) -> None: mount_check(ipaddr, share_name) diff --git a/testcases/smbtorture/test_smbtorture.py b/testcases/smbtorture/test_smbtorture.py index 77cfa03..a98dd15 100755 --- a/testcases/smbtorture/test_smbtorture.py +++ b/testcases/smbtorture/test_smbtorture.py @@ -15,14 +15,16 @@ format_subunit_exec = script_root + "/selftest/format-subunit" smbtorture_tests_file = script_root + "/smbtorture-tests-info.yml" -test_info: typing.Dict[str, typing.Any] = {} +test_info = os.getenv("TEST_INFO_FILE") +test_info_dict = testhelper.read_yaml(test_info) + # Temp filename containing the output of run commands. output = testhelper.get_tmp_file("/tmp") def smbtorture(share_name: str, test: str, tmp_output: str) -> bool: # build smbtorture command - mount_params = testhelper.get_mount_parameters(test_info, share_name) + mount_params = testhelper.get_mount_parameters(test_info_dict, share_name) smbtorture_cmd = [ smbtorture_exec, "--fullname", @@ -48,7 +50,7 @@ def smbtorture(share_name: str, test: str, tmp_output: str) -> bool: "--expected-failures=" + script_root + "/selftest/" + filter ) flapping_list = ["flapping", "flapping.d"] - test_backend = test_info.get("test_backend") + test_backend = test_info_dict.get("test_backend") if test_backend is not None: flapping_file = "flapping." + test_backend flapping_file_path = os.path.join( @@ -114,25 +116,17 @@ def list_smbtorture_tests(): return smbtorture_info -def generate_smbtorture_tests( - test_info_file: typing.Optional[str], -) -> typing.List[typing.Tuple[str, str]]: - global test_info - if not test_info_file: - return [] - test_info = testhelper.read_yaml(test_info_file) +def generate_smbtorture_tests() -> typing.List[typing.Tuple[str, str]]: smbtorture_info = list_smbtorture_tests() arr = [] - for sharenum in range(testhelper.get_num_shares(test_info)): - share_name = testhelper.get_share(test_info, sharenum) + for sharenum in range(testhelper.get_num_shares(test_info_dict)): + share_name = testhelper.get_share(test_info_dict, sharenum) for torture_test in smbtorture_info: arr.append((share_name, torture_test)) return arr -@pytest.mark.parametrize( - "share_name,test", generate_smbtorture_tests(os.getenv("TEST_INFO_FILE")) -) +@pytest.mark.parametrize("share_name,test", generate_smbtorture_tests()) def test_smbtorture(share_name: str, test: str) -> None: ret = smbtorture(share_name, test, output) if os.path.exists(output): diff --git a/testhelper/testhelper.py b/testhelper/testhelper.py index 935ee02..a3d4876 100644 --- a/testhelper/testhelper.py +++ b/testhelper/testhelper.py @@ -3,16 +3,16 @@ import random -def read_yaml(file: str) -> dict: +def read_yaml(test_info): """Returns a dict containing the contents of the yaml file. Parameters: - arg1: filename of yaml file + test_info: filename of yaml file. Returns: - dict: parsed contents of a yaml file + dict: The parsed test information yml as a dictionary. """ - with open(file) as f: + with open(test_info) as f: test_info = yaml.load(f, Loader=yaml.FullLoader) return test_info @@ -81,7 +81,7 @@ def get_mount_parameters( share: The share for which to get the mount_params combonum: The combination number to use. """ - if combonum > get_total_mount_parameter_combinations(test_info): + if combonum >= get_total_mount_parameter_combinations(test_info): assert False, "Invalid combination number" num_public = int(combonum / len(test_info["test_users"])) num_users = combonum % len(test_info["test_users"]) @@ -113,7 +113,7 @@ def get_share(test_info: dict, share_num: int) -> str: share_num: The index within the exported sharenames list Returns: - str: exported sharename in index share_num + str: exported sharename at index share_num """ return test_info["exported_sharenames"][share_num]