diff --git a/python/tests/test_imputation.py b/python/tests/test_imputation.py index 3a9b9b10e9..e87dc5b6a6 100644 --- a/python/tests/test_imputation.py +++ b/python/tests/test_imputation.py @@ -226,18 +226,18 @@ def get_test_data(matrix_text, par): x = convert_to_numpy(matrix_text) if par == "switch": # Switch probability, one per site - return x[:, 2].reshape((4, 4))[:, 2] + return x[:, 2].reshape((4, 4))[:, 0] elif par == "mismatch": - # Mismatch probability - return x[:, 2].reshape((4, 4))[:, 4] + # Mismatch probability, one per site + return x[:, 4].reshape((4, 4))[:, 0] elif par == "ref_hap_allele": # Allele in haplotype in reference panel # 0 = ref allele, 1 = alt allele - return x[:, 2].reshape((4, 4))[:, 6] + return x[:, 6].reshape((4, 4)) elif par == "query_hap_allele": # Allele in haplotype in query # 0 = ref allele, 1 = alt allele - return x[:, 2].reshape((4, 4))[:, 7] + return x[:, 7].reshape((4, 4))[:, 0] elif par == "shift": # Shift factor # TODO @@ -248,6 +248,6 @@ def get_test_data(matrix_text, par): pass elif par == "sum": # Sum of values over haplotypes - return x[:, 2].reshape((4, 4))[:, 10] + return x[:, 10].reshape((4, 4))[:, 0] else: raise ValueError(f"Unknown parameter: {par}")