Skip to content

Commit

Permalink
Site divmat version with O(n^2) per mutation
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Aug 4, 2023
1 parent bef0b4a commit 4945dd8
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 89 deletions.
7 changes: 7 additions & 0 deletions c/tests/test_stats.c
Original file line number Diff line number Diff line change
Expand Up @@ -2041,6 +2041,13 @@ test_simplest_divergence_matrix(void)
ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, NULL, 0, result);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS);

sample_ids[0] = 1;
ret = tsk_treeseq_divergence_matrix(&ts, 2, sample_ids, 0, NULL, 0, result);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_DUPLICATE_SAMPLE);
ret = tsk_treeseq_divergence_matrix(
&ts, 2, sample_ids, 0, NULL, TSK_STAT_BRANCH, result);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_DUPLICATE_SAMPLE);

tsk_treeseq_free(&ts);
}

Expand Down
179 changes: 123 additions & 56 deletions c/tskit/trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -6597,77 +6597,112 @@ tsk_treeseq_divergence_matrix_branch(const tsk_treeseq_t *self, tsk_size_t num_s
return ret;
}

static tsk_size_t
count_mutations_on_path(tsk_id_t u, tsk_id_t v, const tsk_id_t *restrict parent,
const double *restrict time, const tsk_size_t *restrict mutations_per_node)
static void
increment_divergence_matrix_pairs(const tsk_size_t len_A, const tsk_id_t *restrict A,
const tsk_size_t len_B, const tsk_id_t *restrict B, double *restrict D)
{
double tu, tv;
tsk_size_t count = 0;

tu = time[u];
tv = time[v];
while (u != v) {
if (tu < tv) {
count += mutations_per_node[u];
u = parent[u];
if (u == TSK_NULL) {
break;
}
tu = time[u];
} else {
count += mutations_per_node[v];
v = parent[v];
if (v == TSK_NULL) {
break;
tsk_id_t u, v;
tsk_size_t j, k;
const tsk_id_t n = (tsk_id_t)(len_A + len_B);

for (j = 0; j < len_A; j++) {
for (k = 0; k < len_B; k++) {
u = A[j];
v = B[k];
/* Only increment the upper triangle to (hopefully) improve memory
* access patterns */
if (u > v) {
v = A[j];
u = B[k];
}
tv = time[v];
D[u * n + v]++;
}
}
if (u != v) {
while (u != TSK_NULL) {
count += mutations_per_node[u];
u = parent[u];
}

static void
update_site_divergence(const tsk_tree_t *tree, tsk_id_t node,
const tsk_id_t *sample_index_map, tsk_size_t num_samples, tsk_id_t *restrict stack,
int8_t *restrict descending_bitset, tsk_id_t *restrict descending_list,
tsk_id_t *restrict not_descending_list, double *D)
{
const tsk_id_t *restrict left_child = tree->left_child;
const tsk_id_t *restrict right_sib = tree->right_sib;
int stack_top;
tsk_id_t a, u, v;
tsk_size_t j, num_descending, num_not_descending;

tsk_memset(descending_bitset, 0, num_samples * sizeof(*descending_bitset));

stack_top = 0;
stack[stack_top] = node;
while (stack_top >= 0) {
u = stack[stack_top];
stack_top--;
a = sample_index_map[u];
if (a != TSK_NULL) {
descending_bitset[a] = 1;
}
while (v != TSK_NULL) {
count += mutations_per_node[v];
v = parent[v];
for (v = left_child[u]; v != TSK_NULL; v = right_sib[v]) {
stack_top++;
stack[stack_top] = v;
}
}

num_descending = 0;
num_not_descending = 0;
for (j = 0; j < num_samples; j++) {
if (descending_bitset[j]) {
descending_list[num_descending] = (tsk_id_t) j;
num_descending++;
} else {
not_descending_list[num_not_descending] = (tsk_id_t) j;
num_not_descending++;
}
}
return count;
tsk_bug_assert(num_descending + num_not_descending == num_samples);

increment_divergence_matrix_pairs(
num_descending, descending_list, num_not_descending, not_descending_list, D);
}

static int
tsk_treeseq_divergence_matrix_site(const tsk_treeseq_t *self, tsk_size_t num_samples,
const tsk_id_t *restrict samples, tsk_size_t num_windows,
const double *restrict windows, tsk_flags_t TSK_UNUSED(options),
tsk_size_t num_windows, const double *restrict windows,
tsk_flags_t TSK_UNUSED(options), const tsk_id_t *restrict sample_index_map,
double *restrict result)
{
int ret = 0;
tsk_tree_t tree;
const tsk_size_t n = num_samples;
const tsk_size_t num_nodes = self->tables->nodes.num_rows;
const double *restrict nodes_time = self->tables->nodes.time;
tsk_size_t i, j, k, tree_site, tree_mut;
tsk_size_t i, tree_site, tree_mut;
tsk_site_t site;
tsk_mutation_t mut;
tsk_id_t u, v;
double left, right, span_left, span_right;
double *restrict D;
tsk_size_t *mutations_per_node = tsk_malloc(num_nodes * sizeof(*mutations_per_node));
const tsk_size_t num_nodes = self->tables->nodes.num_rows;
int8_t *descending_bitset = tsk_malloc(num_samples * sizeof(*descending_bitset));
tsk_id_t *descending_list = tsk_malloc(num_samples * sizeof(*descending_list));
tsk_id_t *not_descending_list
= tsk_malloc(num_samples * sizeof(*not_descending_list));
/* Do *not* use tsk_tree_get_size bound here because it gives a per-tree
* bound, not a global one! */
tsk_id_t *stack = tsk_malloc(num_nodes * sizeof(*stack));

ret = tsk_tree_init(&tree, self, 0);
if (ret != 0) {
goto out;
}
if (mutations_per_node == NULL) {

if (descending_bitset == NULL || descending_list == NULL
|| not_descending_list == NULL || stack == NULL) {
ret = TSK_ERR_NO_MEMORY;
goto out;
}

for (i = 0; i < num_windows; i++) {
left = windows[i];
right = windows[i + 1];
D = result + i * n * n;
D = result + i * num_samples * num_samples;
ret = tsk_tree_seek(&tree, left, 0);
if (ret != 0) {
goto out;
Expand All @@ -6676,29 +6711,18 @@ tsk_treeseq_divergence_matrix_site(const tsk_treeseq_t *self, tsk_size_t num_sam
span_left = TSK_MAX(tree.interval.left, left);
span_right = TSK_MIN(tree.interval.right, right);

/* NOTE: we could avoid this full memset across all nodes by doing
* the same loops again and decrementing at the end of the main
* tree-loop. It's probably not worth it though, because of the
* overwhelming O(n^2) below */
tsk_memset(mutations_per_node, 0, num_nodes * sizeof(*mutations_per_node));
for (tree_site = 0; tree_site < tree.sites_length; tree_site++) {
site = tree.sites[tree_site];
if (span_left <= site.position && site.position < span_right) {
for (tree_mut = 0; tree_mut < site.mutations_length; tree_mut++) {
mut = site.mutations[tree_mut];
mutations_per_node[mut.node]++;
update_site_divergence(&tree, mut.node, sample_index_map,
num_samples, stack, descending_bitset, descending_list,
not_descending_list, D);
}
}
}

for (j = 0; j < n; j++) {
u = samples[j];
for (k = j + 1; k < n; k++) {
v = samples[k];
D[j * n + k] += (double) count_mutations_on_path(
u, v, tree.parent, nodes_time, mutations_per_node);
}
}
ret = tsk_tree_next(&tree);
if (ret < 0) {
goto out;
Expand All @@ -6708,7 +6732,42 @@ tsk_treeseq_divergence_matrix_site(const tsk_treeseq_t *self, tsk_size_t num_sam
ret = 0;
out:
tsk_tree_free(&tree);
tsk_safe_free(mutations_per_node);
tsk_safe_free(descending_bitset);
tsk_safe_free(descending_list);
tsk_safe_free(not_descending_list);
tsk_safe_free(stack);
return ret;
}

static int
get_sample_index_map(const tsk_size_t num_nodes, const tsk_size_t num_samples,
const tsk_id_t *restrict samples, tsk_id_t **ret_sample_index_map)
{
int ret = 0;
tsk_size_t j;
tsk_id_t u;
tsk_id_t *sample_index_map = tsk_malloc(num_nodes * sizeof(*sample_index_map));

if (sample_index_map == NULL) {
ret = TSK_ERR_NO_MEMORY;
goto out;
}
/* Assign the output pointer here so that it will be freed in the case
* of an error raised in the input checking */
*ret_sample_index_map = sample_index_map;

for (j = 0; j < num_nodes; j++) {
sample_index_map[j] = TSK_NULL;
}
for (j = 0; j < num_samples; j++) {
u = samples[j];
if (sample_index_map[u] != TSK_NULL) {
ret = TSK_ERR_DUPLICATE_SAMPLE;
goto out;
}
sample_index_map[u] = (tsk_id_t) j;
}
out:
return ret;
}

Expand Down Expand Up @@ -6739,9 +6798,11 @@ tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples,
const tsk_id_t *samples = self->samples;
tsk_size_t n = self->num_samples;
const double default_windows[] = { 0, self->tables->sequence_length };
const tsk_size_t num_nodes = self->tables->nodes.num_rows;
bool stat_site = !!(options & TSK_STAT_SITE);
bool stat_branch = !!(options & TSK_STAT_BRANCH);
bool stat_node = !!(options & TSK_STAT_NODE);
tsk_id_t *sample_index_map = NULL;

if (stat_node) {
ret = TSK_ERR_UNSUPPORTED_STAT_MODE;
Expand Down Expand Up @@ -6785,6 +6846,11 @@ tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples,
}
}

ret = get_sample_index_map(num_nodes, n, samples, &sample_index_map);
if (ret != 0) {
goto out;
}

tsk_memset(result, 0, num_windows * n * n * sizeof(*result));

if (stat_branch) {
Expand All @@ -6793,13 +6859,14 @@ tsk_treeseq_divergence_matrix(const tsk_treeseq_t *self, tsk_size_t num_samples,
} else {
tsk_bug_assert(stat_site);
ret = tsk_treeseq_divergence_matrix_site(
self, n, samples, num_windows, windows, options, result);
self, n, num_windows, windows, options, sample_index_map, result);
}
if (ret != 0) {
goto out;
}
fill_lower_triangle(result, n, num_windows);

out:
tsk_safe_free(sample_index_map);
return ret;
}
50 changes: 17 additions & 33 deletions python/tests/test_divmat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
"""
Test cases for divergence matrix based pairwise stats
"""
import collections

import msprime
import numpy as np
import pytest
Expand Down Expand Up @@ -279,6 +277,9 @@ def site_divergence_matrix(ts, windows=None, samples=None):
samples = ts.samples() if samples is None else samples

n = len(samples)
sample_index_map = np.zeros(ts.num_nodes, dtype=int) - 1
sample_index_map[samples] = np.arange(n)
is_descendant = np.zeros(n, dtype=bool)
D = np.zeros((num_windows, n, n))
tree = tskit.Tree(ts)
for i in range(num_windows):
Expand All @@ -289,32 +290,22 @@ def site_divergence_matrix(ts, windows=None, samples=None):
while tree.interval.left < right and tree.index != -1:
span_left = max(tree.interval.left, left)
span_right = min(tree.interval.right, right)
mutations_per_node = collections.Counter()
for site in tree.sites():
if span_left <= site.position < span_right:
for mutation in site.mutations:
mutations_per_node[mutation.node] += 1
for j in range(n):
u = samples[j]
for k in range(j + 1, n):
v = samples[k]
w = tree.mrca(u, v)
if w != tskit.NULL:
wu = w
wv = w
else:
wu = local_root(tree, u)
wv = local_root(tree, v)
du = sum(mutations_per_node[x] for x in rootward_path(tree, u, wu))
dv = sum(mutations_per_node[x] for x in rootward_path(tree, v, wv))
# NOTE: we're just accumulating the raw mutation counts, not
# multiplying by span
D[i, j, k] += du + dv
descendants = []
for u in tree.nodes(mutation.node):
if sample_index_map[u] != -1:
is_descendant[sample_index_map[u]] = True

descendants = np.where(is_descendant)[0]
not_descendants = np.where(np.logical_not(is_descendant))[0]
for j in descendants:
for k in not_descendants:
D[i, j, k] += 1
D[i, k, j] += 1
is_descendant[:] = False
tree.next()
# Fill out symmetric triangle in the matrix
for j in range(n):
for k in range(j + 1, n):
D[i, k, j] = D[i, j, k]
if not windows_specified:
D = D[0]
return D
Expand Down Expand Up @@ -511,15 +502,8 @@ def test_single_tree_duplicate_samples(self, mode):
# 0 1
ts = tskit.Tree.generate_balanced(4).tree_sequence
ts = tsutil.insert_branch_sites(ts)
D1 = check_divmat(ts, samples=[0, 0, 1], compare_stats_api=False, mode=mode)
D2 = np.array(
[
[0.0, 0.0, 2.0],
[0.0, 0.0, 2.0],
[2.0, 2.0, 0.0],
]
)
np.testing.assert_array_equal(D1, D2)
with pytest.raises(tskit.LibraryError, match="TSK_ERR_DUPLICATE_SAMPLE"):
ts.divergence_matrix(samples=[0, 0, 1], mode=mode)

@pytest.mark.parametrize("mode", DIVMAT_MODES)
def test_single_tree_multiroot(self, mode):
Expand Down

0 comments on commit 4945dd8

Please sign in to comment.