Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Investigate effect of polytomies / topological uncertainty on posteriors #359

Open
hyanwong opened this issue Jan 8, 2024 · 15 comments
Open

Comments

@hyanwong
Copy link
Member

hyanwong commented Jan 8, 2024

Hannes had an interesting idea: how do polytomies affect the variation in posterior times for a node? We could test this by taking a known topology and collapsing some of the nodes into polytomies, then dating, and looking at how the posterior distribution of times of the component nodes compares to the posterior distribution estimated for the collapsed polytomy.

@nspope
Copy link
Contributor

nspope commented Jan 8, 2024

another consideration is that a polytomy implies a greater total edge span than the original binary topologies; which I'd think would introduce bias. IIRC we don't see a relationship between arity and bias, however.

@hyanwong
Copy link
Member Author

hyanwong commented Apr 15, 2024

Now that we have a decent routine to create polytomies (tskit-dev/tskit#2926), we can test out the effect of polytomies on dating. Here's an example, using the true topologies, without and with induced polytomies (where edges without a mutation on them are removed to make a polytomy, see the trees below the plot). It appears as if making polytomies like this biases the mutation dates to younger times as the sample size increases:

download

Example first tree (original, then with induced polytomies):

Screenshot 2024-04-15 at 15 06 23

FWIW, the pattern doesn't change much if we use the metadata-stored mutation times instead.

Click to reveal code to reproduce the plot above:
import itertools
import collections
import numpy as np

def remove_edges(ts, edge_id_remove_list):
    edges_to_remove_by_child = collections.defaultdict(list)
    edge_id_remove_list = set(edge_id_remove_list)
    for remove_edge in edge_id_remove_list:
        e = ts.edge(remove_edge)
        edges_to_remove_by_child[e.child].append(e)

    # sort left-to-right for each child
    for k, v in edges_to_remove_by_child.items():
        edges_to_remove_by_child[k] = sorted(v, key=lambda e: e.left)
        # check no overlaps
        for e1, e2 in zip(edges_to_remove_by_child[k], edges_to_remove_by_child[k][1:]):
            assert e1.right <= e2.left

    # Sanity check: this means the topmost node will deal with modified edges left at the end
    assert ts.edge(-1).parent not in edges_to_remove_by_child
    
    new_edges = collections.defaultdict(list)
    tables = ts.dump_tables()
    tables.edges.clear()
    samples = set(ts.samples())
    # Edges are sorted by parent time, youngest first, so we can iterate over
    # nodes-as-parents visiting children before parents by using itertools.groupby
    for parent_id, ts_edges in itertools.groupby(ts.edges(), lambda e: e.parent):
        # Iterate through the ts edges *plus* the polytomy edges we created in previous steps.
        # This allows us to re-edit polytomy edges when the edges_to_remove are stacked
        edges = list(ts_edges)
        if parent_id in new_edges:
             edges += new_edges.pop(parent_id)
        if parent_id in edges_to_remove_by_child:
            for e in edges:
                assert parent_id == e.parent
                l = -1
                if e.id in edge_id_remove_list:
                    continue
                # NB: we go left to right along the target edges, reducing edge e as required
                for target_edge in edges_to_remove_by_child[parent_id]:
                    # As we go along the target_edges, gradually split e into chunks.
                    # If edge e is in the target_edge region, change the edge parent
                    assert target_edge.left > l
                    l = target_edge.left
                    if e.left >= target_edge.right:
                        # This target edge is entirely to the LHS of edge e, with no overlap
                        continue
                    elif e.right <= target_edge.left:
                        # This target edge is entirely to the RHS of edge e with no overlap.
                        # Since target edges are sorted by left coord, all other target edges
                        # are to RHS too, and we are finished dealing with edge e
                        tables.edges.append(e)
                        e = None
                        break
                    else:
                        # Edge e must overlap with current target edge somehow
                        if e.left < target_edge.left:
                            # Edge had region to LHS of target
                            # Add the left hand section (change the edge right coord)
                            tables.edges.add_row(left=e.left, right=target_edge.left, parent=e.parent, child=e.child)
                            e = e.replace(left=target_edge.left)
                        if e.right > target_edge.right:
                            # Edge continues after RHS of target
                            assert e.left < target_edge.right
                            new_edges[target_edge.parent].append(
                                e.replace(right=target_edge.right, parent=target_edge.parent)
                            )
                            e = e.replace(left=target_edge.right)
                        else:
                            # No more of edge e to RHS
                            assert e.left < e.right
                            new_edges[target_edge.parent].append(e.replace(parent=target_edge.parent))
                            e = None
                            break
                if e is not None:
                    # Need to add any remaining regions of edge back in 
                    tables.edges.append(e)
        else:
            # NB: sanity check at top means that the oldest node will have no edges above,
            # so the last iteration should hit this branch
            for e in edges:
                if e.id not in edge_id_remove_list:
                    tables.edges.append(e)
    assert len(new_edges) == 0
    tables.sort()
    return tables.tree_sequence()


def unsupported_edges(ts, per_interval=False):
    """
    Return the internal edges that are unsupported by a mutation.
    If ``per_interval`` is True, each interval needs to be supported,
    otherwise, a mutation on an edge (even if there are multiple intervals
    per edge) will result in all intervals on that edge being treated
    as supported.
    """
    edges_to_remove = np.ones(ts.num_edges, dtype="bool")
    edges_to_remove[[m.edge for m in ts.mutations()]] = False
    # We don't remove edges above samples
    edges_to_remove[np.isin(ts.edges_child, ts.samples())] = False

    if per_interval:
        return np.where(edges_to_remove)[0]
    else:
        keep = (edges_to_remove == False)
        for p, c in zip(ts.edges_parent[keep], ts.edges_child[keep]):
            edges_to_remove[np.logical_and(ts.edges_parent == p, ts.edges_child == c)] = False
        return np.where(edges_to_remove)[0]

###########

from matplotlib import pyplot as plt
import stdpopsim
import tsdate

print(f"Using tsdate {tsdate.__version__}")
species = stdpopsim.get_species("HomSap")
model = species.get_demographic_model("AmericanAdmixture_4B11")
contig = species.get_contig("chr20", mutation_rate=model.mutation_rate, length_multiplier=0.1)
engine = stdpopsim.get_engine("msprime")

sizes = (1, 10, 100, 1000, 10000)
fig, axes = plt.subplots(len(sizes), 2, figsize=(10, 5*len(sizes)))
axes[0][0].set_title("True topology")
axes[0][1].set_title("Topology with induced polytomies")
axes[-1][0].set_xlabel("True mutation times")
axes[-1][1].set_xlabel("True mutation times")
for ax, s in zip(axes, sizes):
    samples = {'AFR': s, 'EUR': s, 'ASIA': s, 'ADMIX': s}
    ts = engine.simulate(model, contig, samples, seed=123)
    print(ts.num_trees, "trees,", ts.num_sites, "sites,", ts.num_edges, "edges")
    poly_ts = remove_edges(ts, unsupported_edges(ts))
    print(poly_ts.num_edges, "edges in unresolved ts")
    # Check it is doing the right thing
    dated_ts = tsdate.variational_gamma(ts.simplify(), mutation_rate=model.mutation_rate, normalisation_intervals=100)
    dated_poly_ts = tsdate.variational_gamma(poly_ts.simplify(), mutation_rate=model.mutation_rate, normalisation_intervals=100)
    x = [max(m.time for m in s.mutations) for s in ts.sites()]
    y = [max(m.time for m in s.mutations) for s in dated_ts.sites()]    
    y_poly = [max(m.time for m in s.mutations) for s in dated_poly_ts.sites()] 
    ax[0].set_ylabel(f"Inferred times ({ts.num_samples} samples)")

    ax[0].hexbin(x, y, bins="log", xscale="log", yscale="log")
    ax[0].plot(np.logspace(1, 5), np.logspace(1, 5), "-", c="red")

    ax[1].hexbin(x, y_poly, bins="log", xscale="log", yscale="log")
    ax[1].plot(np.logspace(1, 5), np.logspace(1, 5), "-", c="red")

resulting in the plot above, and outputting:

Using tsdate 0.1.dev885+g36d81f4
7367 trees, 13103 sites, 22725 edges
30148 edges in unresolved ts
19649 trees, 31256 sites, 71580 edges
131054 edges in unresolved ts
48806 trees, 72224 sites, 190253 edges
417203 edges in unresolved ts
107988 trees, 154211 sites, 444165 edges
1222924 edges in unresolved ts
201490 trees, 284048 sites, 971433 edges

@jeromekelleher
Copy link
Member

Nice!

@nspope
Copy link
Contributor

nspope commented Apr 15, 2024

This is great @hyanwong, thanks. I can think of a few things to try that might reduce bias-- will report back.

@hyanwong
Copy link
Member Author

hyanwong commented Apr 15, 2024

Thanks @nspope : from a few tests it appears as if the bias is less pronounced in tsinfer inferred tree sequences. Plots below - the right hand column is tsinfer on the same data:

download

@hyanwong
Copy link
Member Author

hyanwong commented Apr 16, 2024

As an aside, I wondered if reducing to the topology only present at each variable site would change the bias, but it doesn't seem to very much
download

Code here Using the code above, plus
import numpy as np
import scipy

import tsinfer
# Warning - this take a long time (e.g. 10 hours)
its = tsinfer.infer(tsinfer.SampleData.from_tree_sequence(ts), progress_monitor=True, num_threads=8)
dated_its = tsdate.variational_gamma(its.simplify(filter_sites=False), mutation_rate=model.mutation_rate, normalisation_intervals=100)

dated_reduced_ts = tsdate.variational_gamma(
    remove_edges(ts, unsupported_edges(ts)).simplify(reduce_to_site_topology=True),
    mutation_rate=model.mutation_rate,
    normalisation_intervals=100
)

fig, axes = plt.subplots(1, 4, figsize=(20, 5))
axes[0].set_title("True topology")
axes[1].set_title("Topology with induced polytomies")
axes[2].set_title("Topology reduced to variable_sites & polytomies")
axes[3].set_title("Tsinferred")
axes[0].set_xlabel("True mutation times")
axes[1].set_xlabel("True mutation times")
axes[2].set_xlabel("True mutation times")
axes[3].set_xlabel("True mutation times")

x = [max(m.time for m in s.mutations) for s in ts.sites()]
y = [max(m.time for m in s.mutations) for s in dated_ts.sites()]    
y_poly = [max(m.time for m in s.mutations) for s in dated_poly_ts.sites()] 
y_red = np.array([
    max((json.loads(m.metadata.decode())["mn"] for m in s.mutations), default = np.nan) for s in dated_reduced_ts.sites()
])
y_inf = np.array([
   max((json.loads(m.metadata.decode())["mn"] for m in s.mutations), default = np.nan) for s in dated_its.sites()
])

axes[0].set_ylabel(f"Inferred times ({ts.num_samples} samples)")

axes[0].hexbin(x, y, bins="log", xscale="log", yscale="log")
axes[0].plot(np.logspace(0, 5), np.logspace(0, 5), "-", c="red")
bias = np.mean(np.log(y) - np.log(x))
rho = scipy.stats.spearmanr(np.log(x), np.log(y)).statistic
axes[0].text(1e-4, 2e4, f"Rho: {rho:.5f}\nBias: {bias:.5f}")
    
axes[1].hexbin(x, y_poly, bins="log", xscale="log", yscale="log")
axes[1].plot(np.logspace(0, 5), np.logspace(0, 5), "-", c="red")
bias = np.mean(np.log(y_poly) - np.log(x))
rho = scipy.stats.spearmanr(np.log(x), np.log(y_poly)).statistic
axes[1].text(1e-4, 2e4, f"Rho: {rho:.5f}\nBias: {bias:.5f}")

axes[2].hexbin(x, y_red, bins="log", xscale="log", yscale="log")
axes[2].plot(np.logspace(0, 5), np.logspace(0, 5), "-", c="red")
bias = np.mean(np.log(y_red) - np.log(x))
rho = scipy.stats.spearmanr(np.log(x), np.log(y_red)).statistic
axes[2].text(1e-4, 2e4, f"Rho: {rho:.5f}\nBias: {bias:.5f}")

axes[3].hexbin(x, y_inf, bins="log", xscale="log", yscale="log")
axes[3].plot(np.logspace(0, 5), np.logspace(0, 5), "-", c="red")
bias = np.nanmean(np.log(y_inf) - np.log(x))
rho = scipy.stats.spearmanr(np.log(x), np.log(y_inf), nan_policy="omit").statistic
axes[3].text(1e-4, 2e4, f"Rho: {rho:.5f}\nBias: {bias:.5f}")

@nspope
Copy link
Contributor

nspope commented Apr 26, 2024

Looking first at node ages ... the reason there's bias in dating nodes after introducing polytomies is because there's more mutational area than was in the original binary trees. E.g. we're increasing the total branch length, which means that when we match moments using segregating sites we end up shrinking the timescale.

To be a bit more precise: the current normalisation strategy calculates total edge area and total number of mutations, then rescales time such that the expected number of mutations matches the total number of mutations.

Instead, consider doing the following: for each tree, sample a path from a randomly selected leaf to the root. Only accumulate edge area and mutations on the sampled paths. This should be unbiased, because the "path length" is the same regardless of the presence of polytomies. In fact, we can do this sampling deterministically, because the probability that a randomly selected path passes through a given edge is proportional to the number of samples subtended by that edge. E.g. we normalise as before but weight edges by the number of samples they subtend.

Using this alternative "path normalisation" strategy seems to greatly help with bias (1000 samples, 10 Mb):

norm_poly-nodes

This more-or-less carries over for mutations:

norm_poly-muts

@hyanwong
Copy link
Member Author

hyanwong commented Apr 26, 2024

Oh wow. This is amazing. Thanks Nate.

Does it cause any overcorrection problems for tsinferred tree sequences? I assume it shouldn't...

@nspope
Copy link
Contributor

nspope commented Apr 27, 2024

Another way to phrase this is that we're moment matching against a different summary statistic (rather than segregating sites), that is the expected number of differences between a single sample and the root. In my opinion this choice of summary statistic is a more conceptually straightforward way to measure time with mutational density.

@nspope
Copy link
Contributor

nspope commented Apr 27, 2024

I did a quick check on inferred simulated tree sequences -- the original routine was more or less unbiased (as Yan observed above) and the new routine does about the same. Would be interesting to compare the two on real data. Regardless, this new routine seems like the right approach.

@hyanwong
Copy link
Member Author

the new routine does about the same

That's great.

Regardless, this new routine seems like the right approach.

Absolutely. We should go with the new approach. I wonder how both approaches perform on reinference? I can check this once there's instructions for how to run the new version.

@nspope
Copy link
Contributor

nspope commented Apr 27, 2024

The API is exactly the same, with the new normalisation scheme used by default. The old normalisation scheme can be toggled if you pass match_segregating_sites=True to date.

@hyanwong
Copy link
Member Author

Great, thanks for the info. Is it currently much slower than the old version? It seems maybe not?

@nspope
Copy link
Contributor

nspope commented Apr 27, 2024

It shouldn't be, but would be good to check (if you enable logging it'll print out time spent during normalisation). There's an additional pass over edges, but this is done in numba. It might add few minutes on GEL or UKBB sized data, so would be good to enable logging there to get a sense for the overhead.

@nspope
Copy link
Contributor

nspope commented Apr 28, 2024

I wonder how both approaches perform on reinference

Actually, I don't think it'll change reinference at all -- ancestor building just uses the ordering of mutations, right? Normalisation won't change the order, just the inter-node time differences.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants