diff --git a/python/tests/test_imputation.py b/python/tests/test_imputation.py new file mode 100644 index 0000000000..0ca3cc2639 --- /dev/null +++ b/python/tests/test_imputation.py @@ -0,0 +1,272 @@ +""" +Tests for genotype imputation (forward and Baum-Welsh algorithms). +""" +import io + +import numpy as np +import pandas as pd + +import _tskit +import tskit + + +# A tree sequence containing 3 diploid individuals with 5 sites and 5 mutations +# (one per site). The first 2 individuals are used as reference panel, +# the last one is the target individual. + +toy_ts_nodes_text = """\ +id is_sample time population individual metadata +0 1 0.000000 0 0 +1 1 0.000000 0 0 +2 1 0.000000 0 1 +3 1 0.000000 0 1 +4 1 0.000000 0 2 +5 1 0.000000 0 2 +6 0 0.029768 0 -1 +7 0 0.133017 0 -1 +8 0 0.223233 0 -1 +9 0 0.651586 0 -1 +10 0 0.698831 0 -1 +11 0 2.114867 0 -1 +12 0 4.322031 0 -1 +13 0 7.432311 0 -1 +""" + +toy_ts_edges_text = """\ +left right parent child metadata +0.000000 1000000.000000 6 0 +0.000000 1000000.000000 6 3 +0.000000 1000000.000000 7 2 +0.000000 1000000.000000 7 5 +0.000000 1000000.000000 8 1 +0.000000 1000000.000000 8 4 +0.000000 781157.000000 9 6 +0.000000 781157.000000 9 7 +0.000000 505438.000000 10 8 +0.000000 505438.000000 10 9 +505438.000000 549484.000000 11 8 +505438.000000 549484.000000 11 9 +781157.000000 1000000.000000 12 6 +781157.000000 1000000.000000 12 7 +549484.000000 1000000.000000 13 8 +549484.000000 781157.000000 13 9 +781157.000000 1000000.000000 13 12 +""" + +toy_ts_sites_text = """\ +position ancestral_state metadata +200000.000000 A +300000.000000 C +520000.000000 G +600000.000000 T +900000.000000 A +""" + +toy_ts_mutations_text = """\ +site node time derived_state parent metadata +0 9 unknown G -1 +1 8 unknown A -1 +2 9 unknown T -1 +3 9 unknown C -1 +4 12 unknown C -1 +""" + +toy_ts_individuals_text = """\ +flags +0 +0 +0 +""" + + +def get_toy_data(): + """ + Returns toy data contained in the toy tree sequence in text format above. + + :param: None + :return: Reference panel tree sequence and query haplotypes. + :rtype: list + """ + ts = tskit.load_text( + nodes=io.StringIO(toy_ts_nodes_text), + edges=io.StringIO(toy_ts_edges_text), + sites=io.StringIO(toy_ts_sites_text), + mutations=io.StringIO(toy_ts_mutations_text), + individuals=io.StringIO(toy_ts_individuals_text), + strict=False, + ) + ref_ts = ts.simplify(samples=np.arange(2 * 2), filter_sites=False) + query_ts = ts.simplify(samples=[5, 6], filter_sites=False) + query_h = query_ts.genotype_matrix().T + return [ref_ts, query_h] + + +# BEAGLE 4.1 was run on the toy data set above using default parameters. +# +# In the query VCF, the site at position 520,000 was redacted and then imputed. +# Note that the ancestral allele in the simulated tree sequence is +# treated as the REF in the VCFs. +# +# The following are the forward probability matrices and backward probability +# matrices calculated when imputing into the third individual above. There are +# two sets of matrices, one for each haplotype. +# +# Notes about calculations: +# n = number of haplotypes in ref. panel +# M = number of markers +# m = index of marker (site) +# h = index of haplotype in ref. panel +# +# In forward probability matrix, +# fwd[m][h] = emission prob., if m = 0 (first marker) +# fwd[m][h] = emission prob. * (scale * fwd[m - 1][h] + shift), otherwise +# where scale = (1 - switch prob.)/sum of fwd[m - 1], +# and shift = switch prob./n. +# +# In backward probability matrix, +# bwd[m][h] = 1, if m = M - 1 (last marker) // DON'T SEE THIS IN BEAGLE +# unadj. bwd[m][h] = emission prob. / n +# bwd[m][h] = (unadj. bwd[m][h] + shift) * scale, otherwise +# where scale = (1 - switch prob.)/sum of unadj. bwd[m], +# and shift = switch prob./n. +# +# For each site, the sum of backward value over all haplotypes is calculated +# before scaling and shifting. + +beagle_forward_matrix_text_1 = """ +m,h,probRec,probNoRec,noErrProb,errProb,refAl,queryAl,shiftFac,scaleFac,sumSite,val +0,0,0.000000,1.000000,0.999900,0.000100,1,0,0.000000,1.000000,0.000100,0.000100 +0,1,0.000000,1.000000,0.999900,0.000100,0,0,0.000000,1.000000,1.000000,0.999900 +0,2,0.000000,1.000000,0.999900,0.000100,1,0,0.000000,1.000000,1.000100,0.000100 +0,3,0.000000,1.000000,0.999900,0.000100,1,0,0.000000,1.000000,1.000200,0.000100 +1,0,1.000000,0.000000,0.999900,0.000100,0,1,0.250000,0.000000,0.000025,0.000025 +1,1,1.000000,0.000000,0.999900,0.000100,1,1,0.250000,0.000000,0.250000,0.249975 +1,2,1.000000,0.000000,0.999900,0.000100,0,1,0.250000,0.000000,0.250025,0.000025 +1,3,1.000000,0.000000,0.999900,0.000100,0,1,0.250000,0.000000,0.250050,0.000025 +2,0,1.000000,0.000000,0.999900,0.000100,1,0,0.250000,0.000000,0.000025,0.000025 +2,1,1.000000,0.000000,0.999900,0.000100,0,0,0.250000,0.000000,0.250000,0.249975 +2,2,1.000000,0.000000,0.999900,0.000100,1,0,0.250000,0.000000,0.250025,0.000025 +2,3,1.000000,0.000000,0.999900,0.000100,1,0,0.250000,0.000000,0.250050,0.000025 +3,0,1.000000,0.000000,0.999900,0.000100,1,0,0.250000,0.000000,0.000025,0.000025 +3,1,1.000000,0.000000,0.999900,0.000100,0,0,0.250000,0.000000,0.250000,0.249975 +3,2,1.000000,0.000000,0.999900,0.000100,1,0,0.250000,0.000000,0.250025,0.000025 +3,3,1.000000,0.000000,0.999900,0.000100,1,0,0.250000,0.000000,0.250050,0.000025 +""" + +beagle_backward_matrix_text_1 = """ +m,h,probRec,probNoRec,noErrProb,errProb,refAl,queryAl,shiftFac,scaleFac,sumSite,val +3,0,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000 +3,1,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000 +3,2,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000 +3,3,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000 +2,0,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.250000,0.250050,0.250000 +2,1,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.250000,0.250050,0.250000 +2,2,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.250000,0.250050,0.250000 +2,3,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.250000,0.250050,0.250000 +1,0,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.250000,0.250050,0.250000 +1,1,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.250000,0.250050,0.250000 +1,2,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.250000,0.250050,0.250000 +1,3,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.250000,0.250050,0.250000 +0,0,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.250000,0.250050,0.250000 +0,1,1.000000,0.000000,0.999900,0.000100,0,1,0.000000,0.250000,0.250050,0.250000 +0,2,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.250000,0.250050,0.250000 +0,3,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.250000,0.250050,0.250000 +""" + +beagle_forward_matrix_text_2 = """ +m,h,probRec,probNoRec,noErrProb,errProb,refAl,queryAl,shiftFac,scaleFac,sumSite,val +0,0,0.000000,1.000000,0.999900,0.000100,1,1,0.000000,1.000000,0.999900,0.999900 +0,1,0.000000,1.000000,0.999900,0.000100,0,1,0.000000,1.000000,1.000000,0.000100 +0,2,0.000000,1.000000,0.999900,0.000100,1,1,0.000000,1.000000,1.999900,0.999900 +0,3,0.000000,1.000000,0.999900,0.000100,1,1,0.000000,1.000000,2.999800,0.999900 +1,0,1.000000,0.000000,0.999900,0.000100,0,0,0.250000,0.000000,0.249975,0.249975 +1,1,1.000000,0.000000,0.999900,0.000100,1,0,0.250000,0.000000,0.250000,0.000025 +1,2,1.000000,0.000000,0.999900,0.000100,0,0,0.250000,0.000000,0.499975,0.249975 +1,3,1.000000,0.000000,0.999900,0.000100,0,0,0.250000,0.000000,0.749950,0.249975 +2,0,1.000000,0.000000,0.999900,0.000100,1,1,0.250000,0.000000,0.249975,0.249975 +2,1,1.000000,0.000000,0.999900,0.000100,0,1,0.250000,0.000000,0.250000,0.000025 +2,2,1.000000,0.000000,0.999900,0.000100,1,1,0.250000,0.000000,0.499975,0.249975 +2,3,1.000000,0.000000,0.999900,0.000100,1,1,0.250000,0.000000,0.749950,0.249975 +3,0,1.000000,0.000000,0.999900,0.000100,1,1,0.250000,0.000000,0.249975,0.249975 +3,1,1.000000,0.000000,0.999900,0.000100,0,1,0.250000,0.000000,0.250000,0.000025 +3,2,1.000000,0.000000,0.999900,0.000100,1,1,0.250000,0.000000,0.499975,0.249975 +3,3,1.000000,0.000000,0.999900,0.000100,1,1,0.250000,0.000000,0.749950,0.249975 +""" + +beagle_backward_matrix_text_2 = """ +m,h,probRec,probNoRec,noErrProb,errProb,refAl,queryAl,shiftFac,scaleFac,sumSite,val +3,0,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000 +3,1,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000 +3,2,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000 +3,3,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000 +2,0,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.250000,0.749950,0.250000 +2,1,1.000000,0.000000,0.999900,0.000100,0,1,0.000000,0.250000,0.749950,0.250000 +2,2,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.250000,0.749950,0.250000 +2,3,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.250000,0.749950,0.250000 +1,0,1.000000,0.000000,0.999900,0.000100,0,1,0.000000,0.250000,0.749950,0.250000 +1,1,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.250000,0.749950,0.250000 +1,2,1.000000,0.000000,0.999900,0.000100,0,1,0.000000,0.250000,0.749950,0.250000 +1,3,1.000000,0.000000,0.999900,0.000100,0,1,0.000000,0.250000,0.749950,0.250000 +0,0,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.250000,0.749950,0.250000 +0,1,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.250000,0.749950,0.250000 +0,2,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.250000,0.749950,0.250000 +0,3,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.250000,0.749950,0.250000 +""" + + +def convert_to_numpy(matrix_text): + """Converts a forward or backward matrix in text format to numpy.""" + df = pd.read_csv(io.StringIO(matrix_text)) + # Check that switch and non-switch probabilities sum to 1 + assert np.all(np.isin(df.probRec + df.probNoRec, [1, -2])) + # Check that non-mismatch and mismatch probabilities sum to 1 + assert np.all(np.isin(df.noErrProb + df.errProb, [1, -2])) + return df.val.to_numpy().reshape((4, 4)) + + +def get_beagle_forward_backward_matrices(): + fwd_matrix_1 = convert_to_numpy(beagle_forward_matrix_text_1) + bwd_matrix_1 = convert_to_numpy(beagle_backward_matrix_text_1) + fwd_matrix_2 = convert_to_numpy(beagle_forward_matrix_text_2) + bwd_matrix_2 = convert_to_numpy(beagle_backward_matrix_text_2) + return [fwd_matrix_1, bwd_matrix_1, fwd_matrix_2, bwd_matrix_2] + + +def get_beagle_data(matrix_text, data_type): + """Extracts data to check forward or backward probability matrix calculations.""" + df = pd.read_csv(io.StringIO(matrix_text)) + if data_type == "switch": + # Switch probability, one per site + return df.probRec.to_numpy().reshape((4, 4))[:, 0] + elif data_type == "mismatch": + # Mismatch probability, one per site + return df.errProb.to_numpy().reshape((4, 4))[:, 0] + elif data_type == "ref_hap_allele": + # Allele in haplotype in reference panel + # 0 = ref allele, 1 = alt allele + return df.refAl.to_numpy().reshape((4, 4)) + elif data_type == "query_hap_allele": + # Allele in haplotype in query + # 0 = ref allele, 1 = alt allele + return df.queryAl.to_numpy().reshape((4, 4))[:, 0] + elif data_type == "shift": + # Shift factor, one per site + return df.shiftFac.to_numpy().reshape((4, 4))[:, 0] + elif data_type == "scale": + # Scale factor, one per site + return df.scaleFac.to_numpy().reshape((4, 4))[:, 0] + elif data_type == "sum": + # Sum of values over haplotypes + return df.sumSite.to_numpy().reshape((4, 4))[:, 0] + else: + raise ValueError(f"Unknown data type: {data_type}") + + +def get_tskit_forward_backward_matrices(ts, h): + m = ts.num_sites + fm = _tskit.CompressedMatrix(ts) + bm = _tskit.CompressedMatrix(ts) + ls_hmm = _tskit.LsHmm(ts, np.zeros(m) + 0.1, np.zeros(m) + 0.1) + ls_hmm.forward_matrix(h, fm) + ls_hmm.backward_matrix(h, fm.normalisation_factor, bm) + return [fm, bm]