diff --git a/python/tests/test_haplotype_matching.py b/python/tests/test_haplotype_matching.py index a358993202..fb01015999 100644 --- a/python/tests/test_haplotype_matching.py +++ b/python/tests/test_haplotype_matching.py @@ -22,6 +22,7 @@ """ Python implementation of the Li and Stephens forwards and backwards algorithms. """ +import io import warnings import lshmm as ls @@ -37,7 +38,8 @@ MISSING = -1 -# np.set_printoptions(linewidth=1000, precision=3) +# For debugging +np.set_printoptions(linewidth=1000, precision=3) def check_alleles(alleles, m): @@ -151,7 +153,7 @@ def node_values(self): def print_state(self): print("LsHMM state") print("match_all_nodes =", self.match_all_nodes) - print("Tree =") + print("Tree = ", self.tree.index, self.tree.interval) node_labels = {} for u, value in self.node_values().items(): label = f"{u}" @@ -434,11 +436,13 @@ def update_probabilities(self, site, haplotype_state): def process_site(self, site, haplotype_state): self.update_probabilities(site, haplotype_state) # d1 = self.node_values() + print("PRE") + self.print_state() self.compress() # d2 = self.node_values() # assert d1 == d2 - # print("AFTER COMPRESS") - # self.print_state() + print("AFTER COMPRESS") + self.print_state() s = self.compute_normalisation_factor() for st in self.T: assert st.tree_node != tskit.NULL @@ -489,8 +493,13 @@ def run(self, h): self.initialise(1 / n) while self.tree.next(): self.update_tree() + if self.tree.index != 0: + print("AFTER UPDATE TREE") + self.print_state() for site in self.tree.sites(): self.process_site(site, h[site.id]) + print("BEFORE UPDATE TREE") + self.print_state() return self.output def compute_normalisation_factor(self): @@ -1182,7 +1191,6 @@ def verify(self, ts): self.assertAllClose(ll, ll_check) -# TODO add params to run the various checks def check_viterbi( ts, h, @@ -1212,10 +1220,10 @@ def check_viterbi( cm = ls_viterbi_tree( h, ts, rho=recombination, mu=mutation, match_all_nodes=match_all_nodes ) + cm.print_state() path_tree = cm.traceback(match_all_nodes=match_all_nodes) ll_tree = np.sum(np.log10(cm.normalisation_factor)) assert np.isscalar(ll_tree) - # print(cm) # print("path tree = ", path_tree) if compare_lshmm: @@ -1437,8 +1445,8 @@ def test_match_sample(self, j): ts = self.ts() h = np.zeros(4) h[j] = 1 - # path = check_viterbi(ts, h) - # nt.assert_array_equal([j, j, j, j], path) + path = check_viterbi(ts, h) + nt.assert_array_equal([j, j, j, j], path) cm = check_forward_matrix(ts, h) check_backward_matrix(ts, h, cm) @@ -1525,6 +1533,19 @@ def test_match_sample(self, u, h): ) +def validate_match_all_nodes(ts, h, expected_path): + path = check_viterbi( + ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=False + ) + nt.assert_array_equal(expected_path, path) + cm = check_forward_matrix( + ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=False + ) + bm = check_backward_matrix( + ts, h, cm, match_all_nodes=True, compare_lib=False, compare_lshmm=False + ) + + class TestSingleBalancedTreeAllNodesExample: # 3.00┊ 6 ┊ # ┊ ┏━┻━┓ ┊ @@ -1540,7 +1561,6 @@ def ts(): tables.tree_sequence(), start=1, nodes=np.arange(len(tables.nodes) - 1) ) - # def test_match_sample(self, u, h): @pytest.mark.parametrize( ("h", "expected_path"), [ @@ -1558,20 +1578,99 @@ def ts(): ([0, 0, 0, 0, 0, 0], [6] * 6), ], ) - def test_match_sample(self, h, expected_path): - ts = self.ts() - path = check_viterbi( - ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=False + def test_exact_match(self, h, expected_path): + validate_match_all_nodes(self.ts(), h, expected_path) + + +class TestMultiTreeExample: + # 0.84┊ 7 ┊ 7 ┊ + # ┊ ┏━┻━┓ ┊ ┏━┻━┓ ┊ + # 0.42┊ ┃ ┃ ┊ 6 ┃ ┊ + # ┊ ┃ ┃ ┊ ┏┻┓ ┃ ┊ + # 0.05┊ 5 ┃ ┊ ┃ ┃ ┃ ┊ + # ┊ ┏━┻┓ ┃ ┊ ┃ ┃ ┃ ┊ + # 0.04┊ ┃ 4 ┃ ┊ ┃ ┃ 4 ┊ + # ┊ ┃ ┏┻┓ ┃ ┊ ┃ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ 0 3 1 2 ┊ + # 0 6 7 + @staticmethod + def ts(): + nodes = """\ + is_sample time + 1 0.000000 + 1 0.000000 + 1 0.000000 + 1 0.000000 + 0 0.041304 + 0 0.045967 + 0 0.416719 + 0 0.838075 + """ + edges = """\ + left right parent child + 0.000000 7.000000 4 1 + 0.000000 7.000000 4 2 + 0.000000 6.000000 5 0 + 0.000000 6.000000 5 4 + 6.000000 7.000000 6 0 + 6.000000 7.000000 6 3 + 0.000000 6.000000 7 3 + 6.000000 7.000000 7 4 + 0.000000 6.000000 7 5 + 6.000000 7.000000 7 6 + """ + ts = tskit.load_text( + nodes=io.StringIO(nodes), edges=io.StringIO(edges), strict=False ) + return add_unique_node_mutations(ts, nodes=range(7)) + + # 0.84┊ 7 ┊ 7 ┊ + # ┊ ┏━┻━┓ ┊ ┏━┻━┓ ┊ + # 0.42┊ ┃ ┃ ┊ 6 ┃ ┊ + # ┊ ┃ ┃ ┊ ┏┻┓ ┃ ┊ + # 0.05┊ 5 ┃ ┊ ┃ ┃ ┃ ┊ + # ┊ ┏━┻┓ ┃ ┊ ┃ ┃ ┃ ┊ + # 0.04┊ ┃ 4 ┃ ┊ ┃ ┃ 4 ┊ + # ┊ ┃ ┏┻┓ ┃ ┊ ┃ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ 0 3 1 2 ┊ + # 0 6 7 + + @pytest.mark.parametrize( + ("h", "expected_path"), + [ + # Just samples + ([1, 0, 0, 0, 0, 1, 1], [0] * 7), + ([0, 1, 0, 0, 1, 1, 0], [1] * 7), + ([0, 0, 1, 0, 1, 1, 0], [2] * 7), + ([0, 0, 0, 1, 0, 0, 1], [3] * 7), + # Match root + ([0, 0, 0, 0, 0, 0, 0], [7] * 7), + ], + ) + def test_match_all_nodes(self, h, expected_path): + # print() + # print(self.ts().draw_text()) + # with open("tmp.svg", "w") as f: + # f.write(self.ts().draw_svg()) + validate_match_all_nodes(self.ts(), h, expected_path) + + @pytest.mark.parametrize( + ("h", "expected_path"), + [ + ([1, 0, 0, 0, 0, 1, 1], [0] * 7), + ([0, 1, 0, 0, 1, 1, 0], [1] * 7), + ([0, 0, 1, 0, 1, 1, 0], [2] * 7), + ([0, 0, 0, 1, 0, 0, 1], [3] * 7), + # Switch between each of the samples + ([1, 1, 1, 1, 0, 0, 1], [0, 1, 2, 3, 3, 3, 3]), + ], + ) + def test_match_samples(self, h, expected_path): + ts = self.ts() + path = check_viterbi(ts, h) nt.assert_array_equal(expected_path, path) - cm = check_forward_matrix( - ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=False - ) - print(cm.decode()) - bm = check_backward_matrix( - ts, h, cm, match_all_nodes=True, compare_lib=False, compare_lshmm=False - ) - print(bm.decode()) + cm = check_forward_matrix(ts, h) + check_backward_matrix(ts, h, cm) class TestSimulationExamples: