From 9696cf99f3ccf0e8b4b44da7a5ab51bd52acb8e5 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Wed, 11 Oct 2023 13:28:11 +0100 Subject: [PATCH] Working all-nodes check for special case --- python/tests/test_haplotype_matching.py | 364 ++++++++++++++++++------ 1 file changed, 276 insertions(+), 88 deletions(-) diff --git a/python/tests/test_haplotype_matching.py b/python/tests/test_haplotype_matching.py index dcc1d684fb..c3a9e3134c 100644 --- a/python/tests/test_haplotype_matching.py +++ b/python/tests/test_haplotype_matching.py @@ -99,7 +99,15 @@ class LsHmmAlgorithm: """ def __init__( - self, ts, rho, mu, alleles, n_alleles, precision=10, scale_mutation=False + self, + ts, + rho, + mu, + alleles, + n_alleles, + precision=10, + scale_mutation=False, + match_all_nodes=False, ): self.ts = ts self.mu = mu @@ -109,8 +117,6 @@ def __init__( self.T = [] # indexes in to the T array for each node. self.T_index = np.zeros(ts.num_nodes, dtype=int) - 1 - # The number of nodes underneath each element in the T array. - self.N = np.zeros(ts.num_nodes, dtype=int) # Efficiently compute the allelic state at a site self.allelic_state = np.zeros(ts.num_nodes, dtype=int) - 1 # TreePosition so we can can update T and T_index between trees. @@ -122,6 +128,41 @@ def __init__( self.n_alleles = n_alleles self.alleles = alleles self.scale_mutation_based_on_n_alleles = scale_mutation + self.match_all_nodes = match_all_nodes + + def node_values(self): + """ + Return the current mapping of node->value for each node in the + tree. + """ + d = {} + mapping = {st.tree_node: st.value for st in self.T if st.tree_node != -1} + for u in self.tree.nodes(): + v = u + while v not in mapping: + assert v != -1 + v = self.tree.parent(v) + d[u] = mapping[v] + return d + + def print_state(self): + print("LsHMM state") + print("match_all_nodes =", self.match_all_nodes) + print("Tree =") + node_labels = {} + for u, value in self.node_values().items(): + label = f"{u}" + if self.tree.is_sample(u): + label = f"*{u}*" + label += f":{value:.2g}" + node_labels[u] = label + print(self.tree.draw_text(node_labels=node_labels)) + print("T =") + for vt in self.T: + print("\t", vt) + print("T_index:") + for u in range(self.ts.num_nodes): + print(f"\t{u}\t{self.T_index[u]}") def check_integrity(self): M = [st.tree_node for st in self.T if st.tree_node != -1] @@ -134,6 +175,45 @@ def check_integrity(self): assert j == self.T_index[st.tree_node] def compress(self): + if self.match_all_nodes: + self._compress_tsinfer() + else: + self._compress_parsimony() + # self.print_state() + self.check_integrity() + + def _compress_tsinfer(self): + tree = self.tree + T = self.T + T_index = self.T_index + + T_old = [st.copy() for st in T] + T.clear() + + for st in T_old: + u = st.tree_node + if u != -1: + # We need to find the likelihood of the parent of u. If this is + # the same as u, we can delete it. + v = tree.parent(u) + while v != -1 and T_index[v] == -1: + v = tree.parent(v) + keep = True + if v != -1: + if st.value == T_old[T_index[v]].value: + keep = False + if keep: + T.append(st) + T_index[u] = -1 + + # Sort by decreasing time to ensure postorder. This is used by the + # compressed matrix, downstream + self.T.sort(key=lambda st: -tree.time(st.tree_node)) + + for j, st in enumerate(self.T): + self.T_index[st.tree_node] = j + + def _compress_parsimony(self): tree = self.tree T = self.T T_index = self.T_index @@ -190,13 +270,14 @@ def compute(u, parent_state): T_old = [st.copy() for st in T] T.clear() - T_parent = [] + # Removeing T_parent as it's not needed currently, see note on N[j] below + # T_parent = [] old_state = T_old[T_index[tree.root]].value_index new_state = np.argmax(optimal_set[tree.root]) T.append(ValueTransition(tree_node=tree.root, value=values[new_state])) - T_parent.append(-1) + # T_parent.append(-1) stack = [(tree.root, old_state, new_state, 0)] while len(stack) > 0: u, old_state, new_state, t_parent = stack.pop() @@ -211,14 +292,14 @@ def compute(u, parent_state): if optimal_set[v, new_state] == 0: new_child_state = np.argmax(optimal_set[v]) child_t_parent = len(T) - T_parent.append(t_parent) + # T_parent.append(t_parent) T.append( ValueTransition(tree_node=v, value=values[new_child_state]) ) stack.append((v, old_child_state, new_child_state, child_t_parent)) else: if old_child_state != new_state: - T_parent.append(t_parent) + # T_parent.append(t_parent) T.append( ValueTransition(tree_node=v, value=values[old_child_state]) ) @@ -228,10 +309,13 @@ def compute(u, parent_state): T_index[st.tree_node] = -1 for j, st in enumerate(T): T_index[st.tree_node] = j - self.N[j] = tree.num_samples(st.tree_node) - for j in range(len(T)): - if T_parent[j] != -1: - self.N[T_parent[j]] -= self.N[j] + + # NOTE: we only use the N values in the forward matrix at the moment, + # so simplifying here by calculating them on the fly where needed. + # self.N[j] = tree.num_samples(st.tree_node) + # for j in range(len(T)): + # if T_parent[j] != -1: + # self.N[T_parent[j]] -= self.N[j] def update_tree(self, direction=tskit.FORWARD): """ @@ -333,11 +417,11 @@ def update_probabilities(self, site, haplotype_state): while allelic_state[v] == -1: v = tree.parent(v) assert v != -1 - match = ( + is_match = ( haplotype_state == MISSING or haplotype_state == allelic_state[v] ) # Note that the node u is used only by Viterbi - st.value = self.compute_next_probability(site.id, st.value, match, u) + st.value = self.compute_next_probability(site.id, st.value, is_match, u) # Unset the states allelic_state[tree.root] = -1 @@ -346,7 +430,12 @@ def update_probabilities(self, site, haplotype_state): def process_site(self, site, haplotype_state): self.update_probabilities(site, haplotype_state) + # d1 = self.node_values() self.compress() + # d2 = self.node_values() + # assert d1 == d2 + # print("AFTER COMPRESS") + # self.print_state() s = self.compute_normalisation_factor() for st in self.T: assert st.tree_node != tskit.NULL @@ -413,26 +502,27 @@ class ForwardAlgorithm(LsHmmAlgorithm): The Li and Stephens forward algorithm. """ - def __init__( - self, ts, rho, mu, alleles, n_alleles, scale_mutation=False, precision=10 - ): - super().__init__( - ts, - rho, - mu, - alleles, - n_alleles, - precision=precision, - scale_mutation=scale_mutation, - ) + def __init__(self, ts, *args, **kwargs): + super().__init__(ts, *args, **kwargs) self.output = CompressedMatrix(ts) def compute_normalisation_factor(self): + d = {st.tree_node: st for st in self.T} + N = np.zeros(self.ts.num_nodes, dtype=int) + for u in self.tree.nodes(order="preorder"): + if u in d: + N[u] = self.tree.num_samples(u) + # Subtract this value from everything above + v = self.tree.parent(u) + while v != -1 and v not in d: + v = self.tree.parent(v) + if v != -1: + N[v] -= N[u] s = 0 - for j, st in enumerate(self.T): + for st in self.T: assert st.tree_node != tskit.NULL - # assert self.N[j] > 0 - s += self.N[j] * st.value + assert N[st.tree_node] > 0 + s += N[st.tree_node] * st.value return s def compute_next_probability(self, site_id, p_last, is_match, node): @@ -489,18 +579,8 @@ class ViterbiAlgorithm(LsHmmAlgorithm): Runs the Li and Stephens Viterbi algorithm. """ - def __init__( - self, ts, rho, mu, alleles, n_alleles, scale_mutation=False, precision=10 - ): - super().__init__( - ts, - rho, - mu, - alleles, - n_alleles, - precision=precision, - scale_mutation=scale_mutation, - ) + def __init__(self, ts, *args, **kwargs): + super().__init__(ts, *args, **kwargs) self.output = ViterbiMatrix(ts) def compute_normalisation_factor(self): @@ -570,6 +650,16 @@ def store_site(self, site, normalisation_factor, value_transitions): self.normalisation_factor[site] = normalisation_factor self.value_transitions[site] = value_transitions + def print_state(self): + print("Compressed matrix state") + for site in range(self.num_sites): + print( + site, + self.normalisation_factor[site], + self.value_transitions[site], + sep="\t", + ) + # Expose the same API as the low-level classes @property @@ -633,12 +723,14 @@ def choose_sample(self, site_id, tree): def traceback(self): # Run the traceback. m = self.ts.num_sites - match = np.zeros(m, dtype=int) + matched = np.zeros(m, dtype=int) recombination_tree = np.zeros(self.ts.num_nodes, dtype=int) - 1 tree = tskit.Tree(self.ts) tree.last() current_node = -1 + # self.print_state() + rr_index = len(self.recombination_required) - 1 for site in reversed(self.ts.sites()): while tree.interval.left > site.position: @@ -654,7 +746,7 @@ def traceback(self): if current_node == -1: current_node = self.choose_sample(site.id, tree) - match[site.id] = current_node + matched[site.id] = current_node # Now traverse up the tree from the current node. The first marked node # we meet tells us whether we need to recombine. @@ -664,6 +756,8 @@ def traceback(self): assert u != -1 if recombination_tree[u] == 1: + # print("recomb_tree = ", recombination_tree) + # print("SWITCHING AT ", site) # Need to switch at the next site. current_node = -1 # Reset the nodes in the recombination tree. @@ -674,7 +768,8 @@ def traceback(self): j -= 1 rr_index = j - return match + # print("MATCHED = ", matched) + return matched def get_site_alleles(ts, h, alleles): @@ -701,7 +796,14 @@ def get_site_alleles(ts, h, alleles): def ls_forward_tree( - h, ts, rho, mu, precision=30, alleles=None, scale_mutation_based_on_n_alleles=False + h, + ts, + rho, + mu, + precision=30, + alleles=None, + scale_mutation_based_on_n_alleles=False, + match_all_nodes=False, ): alleles, n_alleles = get_site_alleles(ts, h, alleles) fa = ForwardAlgorithm( @@ -712,11 +814,21 @@ def ls_forward_tree( n_alleles, precision=precision, scale_mutation=scale_mutation_based_on_n_alleles, + match_all_nodes=match_all_nodes, ) return fa.run(h) -def ls_backward_tree(h, ts, rho, mu, normalisation_factor, precision=30, alleles=None): +def ls_backward_tree( + h, + ts, + rho, + mu, + normalisation_factor, + precision=30, + alleles=None, + match_all_nodes=False, +): alleles, n_alleles = get_site_alleles(ts, h, alleles) ba = BackwardAlgorithm( ts, @@ -725,12 +837,20 @@ def ls_backward_tree(h, ts, rho, mu, normalisation_factor, precision=30, alleles alleles, n_alleles, precision=precision, + match_all_nodes=match_all_nodes, ) return ba.run(h, normalisation_factor) def ls_viterbi_tree( - h, ts, rho, mu, precision=30, alleles=None, scale_mutation_based_on_n_alleles=False + h, + ts, + rho, + mu, + precision=30, + alleles=None, + scale_mutation_based_on_n_alleles=False, + match_all_nodes=False, ): alleles, n_alleles = get_site_alleles(ts, h, alleles) va = ViterbiAlgorithm( @@ -741,6 +861,7 @@ def ls_viterbi_tree( n_alleles, precision=precision, scale_mutation=scale_mutation_based_on_n_alleles, + match_all_nodes=match_all_nodes, ) return va.run(h) @@ -798,8 +919,7 @@ def example_parameters_haplotypes(self, ts, seed=42): # yield n, H, s, r, mu def assertAllClose(self, A, B): - """Assert that all entries of two matrices are 'close'""" - assert np.allclose(A, B, rtol=1e-5, atol=1e-8) + np.testing.assert_allclose(A, B, rtol=1e-5, atol=1e-8) # Define a bunch of very small tree-sequences for testing a collection # of parameters on @@ -1028,6 +1148,8 @@ def verify(self, ts): # Now, need to ensure that the likelihood of the preferred path is # the same as ll_tree (and ll). path_tree = cm.traceback() + # print(path) + # print(path_tree) ll_check = ls.path_ll( H, s, @@ -1040,7 +1162,9 @@ def verify(self, ts): # TODO add params to run the various checks -def check_viterbi(ts, h, recombination=None, mutation=None): +def check_viterbi( + ts, h, recombination=None, mutation=None, match_all_nodes=False, compare_lib=True +): h = np.array(h).astype(np.int8) m = ts.num_sites assert len(h) == m @@ -1060,11 +1184,12 @@ def check_viterbi(ts, h, recombination=None, mutation=None): scale_mutation_based_on_n_alleles=False, ) assert np.isscalar(ll) + # print() + # print("ls path = ", path) - cm = ls_viterbi_tree(h, ts, rho=recombination, mu=mutation) - ll_tree = np.sum(np.log10(cm.normalisation_factor)) - assert np.isscalar(ll_tree) - nt.assert_allclose(ll_tree, ll) + cm = ls_viterbi_tree( + h, ts, rho=recombination, mu=mutation, match_all_nodes=match_all_nodes + ) # Check that the likelihood of the preferred path is # the same as ll_tree (and ll). @@ -1077,24 +1202,33 @@ def check_viterbi(ts, h, recombination=None, mutation=None): p_mutation=mutation, scale_mutation_based_on_n_alleles=False, ) - nt.assert_allclose(ll_check, ll) + # print(cm) + # print("path tree = ", path_tree) - ll_ts = ts._ll_tree_sequence - ls_hmm = _tskit.LsHmm(ll_ts, recombination, mutation, precision=precision) - cm_lib = _tskit.ViterbiMatrix(ll_ts) - ls_hmm.viterbi_matrix(h, cm_lib) - path_lib = cm_lib.traceback() + ll_tree = np.sum(np.log10(cm.normalisation_factor)) + assert np.isscalar(ll_tree) + nt.assert_allclose(ll_tree, ll) - # Not true in general, but let's see how far it goes - nt.assert_array_equal(path_lib, path_tree) + if compare_lib: + nt.assert_allclose(ll_check, ll) + ll_ts = ts._ll_tree_sequence + ls_hmm = _tskit.LsHmm(ll_ts, recombination, mutation, precision=precision) + cm_lib = _tskit.ViterbiMatrix(ll_ts) + ls_hmm.viterbi_matrix(h, cm_lib) + path_lib = cm_lib.traceback() - nt.assert_allclose(cm_lib.normalisation_factor, cm.normalisation_factor) + # Not true in general, but let's see how far it goes + nt.assert_array_equal(path_lib, path_tree) - return path + nt.assert_allclose(cm_lib.normalisation_factor, cm.normalisation_factor) + + return path_tree # TODO add params to run the various checks -def check_forward_matrix(ts, h, recombination=None, mutation=None): +def check_forward_matrix( + ts, h, recombination=None, mutation=None, match_all_nodes=False, compare_lib=True +): precision = 22 h = np.array(h).astype(np.int8) n = ts.num_samples @@ -1118,28 +1252,44 @@ def check_forward_matrix(ts, h, recombination=None, mutation=None): assert np.isscalar(ll) cm = ls_forward_tree( - h, ts, recombination, mutation, scale_mutation_based_on_n_alleles=False + h, + ts, + recombination, + mutation, + scale_mutation_based_on_n_alleles=False, + match_all_nodes=match_all_nodes, ) F2 = cm.decode() + # print(F) + # print(F2) nt.assert_allclose(F, F2) nt.assert_allclose(c, cm.normalisation_factor) ll_tree = np.sum(np.log10(cm.normalisation_factor)) nt.assert_allclose(ll_tree, ll) - ll_ts = ts._ll_tree_sequence - ls_hmm = _tskit.LsHmm(ll_ts, recombination, mutation, precision=precision) - cm_lib = _tskit.CompressedMatrix(ll_ts) - ls_hmm.forward_matrix(h, cm_lib) - F3 = cm_lib.decode() + if compare_lib: + ll_ts = ts._ll_tree_sequence + ls_hmm = _tskit.LsHmm(ll_ts, recombination, mutation, precision=precision) + cm_lib = _tskit.CompressedMatrix(ll_ts) + ls_hmm.forward_matrix(h, cm_lib) + F3 = cm_lib.decode() - assert_compressed_matrices_equal(cm, cm_lib) + assert_compressed_matrices_equal(cm, cm_lib) - nt.assert_allclose(F, F3) - nt.assert_allclose(c, cm_lib.normalisation_factor) - return cm_lib + nt.assert_allclose(F, F3) + nt.assert_allclose(c, cm_lib.normalisation_factor) + return cm -def check_backward_matrix(ts, h, forward_cm, recombination=None, mutation=None): +def check_backward_matrix( + ts, + h, + forward_cm, + recombination=None, + mutation=None, + match_all_nodes=False, + compare_lib=True, +): precision = 22 h = np.array(h).astype(np.int8) m = ts.num_sites @@ -1166,22 +1316,23 @@ def check_backward_matrix(ts, h, forward_cm, recombination=None, mutation=None): mutation, forward_cm.normalisation_factor, precision=precision, + match_all_nodes=match_all_nodes, ) nt.assert_array_equal( backward_cm.normalisation_factor, forward_cm.normalisation_factor ) + if compare_lib: + ll_ts = ts._ll_tree_sequence + ls_hmm = _tskit.LsHmm(ll_ts, recombination, mutation, precision=precision) + cm_lib = _tskit.CompressedMatrix(ll_ts) + ls_hmm.backward_matrix(h, forward_cm.normalisation_factor, cm_lib) - ll_ts = ts._ll_tree_sequence - ls_hmm = _tskit.LsHmm(ll_ts, recombination, mutation, precision=precision) - cm_lib = _tskit.CompressedMatrix(ll_ts) - ls_hmm.backward_matrix(h, forward_cm.normalisation_factor, cm_lib) - - assert_compressed_matrices_equal(backward_cm, cm_lib) + assert_compressed_matrices_equal(backward_cm, cm_lib) - B_lib = cm_lib.decode() - B_tree = backward_cm.decode() - nt.assert_allclose(B_tree, B_lib) - nt.assert_allclose(B, B_lib) + B_lib = cm_lib.decode() + B_tree = backward_cm.decode() + nt.assert_allclose(B_tree, B_lib) + nt.assert_allclose(B, B_lib) def add_unique_sample_mutations(ts, start=0): @@ -1221,8 +1372,8 @@ def test_match_sample(self, j): ts = self.ts() h = np.zeros(4) h[j] = 1 - path = check_viterbi(ts, h) - nt.assert_array_equal([j, j, j, j], path) + # path = check_viterbi(ts, h) + # nt.assert_array_equal([j, j, j, j], path) cm = check_forward_matrix(ts, h) check_backward_matrix(ts, h, cm) @@ -1262,11 +1413,48 @@ def test_switch_each_sample_missing_middle(self): h[1:3] = -1 path = check_viterbi(ts, h) # Implementation of Viterbi switches at right-most position - nt.assert_array_equal([0, 3, 3, 3], path) + nt.assert_array_equal([0, 0, 0, 3], path) cm = check_forward_matrix(ts, h) check_backward_matrix(ts, h, cm) +class TestSingleBalancedTreeAllSamplesExample: + # 3.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 2.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 1.00┊ 0 1 2 3 ┊ + # 0 8 + + @staticmethod + def ts(): + tables = tskit.Tree.generate_balanced(4, span=14).tree_sequence.dump_tables() + flags = tables.nodes.flags + flags[:] = 1 + tables.nodes.flags = flags + return add_unique_sample_mutations(tables.tree_sequence(), start=1) + + @pytest.mark.parametrize( + ("u", "h"), + [ + (0, [1, 0, 0, 0, 1, 0, 1]), + (1, [0, 1, 0, 0, 1, 0, 1]), + (2, [0, 0, 1, 0, 0, 1, 1]), + (3, [0, 0, 0, 1, 0, 1, 1]), + (4, [0, 0, 0, 0, 1, 0, 1]), + (5, [0, 0, 0, 0, 0, 1, 1]), + (6, [0, 0, 0, 0, 0, 0, 1]), + ], + ) + def test_match_sample(self, u, h): + np.set_printoptions(linewidth=1000, precision=3) + ts = self.ts() + path = check_viterbi(ts, h, match_all_nodes=True, compare_lib=False) + nt.assert_array_equal([u] * 7, path) + cm = check_forward_matrix(ts, h, match_all_nodes=True, compare_lib=False) + check_backward_matrix(ts, h, cm, match_all_nodes=True, compare_lib=False) + + class TestSimulationExamples: @pytest.mark.parametrize("n", [3, 10, 50]) @pytest.mark.parametrize("L", [1, 10, 100])