Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plot "edge-node time consistency" #953

Open
hyanwong opened this issue Aug 22, 2024 · 0 comments
Open

Plot "edge-node time consistency" #953

hyanwong opened this issue Aug 22, 2024 · 0 comments

Comments

@hyanwong
Copy link
Member

hyanwong commented Aug 22, 2024

For improving inference accuracy, including decent ancestor reconstruction, we don't really care about the absolute times of the nodes under each mutation. Rather, we want to know that the "local" node order is correct. In fact, the only thing we really want is for the parent node and child node of each edge in the true tree sequence to be in the right order in the inferred tree sequence.

Therefore, to compare the accuracy of our inference (in order to improve it) we can match nodes between true and inferred tree sequence (perhaps on the basis of the node under each mutation?), then take each edge in the true tree sequence and ask if the inferred (and dated) tree sequence has the equivalent nodes in the right order (i.e. parent > child time). This should give us something to aim for re improving inference.

I'll figure out some plots to show improvement in this stat, but meanwhile I think this is a reasonable way to test:

import tsinfer
import tskit
import tsdate
import numpy as np
import msprime

# Simulate
sim_ts = msprime.sim_ancestry(50, sequence_length=1e6, population_size=1e4, recombination_rate=1e-8, random_seed=1)
sim_ts = msprime.sim_mutations(sim_ts, rate=1e-8, random_seed=1)
print("Simulated", sim_ts.num_sites, "sites", sim_ts.num_trees, "trees")

# Infer
use_sites_time = False
info = "With true site times " if use_sites_time else ""
ts = tsinfer.infer(tsinfer.SampleData.from_tree_sequence(
    sim_ts,
    use_sites_time=use_sites_time,
))

# Date
pts = tsdate.preprocess_ts(ts)
dts = tsdate.date(pts, mutation_rate=1e-8)

def edge_node_compat(orig_ts, new_ts):
    # map the node in the original to a node in the new one, if possible
    corresponding_nodes = np.full(orig_ts.num_nodes, -1, dtype=orig_ts.edges_child.dtype)
    # find mutations below each node
    assert new_ts.num_sites == orig_ts.num_sites
    for new_site, orig_site in zip(new_ts.sites(), orig_ts.sites()):
        if len(new_site.mutations) and len(orig_site.mutations):
            # first mutation is always oldest, by tskit definition
            corresponding_nodes[orig_site.mutations[0].node] = new_site.mutations[0].node
    
    unique_child_parent = np.unique(np.array([orig_ts.edges_child, orig_ts.edges_parent]).T, axis=0)
    nodes_time = np.concatenate((new_ts.nodes_time, [-1]))
    child_times_in_new = nodes_time[corresponding_nodes[orig_ts.edges_child[unique_child_parent[:, 0]]]]
    parent_times_in_new = nodes_time[corresponding_nodes[orig_ts.edges_parent[unique_child_parent[:, 1]]]]
    used = np.logical_and(child_times_in_new >= 0, parent_times_in_new >= 0)
    compat = child_times_in_new[used] < parent_times_in_new[used]
    return compat, used


good, use = edge_node_compat(sim_ts, ts)
print(f"{info}{sum(good) / len(good) * 100:.2f}% ({sum(good)}) of true edges have inferred parent time older than inferred child time")
print(f"(but {sum(use==0) / len(use) * 100:.2f}% of nodes in the true ts have no associated mutation for comparison)")

good, use = edge_node_compat(sim_ts, pts)
print(f"After preprocessing, {sum(good) / len(good) * 100:.2f}%  ({sum(good)}) of true edges have inferred parent time older than inferred child time")
print(f"(but {sum(use==0) / len(use) * 100:.2f}% of nodes in the true ts have no associated mutation)")

good, use = edge_node_compat(sim_ts, dts)
print(f"After dating, {sum(good) / len(good) * 100:.2f}%  ({sum(good)}) of true edges have inferred parent time older than inferred child time")
print(f"(but {sum(use==0) / len(use) * 100:.2f}% of nodes in the true ts have no associated mutation)")

Giving (in this simplest simulation example):

Simulated 2047 sites 1758 trees
86.53% (2236) of true edges have inferred parent time older than inferred child time
(but 54.88% of nodes in the true ts have no associated mutation for comparison)
After preprocessing, 87.85%  (2270) of true edges have inferred parent time older than inferred child time
(but 54.88% of nodes in the true ts have no associated mutation)
After dating, 95.82%  (2476) of true edges have inferred parent time older than inferred child time
(but 54.88% of nodes in the true ts have no associated mutation)

Weirdly, setting use_sites_time=True gives somewhat under 100%. I'm not sure why this is. When we infer and date, we do worse than using the true times.

Simulated 2047 sites 1758 trees
With true site times 98.92% (2556) of true edges have inferred parent time older than inferred child time
(but 54.88% of nodes in the true ts have no associated mutation for comparison)
After preprocessing, 99.69%  (2576) of true edges have inferred parent time older than inferred child time
(but 54.88% of nodes in the true ts have no associated mutation)
After dating, 96.56%  (2495) of true edges have inferred parent time older than inferred child time
(but 54.88% of nodes in the true ts have no associated mutation)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant