-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use of, and alternatives to, to a "global" (i.e. node-agnostic) prior #292
Comments
Here's an interesting way to plot the expected node times in a 4-tip tree sequence, depending on the mixture of sample descendants. We can use a barycentric (i.e. triangular) plot to represent the ratios of the span over which a node has 2 : 3 : 4 descendants, and then show separate plots for different total node span lengths. Something like this for 100 replicate simulations of 100Mb simulated tree sequences: code: import msprime
import numpy as np
import tqdm
from matplotlib import pyplot as plt
import scipy
reps = 100
# Save into a massive matrix with 4 columns:
# 2span 3span 4span time
nodes = None
for ts in tqdm.tqdm(
msprime.sim_ancestry(4, ploidy=1, population_size=1e4, sequence_length=1e8, recombination_rate=1e-8, num_replicates=reps)
):
node_arr = np.zeros((ts.num_nodes - ts.num_samples, ts.num_samples))
node_arr[:, 3] = ts.nodes_time[ts.num_samples:]
for tree in ts.trees():
for u in tree.nodes():
if tree.is_internal(u):
node_arr[u-ts.num_samples][tree.num_samples(u) - 2] += tree.interval.span
if nodes is None:
nodes = node_arr
else:
nodes = np.vstack((nodes, node_arr))
print("Total data (nodes, spans + time)", nodes.shape)
print("Mean node time is", np.mean(nodes[:,3]), ts.time_units)
a = nodes[:,0]
b = nodes[:,1]
c = nodes[:,2]
tot = a + b + c
times = nodes[:,3]
def get_cartesian_from_barycentric(b, t):
return t.dot(b)
t = np.transpose(np.array([[0,0],[1,0],[1/2,np.sqrt(3)/2]])) # Triangle
abc = np.array([a/tot, b/tot, c/tot])
xy = get_cartesian_from_barycentric(abc, t)
x = xy[0,:]
y = xy[1,:]
z = np.log(times)
fig, axes = plt.subplots(nrows=3, figsize=(7, 15))
zz = scipy.stats.binned_statistic_2d(x, y, np.log(times), bins=50, statistic="mean")[0]
levels=np.logspace(np.log10(np.nanmin(np.exp(zz))), np.log10(np.nanmax(np.exp(zz))), 10)
levels=np.linspace(np.nanmin(np.exp(zz)), np.nanmax(np.exp(zz)), 10)
for ax, (lower, upper) in zip(
axes,
itertools.pairwise(np.quantile(tot, [0, 0.7, 0.9, 1])),
):
use = np.logical_and(tot >= lower, tot < upper)
bins = 50
zz = scipy.stats.binned_statistic_2d(x[use], y[use], z[use], bins=50, statistic="mean")[0].T
cntr1 = ax.contourf(np.exp(zz), levels=levels)
ax.set_ylim(-bins * 0.1, bins * 1.1)
ax.text(*(t[:,0] / np.max(t, axis=1)* bins),
f"100% {2} sample\ndescendants\n(single tree theor: {2500}) {ts.time_units[:3]}",
ha="left", va="top")
ax.text(*(t[:,1] / np.max(t, axis=1)* bins),
f"100% {2+1} sample\ndescendants\n(single tree theor: {5000}) {ts.time_units[:3]}",
ha="right", va="top")
ax.text(*(t[:,2] / np.max(t, axis=1)* bins),
f"100% {2+2} sample descendants (single tree theor: {15000}) {ts.time_units[:3]}",
ha="center")
ax.set_axis_off()
ax.set_title(f"{np.sum(use)} nodes spanning {lower/1000:5g}-{upper/1000:.5g} kb")
cbar = fig.colorbar(cntr1, ax=ax)
cbar.ax.set_ylabel(f"Mean node time ({ts.time_units})")
plt.show() |
We can see what our (poorly performing) conditional coalescent mixture looks like, by substituting in the expected times for the real times. I.e. putting import tsdate
nm = {n: i for i, n in enumerate(tsdate.prior.PriorParams._fields)}
cc = tsdate.prior.ConditionalCoalescentTimes(100)
cc.add(4)
def mixture_expect_and_var(mixture, cond_coal):
"""
Return the expectation and variance of a coalescent mixture
mixture is a dict of the form N:{'descendant_tips': [tips], 'weight': [weights]}
"""
expectation = 0
first = secnd = 0
for N, tip_dict in mixture.items():
# assert 1 not in tip_dict.descendant_tips
mean = cond_coal[N][tip_dict["descendant_tips"], nm["mean"]]
var = cond_coal[N][tip_dict["descendant_tips"], nm["var"]]
# Mixture expectation
expectation += np.sum(mean * tip_dict["weight"])
# Mixture variance
first += np.sum(var * tip_dict["weight"])
secnd += np.sum(mean**2 * tip_dict["weight"])
mean = expectation
var = first + secnd - (expectation**2)
return mean, var
expected_time = np.zeros(len(times))
for i, n in enumerate(nodes):
params = {4: {'descendant_tips': [2, 3, 4], 'weight': n[:3]/np.sum(n[:3])}}
expected_time[i] = mixture_expect_and_var(params, cc)[0] * 1e4 This gives an equivalent set of plots which are identical to each other (as the mixture prior does not account for node spans). It looks like this: |
It looks to me that for a 4 tip tree, the pattern (sloping diagonal) in the mixture prior is correct, but the expectation is being completely swamped by the effect of node span. It should be relatively easy to fit a statistical model to find a decent set of predictors, I think (especially since we can generate as much data as we like). We should do the same for variance, of course. As a start, we could fit a simple linear model between observed time ( This would be easy to test as a statistical model on much larger numbers of samples. I think @a-ignatieva has some predictions linking edge span (which she calls "duration") with time, so we might be able to use tha, although it's not a direct measure of node span. |
Here's a useful plot, showing how we deviate from the mixture prior given the length of the nodes. As you can see, the longer the node span, the younger the observed time relative to what we expect. We can repeat this for larger tree sequences, of course fig, axes = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(10, 5))
axes[0].hexbin(np.log(times)-np.log(expected_time), tot, yscale="log", norm=matplotlib.colors.LogNorm(1, 5000))
axes[0].set_ylabel("Node span (bp)")
axes[0].set_xlabel("Deviation from log expected time")
axes[1].scatter(np.log(times)-np.log(expected_time), tot, alpha=0.01)
axes[1].set_yscale("log")
axes[1].set_xlabel("Deviation from log expected time"); |
Very nice. Something else that might be useful: we can take a matrix |
Hmm, interesting thought. Also, if our problem is mostly to do not with the mean prior assigned to each node, but the variance, then we should try to do a similar plot with variance too. Here's one attempt, using data from simulated tree sequences of 100 samples: If I've got the code right (below), then it looks like the variance problem is more serious, as you thought, @nspope . In particular, we estimate much too high a variance for long spans, and too low a variance (which is presumably worse) for nodes with short spans. import tsdate
nm = {n: i for i, n in enumerate(tsdate.prior.PriorParams._fields)}
cc = tsdate.prior.ConditionalCoalescentTimes(1000)
def mixture_expect_and_var(mixture, cond_coal):
"""
Return the expectation and variance of a coalescent mixture
mixture is a dict of the form N:{'descendant_tips': [tips], 'weight': [weights]}
"""
expectation = 0
first = secnd = 0
for N, tip_dict in mixture.items():
# assert 1 not in tip_dict.descendant_tips
mean = cond_coal[N][tip_dict["descendant_tips"], nm["mean"]]
var = cond_coal[N][tip_dict["descendant_tips"], nm["var"]]
# Mixture expectation
expectation += np.sum(mean * tip_dict["weight"])
# Mixture variance
first += np.sum(var * tip_dict["weight"])
secnd += np.sum(mean**2 * tip_dict["weight"])
mean = expectation
var = first + secnd - (expectation**2)
return mean, var
import msprime
import numpy as np
import tqdm
reps = 10
n = 100
cc.add(n)
nstat = None
# Save into a massive matrix with 7 columns:
for ts in tqdm.tqdm(
msprime.sim_ancestry(n, ploidy=1, population_size=1e4, sequence_length=5e7, recombination_rate=1e-8, num_replicates=reps)
):
n_internal = ts.num_nodes - ts.num_samples
node_arr = np.zeros((n_internal, ts.num_samples - 1))
for tree in ts.trees():
for u in tree.nodes():
if tree.is_internal(u):
node_arr[u-ts.num_samples, tree.num_samples(u) - 2] += tree.interval.span
stat = np.array([
ts.nodes_time[ts.num_samples:], # 0 actual time
np.sum(node_arr, axis=1), # 1 node span
np.zeros(n_internal), # 2 will be expected (mixture) time
np.zeros(n_internal), # 3 will be expected (mixture) variance
np.zeros(n_internal), # 4 will be mean ntips
np.zeros(n_internal), # 5 will be smallest nonzero span ntips
np.zeros(n_internal), # 6 will be largest nonzero span ntips
]).T
ntips = np.arange(2, n+1)
for u in np.arange(n_internal):
weights = node_arr[u,:]/np.sum(node_arr[u,:])
params = {n: {'descendant_tips': ntips, 'weight': weights}}
mean_var = mixture_expect_and_var(params, cc)
stat[u, 2] = mean_var[0] * 1e4
stat[u, 3] = mean_var[1] * 1e4 * 1e4
stat[u, 4] = np.sum(ntips * weights)
stat[u, 5] = np.flatnonzero(node_arr[u,:])[0] + 2
stat[u, 6] = np.flatnonzero(node_arr[u,:])[-1] + 2
if nstat is None:
nstat = stat
else:
nstat = np.vstack((nstat, stat))
fig, axes = plt.subplots(1, 2, sharey=True, figsize=(10, 5))
axes[0].hexbin(np.log(nstat[:,0])-np.log(nstat[:,2]), nstat[:,1], yscale="log", norm=matplotlib.colors.LogNorm(1, 100))
#axes[0].set_ylabel("Node span (bp)")
axes[0].set_ylabel("Node span")
axes[0].set_xlabel("Deviation from log expected time")
for lower, upper in itertools.pairwise(np.quantile(nstat[:,1], np.linspace(0, 1, 100))):
use = np.logical_and(nstat[:,1] >= lower, nstat[:,1] < upper)
var = np.var(nstat[:,0][use], ddof=1)
expected_var = np.mean(nstat[:,3][use])
axes[1].plot(np.log(var)-np.log(expected_var), (lower+upper)/2, "bo")
axes[1].set_xlabel("Deviation from log expected variance in time"); |
And here's the same for the average number of descendant samples on the Y axis, rather than the node span (i.e. using |
So far we have been calculating the mixtures by the mean number of descendant samples weighted by the span. But I wonder if it is better to use the log of the span as a weight instead, or even not weight at all. |
I've tried without weighting by span at all -- this still introduces artefacts. I think the fact that there is a dependence b/w span and age means that you'd need a correction that depends on both of these (like in the linear model you've spelled out above). |
I think that's right, but a linear model that incorporates (log) length of the node might still fit better with a log-weighted (or unweighted) mixture component. By the way, from a few statistical fits I have tried, it seems that we don't really need an interaction term between log(span) and mixture value. Anyway, this is possibly getting into the weeds a little, if we are happy to use a global prior! |
Well, I think this is great to keep in mind for future improvements. For now, the global prior seems to work reasonably well. But, it'd be great come up with a way to generate the global prior without the very costly variance calculation for the mixture. I think, for example, if we just calculate the first moment (linear complexity in number of tips) for each node, then calculate the mean/variance across all nodes, we'll get a reasonable global prior. This would work for trees with 10s or 100s of thousands of tips, without interpolating. |
FWIW, I have just tested the log-weighting, and it makes both the mean time and the variance in time fit better. The mixture variances cluster more closely to the observed variances, and the fit of expectations to observed times, both in a linear model with node span and a linear model without node span is better. So I think it is worth testing this on a full tsdate simulation anyway. There's no harm in weighting by the log of the span, and I think that nodes spans are more close to being exponentially distributed than linear. I assume there is some theory about the sizes of chunks when a line is repeatedly cut up at random positions - it feels somewhat exponential to me. |
There's something I'm not quite understanding here, because my plots above imply that the mixture prior assigns too great a variance (and slightly too great an expected mean) for old, long nodes. But the plot at #257 (comment) implies that we set the time of those old nodes too low when using the vanilla mixture prior. Have I got things back-to-front somehow? |
@nspope has found problems with the conditional coalescent mixture priors used by tsdate. Surprisingly, better results are gained by averaging the priors over all the nodes, and using this single distribution as a prior for all nodes (e.g. see #257 (comment)).
I have opened up this issue as a place to discuss what the best prior strategy is. We should be able to do better than the naive global approach, but I'm not sure how!
(update Summer 2024: rather than constructing a global prior from the coalescent, we now use a "flat" prior for all except the root nodes - although this is actually not truly flat, as internal nodes are constrained by the position of the roots, see #425. This makes tsdate much more robust to weird demographies etc, and allows for historical samples. Although for simple demographies with all-contemporary samples, setting coalescent-based priors on node times might give more power, for the moment we get better and more robust results without making coalescent assumptions, so this issue is more of a "nice-to-have")
The text was updated successfully, but these errors were encountered: