From 46578bdc28211230ad7815b116564e57c2ba88e7 Mon Sep 17 00:00:00 2001 From: lkirk Date: Fri, 23 Aug 2024 02:27:32 -0500 Subject: [PATCH] Create C api for two-locus branch stats Adds a C implementation of two-locus branch statistics. It mirrors the python implementation except where we iterate over edge differences and collect them for updating the stat. We use tree_seek_index to seek to arbitrary positions and tsk_tree_next to move from tree to tree. The tricky part was getting backwards iteration correct. All tests agree with the python prototype. In addition, I had to fix a bug in the python implementation where node ids were being added to our TreeState object instead of sample ids (encoded in the sample index map). The python tests have also been updated to remove the slow naive version (after validating that it agrees with the python and c implementation -- on test cases where the runtime was reasonable). Python tests have been trimmed for runtime. The CPython code has been updated to parse positions in addition to sites. I also found the need to clean up some of the bounds checking code to return reasonable error messages to the user (also updated in the two-locus site stats). Finally, I added a few unbiased statistics for use in validating this code. --- c/tests/test_stats.c | 391 +++++++++++++++-- c/tskit/core.c | 21 + c/tskit/core.h | 21 + c/tskit/trees.c | 761 +++++++++++++++++++++++++++++++-- c/tskit/trees.h | 53 ++- python/_tskitmodule.c | 155 +++++-- python/tests/test_ld_matrix.py | 285 ++++-------- python/tests/test_lowlevel.py | 166 +++++-- python/tskit/trees.py | 72 +++- 9 files changed, 1545 insertions(+), 380 deletions(-) diff --git a/c/tests/test_stats.c b/c/tests/test_stats.c index 251ddaccdf..6c1073414b 100644 --- a/c/tests/test_stats.c +++ b/c/tests/test_stats.c @@ -2281,7 +2281,7 @@ test_paper_ex_two_site(void) result_size = num_sites * num_sites; tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size * num_sample_sets, result, truth_one_set); @@ -2295,7 +2295,7 @@ test_paper_ex_two_site(void) tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size * num_sample_sets, result, truth_two_sets); @@ -2309,7 +2309,7 @@ test_paper_ex_two_site(void) tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal_nan( @@ -2320,6 +2320,128 @@ test_paper_ex_two_site(void) tsk_safe_free(col_sites); } +static void +test_paper_ex_two_branch(void) +{ + int ret; + tsk_treeseq_t ts; + double result[27]; + tsk_size_t i, result_size, num_sample_sets; + tsk_flags_t options = 0; + double truth_one_set[9] + = { 0.001066666666666695, -0.00012666666666665688, -0.0001266666666666534, + -0.00012666666666665688, 6.016666666665456e-05, 6.016666666665629e-05, + -0.0001266666666666534, 6.016666666665629e-05, 6.016666666665629e-05 }; + double truth_two_sets[18] + = { 0.001066666666666695, 0.001066666666666695, -0.00012666666666665688, + -0.00012666666666665688, -0.0001266666666666534, -0.0001266666666666534, + -0.00012666666666665688, -0.00012666666666665688, 6.016666666665456e-05, + 6.016666666665456e-05, 6.016666666665629e-05, 6.016666666665629e-05, + -0.0001266666666666534, -0.0001266666666666534, 6.016666666665629e-05, + 6.016666666665629e-05, 6.016666666665629e-05, 6.016666666665629e-05 }; + double truth_three_sets[27] = { 0.001066666666666695, 0.001066666666666695, NAN, + -0.00012666666666665688, -0.00012666666666665688, NAN, -0.0001266666666666534, + -0.0001266666666666534, NAN, -0.00012666666666665688, -0.00012666666666665688, + NAN, 6.016666666665456e-05, 6.016666666665456e-05, NAN, 6.016666666665629e-05, + 6.016666666665629e-05, NAN, -0.0001266666666666534, -0.0001266666666666534, NAN, + 6.016666666665629e-05, 6.016666666665629e-05, NAN, 6.016666666665629e-05, + 6.016666666665629e-05, NAN }; + double truth_positions_subset_1[12] = { 0.001066666666666695, 0.001066666666666695, + NAN, 0.001066666666666695, 0.001066666666666695, NAN, 0.001066666666666695, + 0.001066666666666695, NAN, 0.001066666666666695, 0.001066666666666695, NAN }; + double truth_positions_subset_2[12] = { 6.016666666665456e-05, 6.016666666665456e-05, + NAN, 6.016666666665456e-05, 6.016666666665456e-05, NAN, 6.016666666665456e-05, + 6.016666666665456e-05, NAN, 6.016666666665456e-05, 6.016666666665456e-05, NAN }; + double truth_positions_subset_3[12] = { 6.016666666665456e-05, 6.016666666665456e-05, + NAN, 6.016666666665456e-05, 6.016666666665456e-05, NAN, 6.016666666665456e-05, + 6.016666666665456e-05, NAN, 6.016666666665456e-05, 6.016666666665456e-05, NAN }; + + tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, + paper_ex_mutations, paper_ex_individuals, NULL, 0); + + tsk_size_t sample_set_sizes[3]; + tsk_id_t sample_sets[ts.num_samples * 3]; + tsk_size_t num_trees = ts.num_trees; + double *row_positions = tsk_malloc(num_trees * sizeof(*row_positions)); + double *col_positions = tsk_malloc(num_trees * sizeof(*col_positions)); + double positions_subset_1[2] = { 0., 0.1 }; + double positions_subset_2[2] = { 2., 6. }; + double positions_subset_3[2] = { 9., 9.999 }; + + // First sample set contains all of the samples + sample_set_sizes[0] = ts.num_samples; + num_sample_sets = 1; + for (i = 0; i < ts.num_samples; i++) { + sample_sets[i] = (tsk_id_t) i; + } + for (i = 0; i < num_trees; i++) { + row_positions[i] = ts.breakpoints[i]; + col_positions[i] = ts.breakpoints[i]; + } + + options |= TSK_STAT_BRANCH; + + result_size = num_trees * num_trees * num_sample_sets; + tsk_memset(result, 0, sizeof(*result) * result_size); + ret = tsk_treeseq_D2_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, + num_trees, NULL, row_positions, num_trees, NULL, col_positions, options, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size, result, truth_one_set); + + // Second sample set contains all of the samples + sample_set_sizes[1] = ts.num_samples; + num_sample_sets = 2; + for (i = ts.num_samples; i < ts.num_samples * 2; i++) { + sample_sets[i] = (tsk_id_t) i - (tsk_id_t) ts.num_samples; + } + + result_size = num_trees * num_trees * num_sample_sets; + tsk_memset(result, 0, sizeof(*result) * result_size); + ret = tsk_treeseq_D2_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, + num_trees, NULL, row_positions, num_trees, NULL, col_positions, options, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size, result, truth_two_sets); + + // Third sample set contains the first two samples + sample_set_sizes[2] = 2; + num_sample_sets = 3; + for (i = ts.num_samples * 2; i < (ts.num_samples * 3) - 2; i++) { + sample_sets[i] = (tsk_id_t) i - (tsk_id_t) ts.num_samples * 2; + } + + result_size = num_trees * num_trees * num_sample_sets; + tsk_memset(result, 0, sizeof(*result) * result_size); + ret = tsk_treeseq_D2_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, + num_trees, NULL, row_positions, num_trees, NULL, col_positions, options, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal_nan(result_size, result, truth_three_sets); + + result_size = 4 * num_sample_sets; + tsk_memset(result, 0, sizeof(*result) * result_size); + ret = tsk_treeseq_D2_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, 2, + NULL, positions_subset_1, 2, NULL, positions_subset_1, options, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal_nan(result_size, result, truth_positions_subset_1); + + result_size = 4 * num_sample_sets; + tsk_memset(result, 0, sizeof(*result) * result_size); + ret = tsk_treeseq_D2_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, 2, + NULL, positions_subset_2, 2, NULL, positions_subset_2, options, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal_nan(result_size, result, truth_positions_subset_2); + + result_size = 4 * num_sample_sets; + tsk_memset(result, 0, sizeof(*result) * result_size); + ret = tsk_treeseq_D2_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, 2, + NULL, positions_subset_3, 2, NULL, positions_subset_3, options, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal_nan(result_size, result, truth_positions_subset_3); + + tsk_treeseq_free(&ts); + tsk_safe_free(row_positions); + tsk_safe_free(col_positions); +} + static void test_two_site_correlated_multiallelic(void) { @@ -2401,43 +2523,43 @@ test_two_site_correlated_multiallelic(void) tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_D(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_D); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_D2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_D2); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_r2); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_D_prime(&ts, num_sample_sets, sample_set_sizes, sample_sets, - num_sites, row_sites, num_sites, col_sites, 0, result); + num_sites, row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_D_prime); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_r(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_r); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_Dz(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_Dz); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_pi2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_pi2); @@ -2532,43 +2654,43 @@ test_two_site_uncorrelated_multiallelic(void) tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_D(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_D); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_D2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_D2); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_r2); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_D_prime(&ts, num_sample_sets, sample_set_sizes, sample_sets, - num_sites, row_sites, num_sites, col_sites, 0, result); + num_sites, row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_D_prime); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_r(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_r); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_Dz(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_Dz); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_pi2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_pi2); @@ -2637,7 +2759,7 @@ test_two_site_backmutation(void) tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_r2); @@ -2646,6 +2768,150 @@ test_two_site_backmutation(void) tsk_safe_free(col_sites); } +static void +test_two_locus_site_all_stats(void) +{ + int ret; + tsk_treeseq_t ts; + double result[16]; + tsk_size_t result_size = 16; + tsk_id_t sample_sets[10] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }; + tsk_size_t sample_set_sizes[1] = { 10 }; + double positions[4] = { 0.0, 2.0, 5.0, 6.0 }; + + const char *nodes + = "1 0 -1\n1 0 -1\n1 0 -1\n1 0 -1\n1 0 -1\n1 0 -1\n1 0 -1\n1 0 -1\n" + "1 0 -1\n1 0 -1\n0 0.02 -1\n0 0.06 -1\n0 0.08 -1\n0 0.09 -1\n0 0.21 -1\n" + "0 0.35 -1\n0 0.44 -1\n0 0.69 -1\n0 0.79 -1\n0 0.80 -1\n0 0.84 -1\n" + "0 1.26 -1\n"; + const char *edges + = "0 10 10 0,8\n0 10 11 4,7\n0 10 12 3,9\n0 10 13 6,11\n0 10 14 1,2\n" + "5 10 15 5,10\n0 5 16 5,10\n6 10 17 12,14\n2 6 18 14\n5 10 18 15\n" + "2 5 18 16\n6 10 18 17\n0 6 19 12\n0 2 19 14\n2 6 19 18\n" + "0 2 20 13\n0 2 20 16\n2 10 21 13\n6 10 21 18\n0 6 21 19\n" + "0 2 21 20\n"; + + double truth_D[16] = { -6.938893903907228e-18, 5.551115123125783e-17, + 4.85722573273506e-17, 2.7755575615628914e-17, 1.0408340855860843e-17, + 8.326672684688674e-17, 7.979727989493313e-17, 6.938893903907228e-17, + -2.42861286636753e-17, 4.163336342344337e-17, 2.42861286636753e-17, + 4.163336342344337e-17, 1.3877787807814457e-17, 5.551115123125783e-17, + 2.0816681711721685e-17, 2.7755575615628914e-17 }; + double truth_D2[16] = { 0.21949755999999998, 0.1867003599999999, 0.18798699999999988, + 0.18941379999999983, 0.18670035999999995, 0.21159555999999993, + 0.21257979999999996, 0.21222580000000005, 0.187987, 0.21257979999999996, + 0.21380379999999996, 0.2134714, 0.18941379999999994, 0.21222579999999996, + 0.21347139999999992, 0.21377299999999996 }; + double truth_r2[16] = { 6.286870108969513, 5.742220038107836, 5.7080225607835695, + 5.623290389581752, 5.742220038107832, 6.3274209876543175, 6.291288603867465, + 6.195658345930953, 5.708022560783573, 6.291288603867472, 6.266256220080618, + 6.170677280171318, 5.623290389581758, 6.195658345930966, 6.170677280171324, + 6.094109054547737 }; + double truth_D_prime[16] = { -9.6552, -9.44459999999999, -9.136799999999988, + -8.680999999999989, -9.444599999999998, -9.240699999999984, -8.937399999999977, + -8.488499999999984, -9.136799999999996, -8.93739999999999, -8.658399999999984, + -8.219399999999993, -8.68099999999999, -8.488499999999991, -8.21939999999999, + -7.814699999999995 }; + double truth_r[16] = { 0.023193673439522472, 0.023272634599981495, + 0.021243465874728862, 0.01919099466703808, 0.023272634599981454, + 0.023358527073393587, 0.021370047752011, 0.019268461077492888, + 0.021243465874728862, 0.021370047752011012, 0.020359977803327087, + 0.01793842604857987, 0.019190994667037817, 0.019268461077492804, + 0.017938426048579773, 0.0160605735196305 }; + double truth_Dz[16] = { 0.01958895999999996, -0.007941440000000037, + -0.007572800000000046, -0.010558400000000029, -0.007941440000000022, + 0.01385535999999997, 0.014569599999999966, 0.015529599999999963, + -0.007572800000000024, 0.01456959999999996, 0.015426399999999951, + 0.016271199999999948, -0.010558400000000011, 0.01552959999999999, + 0.016271199999999986, 0.017607999999999985 }; + double truth_pi2[16] = { 0.7201219600000001, 0.6895723600000001, 0.6865174000000006, + 0.6780314000000008, 0.6895723600000002, 0.6603187600000002, 0.6573934000000002, + 0.6492674000000002, 0.6865174000000002, 0.6573934000000003, 0.6544810000000003, + 0.6463910000000003, 0.6780314000000002, 0.6492674000000004, 0.6463910000000005, + 0.6384010000000007 }; + double truth_Dz_unbiased[16] = { -0.06387380952380949, -0.09312571428571428, + -0.09361428571428566, -0.10075682539682536, -0.09312571428571428, + -0.0734419047619048, -0.0730733333333334, -0.07171301587301597, + -0.0936142857142857, -0.07307333333333343, -0.07261476190476202, + -0.07147730158730167, -0.10075682539682543, -0.07171301587301596, + -0.07147730158730159, -0.06988666666666674 }; + double truth_D2_unbiased[16] = { 0.19576484126984134, 0.1586769841269842, + 0.16093412698412704, 0.16485253968253985, 0.15867698412698414, + 0.1949926984126984, 0.19673555555555555, 0.19734825396825403, + 0.16093412698412699, 0.1967355555555555, 0.19879341269841264, + 0.19945182539682532, 0.16485253968253968, 0.19734825396825395, + 0.1994518253968253, 0.20091222222222213 }; + double truth_pi2_unbiased[16] = { 0.8910765079365083, 0.8571103174603181, + 0.853337460317461, 0.8434880952380959, 0.8571103174603178, 0.8182193650793657, + 0.8145322222222225, 0.8043504761904768, 0.8533374603174609, 0.8145322222222225, + 0.8108450793650795, 0.800729047619048, 0.8434880952380955, 0.8043504761904766, + 0.8007290476190477, 0.7906733333333332 }; + + tsk_treeseq_from_text(&ts, 10, nodes, edges, NULL, NULL, NULL, NULL, NULL, 0); + + tsk_memset(result, 0, sizeof(*result) * result_size); + ret = tsk_treeseq_D(&ts, 1, sample_set_sizes, sample_sets, 4, NULL, positions, 4, + NULL, positions, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size, result, truth_D); + + tsk_memset(result, 0, sizeof(*result) * result_size); + ret = tsk_treeseq_D2(&ts, 1, sample_set_sizes, sample_sets, 4, NULL, positions, 4, + NULL, positions, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size, result, truth_D2); + + tsk_memset(result, 0, sizeof(*result) * result_size); + ret = tsk_treeseq_r2(&ts, 1, sample_set_sizes, sample_sets, 4, NULL, positions, 4, + NULL, positions, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size, result, truth_r2); + + tsk_memset(result, 0, sizeof(*result) * result_size); + ret = tsk_treeseq_D_prime(&ts, 1, sample_set_sizes, sample_sets, 4, NULL, positions, + 4, NULL, positions, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size, result, truth_D_prime); + + tsk_memset(result, 0, sizeof(*result) * result_size); + ret = tsk_treeseq_r(&ts, 1, sample_set_sizes, sample_sets, 4, NULL, positions, 4, + NULL, positions, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size, result, truth_r); + + tsk_memset(result, 0, sizeof(*result) * result_size); + ret = tsk_treeseq_Dz(&ts, 1, sample_set_sizes, sample_sets, 4, NULL, positions, 4, + NULL, positions, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size, result, truth_Dz); + + tsk_memset(result, 0, sizeof(*result) * result_size); + ret = tsk_treeseq_pi2(&ts, 1, sample_set_sizes, sample_sets, 4, NULL, positions, 4, + NULL, positions, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size, result, truth_pi2); + + tsk_memset(result, 0, sizeof(*result) * result_size); + ret = tsk_treeseq_Dz_unbiased(&ts, 1, sample_set_sizes, sample_sets, 4, NULL, + positions, 4, NULL, positions, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size, result, truth_Dz_unbiased); + + tsk_memset(result, 0, sizeof(*result) * result_size); + ret = tsk_treeseq_D2_unbiased(&ts, 1, sample_set_sizes, sample_sets, 4, NULL, + positions, 4, NULL, positions, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size, result, truth_D2_unbiased); + + tsk_memset(result, 0, sizeof(*result) * result_size); + ret = tsk_treeseq_pi2_unbiased(&ts, 1, sample_set_sizes, sample_sets, 4, NULL, + positions, 4, NULL, positions, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size, result, truth_pi2_unbiased); + + tsk_treeseq_free(&ts); +} + static void test_paper_ex_two_site_subset(void) { @@ -2675,7 +2941,7 @@ test_paper_ex_two_site_subset(void) result_size = 2 * 2; tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 2, - row_sites, 2, col_sites, 0, result); + row_sites, NULL, 2, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size * num_sample_sets, result, result_truth_1); @@ -2683,7 +2949,7 @@ test_paper_ex_two_site_subset(void) tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); col_sites[0] = 2; ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 1, - row_sites, 1, col_sites, 0, result); + row_sites, NULL, 1, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size * num_sample_sets, result, result_truth_2); @@ -2694,7 +2960,7 @@ test_paper_ex_two_site_subset(void) col_sites[0] = 0; col_sites[1] = 1; ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 2, - row_sites, 2, col_sites, 0, result); + row_sites, NULL, 2, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size * num_sample_sets, result, result_truth_3); @@ -2716,8 +2982,9 @@ test_two_locus_stat_input_errors(void) tsk_size_t sample_set_sizes[1] = { ts.num_samples }; tsk_size_t num_sample_sets = 1; tsk_id_t sample_sets[ts.num_samples]; - tsk_size_t result_size = num_sites * num_sites; - double result[result_size]; + double positions[10] = { 0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9 }; + double bad_col_positions[2] = { 0., 0. }; // used in 1 test to cover column check + double result[100]; tsk_size_t s; for (s = 0; s < ts.num_samples; s++) { @@ -2736,66 +3003,104 @@ test_two_locus_stat_input_errors(void) sample_sets[1] = 0; ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_DUPLICATE_SAMPLE); sample_sets[1] = 1; ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, TSK_STAT_SITE | TSK_STAT_BRANCH, result); + row_sites, NULL, num_sites, col_sites, NULL, TSK_STAT_SITE | TSK_STAT_BRANCH, + result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_MULTIPLE_STAT_MODES); - ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, TSK_STAT_BRANCH, result); - CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_UNSUPPORTED_STAT_MODE); - ret = tsk_treeseq_r2(&ts, 0, sample_set_sizes, sample_sets, num_sites, row_sites, - num_sites, col_sites, 0, result); + NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_INSUFFICIENT_SAMPLE_SETS); sample_set_sizes[0] = 0; ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_EMPTY_SAMPLE_SET); sample_set_sizes[0] = ts.num_samples; sample_sets[1] = 10; ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); sample_sets[1] = 1; row_sites[0] = 1000; ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_SITE_OUT_OF_BOUNDS); row_sites[0] = 0; col_sites[num_sites - 1] = (tsk_id_t) num_sites; ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_SITE_OUT_OF_BOUNDS); col_sites[num_sites - 1] = (tsk_id_t) num_sites - 1; row_sites[0] = 1; row_sites[1] = 0; ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); - CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_UNSORTED_SITES); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_STAT_UNSORTED_SITES); row_sites[0] = 0; row_sites[1] = 1; row_sites[0] = 1; row_sites[1] = 1; ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, num_sites, col_sites, 0, result); - CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_UNSORTED_SITES); + row_sites, NULL, num_sites, col_sites, NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_STAT_DUPLICATE_SITES); row_sites[0] = 0; row_sites[1] = 1; - // Not an error condition, but we want to record this behavior - ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, 0, - NULL, 0, result); - CU_ASSERT_EQUAL_FATAL(ret, 0); + // Not an error condition, but we want to record this behavior. The method is robust + // to zero-length site/position inputs. + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, + NULL, 0, NULL, NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 0, NULL, + NULL, 0, NULL, NULL, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + positions[9] = 1; + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 10, NULL, + positions, 10, NULL, positions, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_POSITION_OUT_OF_BOUNDS); + positions[9] = 0.9; + + positions[0] = -0.1; + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 10, NULL, + positions, 10, NULL, positions, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_POSITION_OUT_OF_BOUNDS); + positions[0] = 0; + + positions[0] = 0.1; + positions[1] = 0; + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 10, NULL, + positions, 10, NULL, positions, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_STAT_UNSORTED_POSITIONS); + positions[0] = 0; + positions[1] = 0.1; + + // rows always fail first, check columns + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 10, NULL, + positions, 2, NULL, bad_col_positions, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_STAT_DUPLICATE_POSITIONS); + + positions[0] = 0; + positions[1] = 0; + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 10, NULL, + positions, 10, NULL, positions, TSK_STAT_BRANCH, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_STAT_DUPLICATE_POSITIONS); + positions[0] = 0; + positions[1] = 0.1; + + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 10, NULL, + positions, 10, NULL, positions, TSK_STAT_NODE, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_UNSUPPORTED_STAT_MODE); tsk_treeseq_free(&ts); tsk_safe_free(row_sites); @@ -3101,11 +3406,13 @@ main(int argc, char **argv) { "test_ld_silent_mutations", test_ld_silent_mutations }, { "test_paper_ex_two_site", test_paper_ex_two_site }, + { "test_paper_ex_two_branch", test_paper_ex_two_branch }, { "test_two_site_correlated_multiallelic", test_two_site_correlated_multiallelic }, { "test_two_site_uncorrelated_multiallelic", test_two_site_uncorrelated_multiallelic }, { "test_two_site_backmutation", test_two_site_backmutation }, + { "test_two_locus_site_all_stats", test_two_locus_site_all_stats }, { "test_paper_ex_two_site_subset", test_paper_ex_two_site_subset }, { "test_two_locus_stat_input_errors", test_two_locus_stat_input_errors }, diff --git a/c/tskit/core.c b/c/tskit/core.c index 32120ba2f8..609979edd8 100644 --- a/c/tskit/core.c +++ b/c/tskit/core.c @@ -226,6 +226,9 @@ tsk_strerror_internal(int err) ret = "One of the kept rows in the table refers to a deleted row. " "(TSK_ERR_KEEP_ROWS_MAP_TO_DELETED)"; break; + case TSK_ERR_POSITION_OUT_OF_BOUNDS: + ret = "Position out of bounds. (TSK_ERR_POSITION_OUT_OF_BOUNDS)"; + break; /* Edge errors */ case TSK_ERR_NULL_PARENT: @@ -502,6 +505,24 @@ tsk_strerror_internal(int err) ret = "Times must be strictly increasing. (TSK_ERR_UNSORTED_TIMES)"; break; + /* Two locus errors */ + case TSK_ERR_STAT_UNSORTED_POSITIONS: + ret = "The provided positions are not sorted in strictly increasing " + "order. (TSK_ERR_STAT_UNSORTED_POSITIONS)"; + break; + case TSK_ERR_STAT_DUPLICATE_POSITIONS: + ret = "The provided positions contain duplicates. " + "(TSK_ERR_STAT_DUPLICATE_POSITIONS)"; + break; + case TSK_ERR_STAT_UNSORTED_SITES: + ret = "The provided sites are not sorted in strictly increasing position " + "order. (TSK_ERR_STAT_UNSORTED_SITES)"; + break; + case TSK_ERR_STAT_DUPLICATE_SITES: + ret = "The provided sites contain duplicated entries. " + "(TSK_ERR_STAT_DUPLICATE_SITES)"; + break; + /* Mutation mapping errors */ case TSK_ERR_GENOTYPES_ALL_MISSING: ret = "Must provide at least one non-missing genotype. " diff --git a/c/tskit/core.h b/c/tskit/core.h index 641400a44c..93e407fe68 100644 --- a/c/tskit/core.h +++ b/c/tskit/core.h @@ -370,6 +370,11 @@ One of the rows in the retained table refers to a row that has been deleted. */ #define TSK_ERR_KEEP_ROWS_MAP_TO_DELETED -212 +/** +A genomic position was less than zero or greater equal to the sequence +length +*/ +#define TSK_ERR_POSITION_OUT_OF_BOUNDS -213 /** @} */ @@ -710,6 +715,22 @@ The vector of quantiles is out of bounds or in nonascending order. Times are not in ascending order */ #define TSK_ERR_UNSORTED_TIMES -917 +/* +The provided positions are not provided in strictly increasing order +*/ +#define TSK_ERR_STAT_UNSORTED_POSITIONS -918 +/** +The provided positions are not unique +*/ +#define TSK_ERR_STAT_DUPLICATE_POSITIONS -919 +/** +The provided sites are not provided in strictly increasing position order +*/ +#define TSK_ERR_STAT_UNSORTED_SITES -920 +/** +The provided sites are not unique +*/ +#define TSK_ERR_STAT_DUPLICATE_SITES -921 /** @} */ /** diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 1432d257ee..5f55f470fb 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -808,7 +808,7 @@ tsk_treeseq_get_individuals_time(const tsk_treeseq_t *self, double *output) /* Stats functions */ -#define GET_2D_ROW(array, row_len, row) (array + (((size_t)(row_len)) * (size_t) row)) +#define GET_2D_ROW(array, row_len, row) (array + (((size_t)(row_len)) * (size_t)(row))) static inline double * GET_3D_ROW(double *base, tsk_size_t num_nodes, tsk_size_t output_dim, @@ -2621,9 +2621,12 @@ check_sites(const tsk_id_t *sites, tsk_size_t num_sites, tsk_size_t num_site_row ret = TSK_ERR_SITE_OUT_OF_BOUNDS; goto out; } - if (sites[i] >= sites[i + 1]) { - // TODO: this checks no repeats, but error is ambiguous - ret = TSK_ERR_UNSORTED_SITES; + if (sites[i] > sites[i + 1]) { + ret = TSK_ERR_STAT_UNSORTED_SITES; + goto out; + } + if (sites[i] == sites[i + 1]) { + ret = TSK_ERR_STAT_DUPLICATE_SITES; goto out; } } @@ -2636,12 +2639,571 @@ check_sites(const tsk_id_t *sites, tsk_size_t num_sites, tsk_size_t num_site_row return ret; } +static int +check_positions( + const double *positions, tsk_size_t num_positions, double sequence_length) +{ + int ret = 0; + tsk_size_t i; + + if (num_positions == 0) { + return ret; // No need to verify positions if there aren't any + } + + for (i = 0; i < num_positions - 1; i++) { + if (positions[i] < 0 || positions[i] >= sequence_length) { + ret = TSK_ERR_POSITION_OUT_OF_BOUNDS; + goto out; + } + if (positions[i] > positions[i + 1]) { + ret = TSK_ERR_STAT_UNSORTED_POSITIONS; + goto out; + } + if (positions[i] == positions[i + 1]) { + ret = TSK_ERR_STAT_DUPLICATE_POSITIONS; + goto out; + } + } + // check bounds of last value + if (positions[i] < 0 || positions[i] >= sequence_length) { + ret = TSK_ERR_POSITION_OUT_OF_BOUNDS; + goto out; + } +out: + return ret; +} + +static int +positions_to_tree_indexes(const tsk_treeseq_t *ts, const double *positions, + tsk_size_t num_positions, tsk_id_t **tree_indexes) +{ + int ret = 0; + tsk_id_t tree_index = 0; + tsk_size_t i, num_trees = ts->num_trees; + + // This is tricky. If there are 0 positions, we calloc a size of 1 + // we must calloc, because memset will have no effect when called with size 0 + *tree_indexes = tsk_calloc(num_positions, sizeof(*tree_indexes)); + if (tree_indexes == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + tsk_memset(*tree_indexes, TSK_NULL, num_positions * sizeof(**tree_indexes)); + for (i = 0; i < num_positions; i++) { + while (ts->breakpoints[tree_index + 1] <= positions[i]) { + tree_index++; + } + (*tree_indexes)[i] = tree_index; + } + tsk_bug_assert(tree_index <= (tsk_id_t)(num_trees - 1)); + +out: + return ret; +} + +static int +get_index_counts( + const tsk_id_t *indexes, tsk_size_t num_indexes, tsk_size_t **out_counts) +{ + int ret = 0; + tsk_id_t index = indexes[0]; + tsk_size_t count, i; + tsk_size_t *counts = tsk_calloc( + (tsk_size_t)(indexes[num_indexes ? num_indexes - 1 : 0] - indexes[0] + 1), + sizeof(*counts)); + if (counts == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + + count = 1; + for (i = 1; i < num_indexes; i++) { + if (indexes[i] == indexes[i - 1]) { + count++; + } else { + counts[index - indexes[0]] = count; + count = 1; + index = indexes[i]; + } + } + counts[index - indexes[0]] = count; + *out_counts = counts; +out: + return ret; +} + +typedef struct { + tsk_tree_t tree; + tsk_bit_array_t *node_samples; + tsk_id_t *parent; + tsk_id_t *edges_out; + tsk_id_t *edges_in; + double *branch_len; + tsk_size_t n_edges_out; + tsk_size_t n_edges_in; +} iter_state; + +static int +iter_state_init(iter_state *self, const tsk_treeseq_t *ts, tsk_size_t state_dim) +{ + int ret = 0; + const tsk_size_t num_nodes = ts->tables->nodes.num_rows; + + ret = tsk_tree_init(&self->tree, ts, TSK_NO_SAMPLE_COUNTS); + if (ret != 0) { + goto out; + } + self->node_samples = tsk_calloc(1, sizeof(*self->node_samples)); + if (self->node_samples == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + ret = tsk_bit_array_init(self->node_samples, ts->num_samples, state_dim * num_nodes); + if (ret != 0) { + goto out; + } + self->parent = tsk_malloc(num_nodes * sizeof(*self->parent)); + self->edges_out = tsk_malloc(num_nodes * sizeof(*self->edges_out)); + self->edges_in = tsk_malloc(num_nodes * sizeof(*self->edges_in)); + self->branch_len = tsk_calloc(num_nodes, sizeof(*self->branch_len)); + if (self->parent == NULL || self->edges_out == NULL || self->edges_in == NULL + || self->branch_len == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } +out: + return ret; +} + +static int +get_node_samples(const tsk_treeseq_t *ts, tsk_size_t state_dim, + const tsk_bit_array_t *sample_sets, tsk_bit_array_t *node_samples) +{ + int ret = 0; + tsk_size_t n, k; + tsk_bit_array_t sample_set_row, node_samples_row; + tsk_size_t num_nodes = ts->tables->nodes.num_rows; + tsk_bit_array_value_t sample; + const tsk_id_t *restrict sample_index_map = ts->sample_index_map; + const tsk_flags_t *restrict flags = ts->tables->nodes.flags; + + ret = tsk_bit_array_init(node_samples, ts->num_samples, num_nodes * state_dim); + if (ret != 0) { + goto out; + } + for (k = 0; k < state_dim; k++) { + tsk_bit_array_get_row(sample_sets, k, &sample_set_row); + for (n = 0; n < num_nodes; n++) { + if (flags[n] & TSK_NODE_IS_SAMPLE) { + sample = (tsk_bit_array_value_t) sample_index_map[n]; + if (tsk_bit_array_contains(&sample_set_row, sample)) { + tsk_bit_array_get_row( + node_samples, (state_dim * n) + k, &node_samples_row); + tsk_bit_array_add_bit(&node_samples_row, sample); + } + } + } + } +out: + return ret; +} + +static void +iter_state_clear(iter_state *self, tsk_size_t state_dim, tsk_size_t num_nodes, + const tsk_bit_array_t *node_samples) +{ + self->n_edges_out = 0; + self->n_edges_in = 0; + tsk_tree_clear(&self->tree); + tsk_memset(self->parent, TSK_NULL, num_nodes * sizeof(*self->parent)); + tsk_memset(self->edges_out, TSK_NULL, num_nodes * sizeof(*self->edges_out)); + tsk_memset(self->edges_in, TSK_NULL, num_nodes * sizeof(*self->edges_in)); + tsk_memset(self->branch_len, 0, num_nodes * sizeof(*self->branch_len)); + tsk_memcpy(self->node_samples->data, node_samples->data, + node_samples->size * state_dim * num_nodes * sizeof(*node_samples->data)); +} + +static void +iter_state_free(iter_state *self) +{ + tsk_tree_free(&self->tree); + tsk_bit_array_free(self->node_samples); + tsk_safe_free(self->node_samples); + tsk_safe_free(self->parent); + tsk_safe_free(self->edges_out); + tsk_safe_free(self->edges_in); + tsk_safe_free(self->branch_len); +} + +static int +advance_collect_edges(iter_state *s, tsk_id_t index) +{ + int ret = 0; + tsk_id_t j, e; + tsk_size_t i; + double left, right; + tsk_tree_position_t pos; + tsk_tree_t *tree = &s->tree; + const double *restrict edge_left = tree->tree_sequence->tables->edges.left; + const double *restrict edge_right = tree->tree_sequence->tables->edges.right; + + // Either we're seeking forward one step from some nonzero position in the tree, or + // from the beginning of the tree sequence. + if (tree->index != TSK_NULL || index == 0) { + ret = tsk_tree_next(tree); + if (ret < 0) { + goto out; + } + pos = tree->tree_pos; + i = 0; + for (j = pos.out.start; j != pos.out.stop; j++) { + s->edges_out[i] = pos.out.order[j]; + i++; + } + s->n_edges_out = i; + i = 0; + for (j = pos.in.start; j != pos.in.stop; j++) { + s->edges_in[i] = pos.in.order[j]; + i++; + } + s->n_edges_in = i; + } else { + // Seek from an arbitrary nonzero position from an uninitialized tree. + tsk_bug_assert(tree->index == -1); + ret = tsk_tree_seek_index(tree, index, 0); + if (ret < 0) { + goto out; + } + pos = tree->tree_pos; + i = 0; + if (pos.direction == TSK_DIR_FORWARD) { + left = pos.interval.left; + for (j = pos.in.start; j != pos.in.stop; j++) { + e = pos.in.order[j]; + if (edge_left[e] <= left && left < edge_right[e]) { + s->edges_in[i] = pos.in.order[j]; + i++; + } + } + } else { + right = pos.interval.right; + for (j = pos.in.start; j != pos.in.stop; j--) { + e = pos.in.order[j]; + if (edge_right[e] >= right && right > edge_left[e]) { + s->edges_in[i] = pos.in.order[j]; + i++; + } + } + } + s->n_edges_out = 0; + s->n_edges_in = i; + } + ret = 0; +out: + return ret; +} + +static int +compute_two_tree_branch_state_update(const tsk_treeseq_t *ts, tsk_id_t c, + tsk_bit_array_t *child_samples, const iter_state *A_state, const iter_state *B_state, + tsk_size_t state_dim, tsk_size_t result_dim, int sign, general_stat_func_t *f, + sample_count_stat_params_t *f_params, double *result) +{ + int ret = 0; + double a_len, b_len; + double *restrict B_branch_len = B_state->branch_len; + double *weights = NULL, *weights_row, *result_tmp = NULL; + tsk_size_t n, k, a_row, b_row; + tsk_bit_array_t A_samples, B_samples, AB_samples, B_samples_tmp; + const double *restrict A_branch_len = A_state->branch_len; + const tsk_bit_array_t *restrict A_state_samples = A_state->node_samples; + const tsk_bit_array_t *restrict B_state_samples = B_state->node_samples; + tsk_size_t num_samples = ts->num_samples; + tsk_size_t num_nodes = ts->tables->nodes.num_rows; + + tsk_memset(&AB_samples, 0, sizeof(AB_samples)); + tsk_memset(&B_samples_tmp, 0, sizeof(B_samples_tmp)); + + weights = tsk_calloc(3 * state_dim, sizeof(*weights)); + result_tmp = tsk_calloc(result_dim, sizeof(*result_tmp)); + if (weights == NULL || result_tmp == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + ret = tsk_bit_array_init(&AB_samples, num_samples, 1); + if (ret != 0) { + goto out; + } + ret = tsk_bit_array_init(&B_samples_tmp, num_samples, 1); + if (ret != 0) { + goto out; + } + b_len = B_branch_len[c] * sign; + for (n = 0; n < num_nodes; n++) { + a_len = A_branch_len[n]; + if (a_len == 0) { + continue; + } + for (k = 0; k < state_dim; k++) { + a_row = (state_dim * n) + k; + // TODO: what if c is TSK_NULL? + b_row = (state_dim * (tsk_size_t) c) + k; + tsk_bit_array_get_row(A_state_samples, a_row, &A_samples); + tsk_bit_array_get_row(B_state_samples, b_row, &B_samples); + tsk_bit_array_intersect(&A_samples, &B_samples, &AB_samples); + weights_row = GET_2D_ROW(weights, 3, k); + weights_row[0] = (double) tsk_bit_array_count(&AB_samples); // w_AB + weights_row[1] + = (double) tsk_bit_array_count(&A_samples) - weights_row[0]; // w_Ab + weights_row[2] + = (double) tsk_bit_array_count(&B_samples) - weights_row[0]; // w_aB + } + ret = f(state_dim, weights, result_dim, result_tmp, f_params); + if (ret != 0) { + goto out; + } + for (k = 0; k < result_dim; k++) { + result[k] += result_tmp[k] * a_len * b_len; + } + + if (child_samples != NULL) { + for (k = 0; k < state_dim; k++) { + a_row = (state_dim * n) + k; + // TODO: what if c is TSK_NULL? + b_row = (state_dim * (tsk_size_t) c) + k; + tsk_bit_array_get_row(B_state_samples, b_row, &B_samples); + tsk_bit_array_add(&B_samples_tmp, &B_samples); + tsk_bit_array_subtract(&B_samples_tmp, child_samples); + tsk_bit_array_get_row(A_state_samples, a_row, &A_samples); + tsk_bit_array_intersect(&A_samples, &B_samples_tmp, &AB_samples); + weights_row = GET_2D_ROW(weights, 3, k); + weights_row[0] = (double) tsk_bit_array_count(&AB_samples); // w_AB + weights_row[1] + = (double) tsk_bit_array_count(&A_samples) - weights_row[0]; // w_Ab + weights_row[2] = (double) tsk_bit_array_count(&B_samples_tmp) + - weights_row[0]; // w_aB + } + ret = f(state_dim, weights, result_dim, result_tmp, f_params); + if (ret != 0) { + goto out; + } + for (k = 0; k < result_dim; k++) { + result[k] -= result_tmp[k] * a_len * b_len; + } + } + } +out: + tsk_safe_free(weights); + tsk_safe_free(result_tmp); + tsk_bit_array_free(&AB_samples); + tsk_bit_array_free(&B_samples_tmp); + return ret; +} + +static int +compute_two_tree_branch_stat(const tsk_treeseq_t *ts, const iter_state *l_state, + iter_state *r_state, general_stat_func_t *f, sample_count_stat_params_t *f_params, + tsk_size_t result_dim, tsk_size_t state_dim, double *result) +{ + int ret = 0; + tsk_id_t e, c, p; + tsk_size_t j, k; + tsk_bit_array_t child_samples, child_samples_row, samples_row, *in_parent; + const double *restrict time = ts->tables->nodes.time; + const tsk_id_t *restrict edges_child = ts->tables->edges.child; + const tsk_id_t *restrict edges_parent = ts->tables->edges.parent; + tsk_bit_array_t *r_samples = r_state->node_samples; + + tsk_memset(&child_samples, 0, sizeof(child_samples)); + ret = tsk_bit_array_init(&child_samples, ts->num_samples, state_dim); + if (ret != 0) { + goto out; + } + for (j = 0; j < r_state->n_edges_out; j++) { + e = r_state->edges_out[j]; + c = edges_child[e]; + p = edges_parent[e]; + tsk_memset(child_samples.data, 0, + child_samples.size * state_dim * sizeof(tsk_bit_array_value_t)); + tsk_bug_assert(c != TSK_NULL); // TODO: are these checks necessary? + tsk_bug_assert(p != TSK_NULL); + for (k = 0; k < state_dim; k++) { + tsk_bit_array_get_row( + r_samples, (state_dim * (tsk_size_t) c) + k, &samples_row); + tsk_bit_array_get_row(&child_samples, k, &child_samples_row); + tsk_bit_array_add(&child_samples_row, &samples_row); + } + in_parent = NULL; + while (p != TSK_NULL) { + compute_two_tree_branch_state_update(ts, c, in_parent, l_state, r_state, + state_dim, result_dim, -1, f, f_params, result); + if (in_parent != NULL) { + for (k = 0; k < state_dim; k++) { + tsk_bit_array_get_row( + r_samples, (state_dim * (tsk_size_t) c) + k, &samples_row); + tsk_bit_array_get_row(&child_samples, k, &child_samples_row); + tsk_bit_array_subtract(&samples_row, &child_samples_row); + } + } + in_parent = &child_samples; + c = p; + p = r_state->parent[p]; + } + for (k = 0; k < state_dim; k++) { + tsk_bit_array_get_row( + r_samples, (state_dim * (tsk_size_t) c) + k, &samples_row); + tsk_bit_array_get_row(&child_samples, k, &child_samples_row); + tsk_bit_array_subtract(&samples_row, &child_samples_row); + } + c = edges_child[e]; + r_state->branch_len[c] = 0; + r_state->parent[c] = TSK_NULL; + } + for (j = 0; j < r_state->n_edges_in; j++) { + e = r_state->edges_in[j]; + c = edges_child[e]; + p = edges_parent[e]; + tsk_memset(child_samples.data, 0, + child_samples.size * state_dim * sizeof(tsk_bit_array_value_t)); + for (k = 0; k < state_dim; k++) { + tsk_bit_array_get_row( + r_samples, (state_dim * (tsk_size_t) c) + k, &samples_row); + tsk_bit_array_get_row(&child_samples, k, &child_samples_row); + tsk_bit_array_add(&child_samples_row, &samples_row); + } + r_state->branch_len[c] = time[p] - time[c]; + r_state->parent[c] = p; + + in_parent = NULL; + while (p != TSK_NULL) { + for (k = 0; k < state_dim; k++) { + tsk_bit_array_get_row( + r_samples, (state_dim * (tsk_size_t) p) + k, &samples_row); + tsk_bit_array_get_row(&child_samples, k, &child_samples_row); + tsk_bit_array_add(&samples_row, &child_samples_row); + } + compute_two_tree_branch_state_update(ts, c, in_parent, l_state, r_state, + state_dim, result_dim, +1, f, f_params, result); + in_parent = &child_samples; + c = p; + p = r_state->parent[p]; + } + } +out: + tsk_bit_array_free(&child_samples); + return ret; +} + +static int +tsk_treeseq_two_branch_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, + const tsk_bit_array_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, + sample_count_stat_params_t *f_params, norm_func_t *TSK_UNUSED(norm_f), + tsk_size_t n_rows, const double *row_positions, tsk_size_t n_cols, + const double *col_positions, tsk_flags_t TSK_UNUSED(options), double *result) +{ + int ret = 0; + int r, c; + tsk_id_t *row_indexes = NULL, *col_indexes = NULL; + tsk_size_t i, j, k, row, col, *row_repeats = NULL, *col_repeats = NULL; + tsk_bit_array_t node_samples; + iter_state l_state, r_state; + double *result_tmp = NULL, *result_row; + const tsk_size_t num_nodes = self->tables->nodes.num_rows; + + tsk_memset(&node_samples, 0, sizeof(node_samples)); + tsk_memset(&l_state, 0, sizeof(l_state)); + tsk_memset(&r_state, 0, sizeof(r_state)); + result_tmp = tsk_malloc(result_dim * sizeof(*result_tmp)); + if (result_tmp == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + ret = iter_state_init(&l_state, self, state_dim); + if (ret != 0) { + goto out; + } + ret = iter_state_init(&r_state, self, state_dim); + if (ret != 0) { + goto out; + } + ret = positions_to_tree_indexes(self, row_positions, n_rows, &row_indexes); + if (ret != 0) { + goto out; + } + ret = positions_to_tree_indexes(self, col_positions, n_cols, &col_indexes); + if (ret != 0) { + goto out; + } + ret = get_index_counts(row_indexes, n_rows, &row_repeats); + if (ret != 0) { + goto out; + } + ret = get_index_counts(col_indexes, n_cols, &col_repeats); + if (ret != 0) { + goto out; + } + ret = get_node_samples(self, state_dim, sample_sets, &node_samples); + if (ret != 0) { + goto out; + } + iter_state_clear(&l_state, state_dim, num_nodes, &node_samples); + row = 0; + for (r = 0; r < (row_indexes[n_rows ? n_rows - 1U : 0] - row_indexes[0] + 1); r++) { + tsk_memset(result_tmp, 0, result_dim * sizeof(*result_tmp)); + iter_state_clear(&r_state, state_dim, num_nodes, &node_samples); + ret = advance_collect_edges(&l_state, (tsk_id_t) r + row_indexes[0]); + if (ret != 0) { + goto out; + } + result_row = GET_2D_ROW(result, result_dim * n_cols, row); + ret = compute_two_tree_branch_stat( + self, &r_state, &l_state, f, f_params, result_dim, state_dim, result_tmp); + if (ret != 0) { + goto out; + } + col = 0; + for (c = 0; c < (col_indexes[n_cols ? n_cols - 1 : 0] - col_indexes[0] + 1); + c++) { + ret = advance_collect_edges(&r_state, (tsk_id_t) c + col_indexes[0]); + if (ret != 0) { + goto out; + } + ret = compute_two_tree_branch_stat(self, &l_state, &r_state, f, f_params, + result_dim, state_dim, result_tmp); + if (ret != 0) { + goto out; + } + for (i = 0; i < row_repeats[r]; i++) { + for (j = 0; j < col_repeats[c]; j++) { + result_row = GET_2D_ROW(result, result_dim * n_cols, row + i); + for (k = 0; k < result_dim; k++) { + result_row[col + (j * result_dim) + k] = result_tmp[k]; + } + } + } + col += (col_repeats[c] * result_dim); + } + row += row_repeats[r]; + } +out: + tsk_safe_free(result_tmp); + tsk_safe_free(row_indexes); + tsk_safe_free(col_indexes); + tsk_safe_free(row_repeats); + tsk_safe_free(col_repeats); + iter_state_free(&l_state); + iter_state_free(&r_state); + tsk_bit_array_free(&node_samples); + return ret; +} + static int tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t result_dim, const tsk_id_t *set_indexes, general_stat_func_t *f, norm_func_t *norm_f, tsk_size_t out_rows, const tsk_id_t *row_sites, - tsk_size_t out_cols, const tsk_id_t *col_sites, tsk_flags_t options, double *result) + const double *row_positions, tsk_size_t out_cols, const tsk_id_t *col_sites, + const double *col_positions, tsk_flags_t options, double *result) { // TODO: generalize this function if we ever decide to do weighted two_locus stats. // We only implement count stats and therefore we don't handle weights. @@ -2658,6 +3220,11 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl tsk_memset(&sample_sets_bits, 0, sizeof(sample_sets_bits)); + // We do not support two-locus node stats + if (!!(options & TSK_STAT_NODE)) { + ret = TSK_ERR_UNSUPPORTED_STAT_MODE; + goto out; + } // If no mode is specified, we default to site mode if (!(stat_site || stat_branch)) { stat_site = true; @@ -2696,8 +3263,20 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl ret = tsk_treeseq_two_site_count_stat(self, state_dim, &sample_sets_bits, result_dim, f, &f_params, norm_f, out_rows, row_sites, out_cols, col_sites, options, result); - } else { - ret = TSK_ERR_UNSUPPORTED_STAT_MODE; + } else if (stat_branch) { + ret = check_positions( + row_positions, out_rows, tsk_treeseq_get_sequence_length(self)); + if (ret != 0) { + goto out; + } + ret = check_positions( + col_positions, out_cols, tsk_treeseq_get_sequence_length(self)); + if (ret != 0) { + goto out; + } + ret = tsk_treeseq_two_branch_count_stat(self, state_dim, &sample_sets_bits, + result_dim, f, &f_params, norm_f, out_rows, row_positions, out_cols, + col_positions, options, result); } out: @@ -3527,13 +4106,15 @@ D_summary_func(tsk_size_t state_dim, const double *state, int tsk_treeseq_D(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, - const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, - tsk_flags_t options, double *result) + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result) { options |= TSK_STAT_POLARISED; // TODO: allow user to pick? return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_sample_sets, NULL, D_summary_func, norm_total_weighted, - num_rows, row_sites, num_cols, col_sites, options, result); + num_rows, row_sites, row_positions, num_cols, col_sites, col_positions, options, + result); } static int @@ -3564,12 +4145,14 @@ D2_summary_func(tsk_size_t state_dim, const double *state, int tsk_treeseq_D2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, - const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, - tsk_flags_t options, double *result) + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result) { return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_sample_sets, NULL, D2_summary_func, norm_total_weighted, - num_rows, row_sites, num_cols, col_sites, options, result); + num_rows, row_sites, row_positions, num_cols, col_sites, col_positions, options, + result); } static int @@ -3602,12 +4185,13 @@ r2_summary_func(tsk_size_t state_dim, const double *state, int tsk_treeseq_r2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, - const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, - tsk_flags_t options, double *result) + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result) { return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_sample_sets, NULL, r2_summary_func, norm_hap_weighted, num_rows, - row_sites, num_cols, col_sites, options, result); + row_sites, row_positions, num_cols, col_sites, col_positions, options, result); } static int @@ -3643,13 +4227,15 @@ D_prime_summary_func(tsk_size_t state_dim, const double *state, int tsk_treeseq_D_prime(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, - const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, - tsk_flags_t options, double *result) + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result) { options |= TSK_STAT_POLARISED; // TODO: allow user to pick? return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_sample_sets, NULL, D_prime_summary_func, norm_hap_weighted, - num_rows, row_sites, num_cols, col_sites, options, result); + num_rows, row_sites, row_positions, num_cols, col_sites, col_positions, options, + result); } static int @@ -3682,13 +4268,15 @@ r_summary_func(tsk_size_t state_dim, const double *state, int tsk_treeseq_r(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, - const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, - tsk_flags_t options, double *result) + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result) { options |= TSK_STAT_POLARISED; // TODO: allow user to pick? return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_sample_sets, NULL, r_summary_func, norm_total_weighted, - num_rows, row_sites, num_cols, col_sites, options, result); + num_rows, row_sites, row_positions, num_cols, col_sites, col_positions, options, + result); } static int @@ -3720,12 +4308,14 @@ Dz_summary_func(tsk_size_t state_dim, const double *state, int tsk_treeseq_Dz(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, - const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, - tsk_flags_t options, double *result) + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result) { return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_sample_sets, NULL, Dz_summary_func, norm_total_weighted, - num_rows, row_sites, num_cols, col_sites, options, result); + num_rows, row_sites, row_positions, num_cols, col_sites, col_positions, options, + result); } static int @@ -3754,12 +4344,127 @@ pi2_summary_func(tsk_size_t state_dim, const double *state, int tsk_treeseq_pi2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, - const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, - tsk_flags_t options, double *result) + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result) { return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_sample_sets, NULL, pi2_summary_func, norm_total_weighted, - num_rows, row_sites, num_cols, col_sites, options, result); + num_rows, row_sites, row_positions, num_cols, col_sites, col_positions, options, + result); +} + +static int +D2_unbiased_summary_func(tsk_size_t state_dim, const double *state, + tsk_size_t TSK_UNUSED(result_dim), double *result, void *params) +{ + sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; + double n; + const double *state_row; + tsk_size_t j; + + for (j = 0; j < state_dim; j++) { + n = (double) args.sample_set_sizes[j]; + state_row = GET_2D_ROW(state, 3, j); + double w_AB = state_row[0]; + double w_Ab = state_row[1]; + double w_aB = state_row[2]; + double w_ab = n - (w_AB + w_Ab + w_aB); + result[j] = (1 / (n * (n - 1) * (n - 2) * (n - 3))) + * ((w_aB * w_aB * (w_Ab - 1) * w_Ab) + + ((w_ab - 1) * w_ab * (w_AB - 1) * w_AB) + - (w_aB * w_Ab * (w_Ab + (2 * w_ab * w_AB) - 1))); + } + return 0; +} + +int +tsk_treeseq_D2_unbiased(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result) +{ + return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_sample_sets, NULL, D2_unbiased_summary_func, + norm_total_weighted, num_rows, row_sites, row_positions, num_cols, col_sites, + col_positions, options, result); +} + +static int +Dz_unbiased_summary_func(tsk_size_t state_dim, const double *state, + tsk_size_t TSK_UNUSED(result_dim), double *result, void *params) +{ + sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; + double n; + const double *state_row; + tsk_size_t j; + + for (j = 0; j < state_dim; j++) { + n = (double) args.sample_set_sizes[j]; + state_row = GET_2D_ROW(state, 3, j); + double w_AB = state_row[0]; + double w_Ab = state_row[1]; + double w_aB = state_row[2]; + double w_ab = n - (w_AB + w_Ab + w_aB); + result[j] = (1 / (n * (n - 1) * (n - 2) * (n - 3))) + * ((((w_AB * w_ab) - (w_Ab * w_aB)) * (w_aB + w_ab - w_AB - w_Ab) + * (w_Ab + w_ab - w_AB - w_aB)) + - ((w_AB * w_ab) * (w_AB + w_ab - w_Ab - w_aB - 2)) + - ((w_Ab * w_aB) * (w_Ab + w_aB - w_AB - w_ab - 2))); + } + return 0; +} + +int +tsk_treeseq_Dz_unbiased(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result) +{ + return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_sample_sets, NULL, Dz_unbiased_summary_func, + norm_total_weighted, num_rows, row_sites, row_positions, num_cols, col_sites, + col_positions, options, result); +} + +static int +pi2_unbiased_summary_func(tsk_size_t state_dim, const double *state, + tsk_size_t TSK_UNUSED(result_dim), double *result, void *params) +{ + sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; + double n; + const double *state_row; + tsk_size_t j; + + for (j = 0; j < state_dim; j++) { + n = (double) args.sample_set_sizes[j]; + state_row = GET_2D_ROW(state, 3, j); + double w_AB = state_row[0]; + double w_Ab = state_row[1]; + double w_aB = state_row[2]; + double w_ab = n - (w_AB + w_Ab + w_aB); + result[j] + = (1 / (n * (n - 1) * (n - 2) * (n - 3))) + * (((w_AB + w_Ab) * (w_aB + w_ab) * (w_AB + w_aB) * (w_Ab + w_ab)) + - ((w_AB * w_ab) * (w_AB + w_ab + (3 * w_Ab) + (3 * w_aB) - 1)) + - ((w_Ab * w_aB) * (w_Ab + w_aB + (3 * w_AB) + (3 * w_ab) - 1))); + } + return 0; +} + +int +tsk_treeseq_pi2_unbiased(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result) +{ + return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_sample_sets, NULL, pi2_unbiased_summary_func, + norm_total_weighted, num_rows, row_sites, row_positions, num_cols, col_sites, + col_positions, options, result); } /*********************************** diff --git a/c/tskit/trees.h b/c/tskit/trees.h index d7b64d0701..b23fa55320 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -1073,36 +1073,59 @@ int tsk_treeseq_genetic_relatedness(const tsk_treeseq_t *self, typedef int two_locus_count_stat_method(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, const tsk_id_t *row_sites, - tsk_size_t num_cols, const tsk_id_t *col_sites, tsk_flags_t options, double *result); + const double *row_positions, tsk_size_t num_cols, const tsk_id_t *col_sites, + const double *col_positions, tsk_flags_t options, double *result); int tsk_treeseq_D(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, - const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, - tsk_flags_t options, double *result); + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result); int tsk_treeseq_D2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, - const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, - tsk_flags_t options, double *result); + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result); int tsk_treeseq_r2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, - const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, - tsk_flags_t options, double *result); + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result); int tsk_treeseq_D_prime(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, - const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, - tsk_flags_t options, double *result); + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result); int tsk_treeseq_r(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, - const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, - tsk_flags_t options, double *result); + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result); int tsk_treeseq_Dz(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, - const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, - tsk_flags_t options, double *result); + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result); int tsk_treeseq_pi2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, - const tsk_id_t *row_sites, tsk_size_t num_cols, const tsk_id_t *col_sites, - tsk_flags_t options, double *result); + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result); +int tsk_treeseq_D2_unbiased(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result); +int tsk_treeseq_Dz_unbiased(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result); +int tsk_treeseq_pi2_unbiased(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result); /* Three way sample set stats */ int tsk_treeseq_Y3(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index c403b12af6..3972caf447 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -10135,24 +10135,71 @@ TreeSequence_pair_coalescence_quantiles( return ret; } +static PyArrayObject * +parse_sites(TreeSequence *self, PyObject *sites, npy_intp *out_dim) +{ + PyArrayObject *array; + tsk_size_t num_sites = tsk_treeseq_get_num_sites(self->tree_sequence); + + if (sites == Py_None) { + array = (PyArrayObject *) PyArray_Arange(0, num_sites, 1, NPY_INT32); + if (array == NULL) { + goto out; + } + *out_dim = PyArray_DIM(array, 0); + } else { + array = (PyArrayObject *) PyArray_FROMANY( + sites, NPY_INT32, 1, 1, NPY_ARRAY_IN_ARRAY); + if (array == NULL) { + goto out; + } + *out_dim = PyArray_DIM(array, 0); + } + +out: + return array; +} + +static PyArrayObject * +parse_positions(TreeSequence *self, PyObject *positions, npy_intp *out_dim) +{ + PyArrayObject *array; + + if (positions == Py_None) { + array = (PyArrayObject *) TreeSequence_get_breakpoints(self); + if (array == NULL) { + goto out; + } + *out_dim = PyArray_DIM(array, 0) - 1; // NB the last element must be truncated + } else { + array = (PyArrayObject *) PyArray_FROMANY( + positions, NPY_FLOAT64, 1, 1, NPY_ARRAY_IN_ARRAY); + if (array == NULL) { + goto out; + } + *out_dim = PyArray_DIM(array, 0); + } +out: + return array; +} + static PyObject * TreeSequence_ld_matrix(TreeSequence *self, PyObject *args, PyObject *kwds, two_locus_count_stat_method *method) { PyObject *ret = NULL; - static char *kwlist[] - = { "sample_set_sizes", "sample_sets", "row_sites", "col_sites", "mode", NULL }; + static char *kwlist[] = { "sample_set_sizes", "sample_sets", "row_sites", + "col_sites", "row_positions", "column_positions", "mode", NULL }; - PyObject *row_sites = NULL; - PyObject *col_sites = NULL; - PyObject *sample_set_sizes = NULL; - PyObject *sample_sets = NULL; - PyArrayObject *sample_set_sizes_array = NULL; - PyArrayObject *sample_sets_array = NULL; - PyArrayObject *row_sites_array = NULL; - PyArrayObject *col_sites_array = NULL; - PyArrayObject *result_matrix = NULL; - npy_intp result_shape[3]; + PyObject *row_sites = NULL, *col_sites = NULL, *row_positions = NULL, + *col_positions = NULL, *sample_set_sizes = NULL, *sample_sets = NULL; + PyArrayObject *row_sites_array = NULL, *col_sites_array = NULL, + *row_positions_array = NULL, *col_positions_array = NULL, + *sample_sets_array = NULL, *sample_set_sizes_array = NULL, + *result_matrix = NULL; + tsk_id_t *row_sites_parsed = NULL, *col_sites_parsed = NULL; + double *row_positions_parsed = NULL, *col_positions_parsed = NULL; + npy_intp result_dim[3] = { 0, 0, 0 }; char *mode = NULL; tsk_size_t num_sample_sets; tsk_flags_t options = 0; @@ -10161,8 +10208,9 @@ TreeSequence_ld_matrix(TreeSequence *self, PyObject *args, PyObject *kwds, if (TreeSequence_check_state(self) != 0) { goto out; } - if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOOs", kwlist, &sample_set_sizes, - &sample_sets, &row_sites, &col_sites, &mode)) { + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOOOO|s", kwlist, &sample_set_sizes, + &sample_sets, &row_sites, &col_sites, &row_positions, &col_positions, + &mode)) { goto out; } if (parse_stats_mode(mode, &options) != 0) { @@ -10173,22 +10221,37 @@ TreeSequence_ld_matrix(TreeSequence *self, PyObject *args, PyObject *kwds, != 0) { goto out; } - row_sites_array = (PyArrayObject *) PyArray_FROMANY( - row_sites, NPY_INT32, 1, 1, NPY_ARRAY_IN_ARRAY); - if (row_sites_array == NULL) { - goto out; - } - col_sites_array = (PyArrayObject *) PyArray_FROMANY( - col_sites, NPY_INT32, 1, 1, NPY_ARRAY_IN_ARRAY); - if (col_sites_array == NULL) { - goto out; + + if (options & TSK_STAT_SITE) { + if (row_positions != Py_None || col_positions != Py_None) { + PyErr_SetString(PyExc_ValueError, "Cannot specify positions in site mode"); + goto out; + } + row_sites_array = parse_sites(self, row_sites, &(result_dim[0])); + col_sites_array = parse_sites(self, col_sites, &(result_dim[1])); + if (row_sites_array == NULL || col_sites_array == NULL) { + goto out; + } + row_sites_parsed = PyArray_DATA(row_sites_array); + col_sites_parsed = PyArray_DATA(col_sites_array); + } else if (options & TSK_STAT_BRANCH) { + if (row_sites != Py_None || col_sites != Py_None) { + PyErr_SetString(PyExc_ValueError, "Cannot specify sites in branch mode"); + goto out; + } + row_positions_array = parse_positions(self, row_positions, &(result_dim[0])); + col_positions_array = parse_positions(self, col_positions, &(result_dim[1])); + if (col_positions_array == NULL || row_positions_array == NULL) { + goto out; + } + row_positions_parsed = PyArray_DATA(row_positions_array); + col_positions_parsed = PyArray_DATA(col_positions_array); } - result_shape[0] = PyArray_DIM(row_sites_array, 0); - result_shape[1] = PyArray_DIM(col_sites_array, 0); - result_shape[2] = num_sample_sets; - result_matrix = (PyArrayObject *) PyArray_ZEROS(3, result_shape, NPY_FLOAT64, 0); + result_dim[2] = num_sample_sets; + result_matrix = (PyArrayObject *) PyArray_ZEROS(3, result_dim, NPY_FLOAT64, 0); if (result_matrix == NULL) { + PyErr_NoMemory(); goto out; } @@ -10196,8 +10259,8 @@ TreeSequence_ld_matrix(TreeSequence *self, PyObject *args, PyObject *kwds, Py_BEGIN_ALLOW_THREADS err = method(self->tree_sequence, num_sample_sets, PyArray_DATA(sample_set_sizes_array), PyArray_DATA(sample_sets_array), - result_shape[0], PyArray_DATA(row_sites_array), result_shape[1], - PyArray_DATA(col_sites_array), options, PyArray_DATA(result_matrix)); + result_dim[0], row_sites_parsed, row_positions_parsed, result_dim[1], + col_sites_parsed, col_positions_parsed, options, PyArray_DATA(result_matrix)); Py_END_ALLOW_THREADS // clang-format on @@ -10211,8 +10274,10 @@ TreeSequence_ld_matrix(TreeSequence *self, PyObject *args, PyObject *kwds, out: Py_XDECREF(row_sites_array); Py_XDECREF(col_sites_array); - Py_XDECREF(sample_set_sizes_array); + Py_XDECREF(row_positions_array); + Py_XDECREF(col_positions_array); Py_XDECREF(sample_sets_array); + Py_XDECREF(sample_set_sizes_array); Py_XDECREF(result_matrix); return ret; } @@ -10259,6 +10324,24 @@ TreeSequence_pi2_matrix(TreeSequence *self, PyObject *args, PyObject *kwds) return TreeSequence_ld_matrix(self, args, kwds, tsk_treeseq_pi2); } +static PyObject * +TreeSequence_pi2_unbiased_matrix(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_ld_matrix(self, args, kwds, tsk_treeseq_pi2_unbiased); +} + +static PyObject * +TreeSequence_D2_unbiased_matrix(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_ld_matrix(self, args, kwds, tsk_treeseq_D2_unbiased); +} + +static PyObject * +TreeSequence_Dz_unbiased_matrix(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_ld_matrix(self, args, kwds, tsk_treeseq_Dz_unbiased); +} + static PyObject * TreeSequence_get_num_mutations(TreeSequence *self) { @@ -11023,6 +11106,18 @@ static PyMethodDef TreeSequence_methods[] = { .ml_meth = (PyCFunction) TreeSequence_pi2_matrix, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Computes the pi2 matrix." }, + { .ml_name = "D2_unbiased_matrix", + .ml_meth = (PyCFunction) TreeSequence_D2_unbiased_matrix, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes the unbiased D2 matrix." }, + { .ml_name = "Dz_unbiased_matrix", + .ml_meth = (PyCFunction) TreeSequence_Dz_unbiased_matrix, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes the unbiased Dz matrix." }, + { .ml_name = "pi2_unbiased_matrix", + .ml_meth = (PyCFunction) TreeSequence_pi2_unbiased_matrix, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes the unbiased pi2 matrix." }, { NULL } /* Sentinel */ }; diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index 761658224e..3926767f8b 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -25,7 +25,6 @@ import contextlib import io from dataclasses import dataclass -from itertools import combinations from itertools import combinations_with_replacement from itertools import permutations from itertools import product @@ -631,14 +630,14 @@ def get_index_repeats(indices): def two_branch_count_stat( ts: tskit.TreeSequence, func: Callable[[int, np.ndarray, np.ndarray, Dict[str, Any]], None], - norm_func, # TODO: might need for polarisation + norm_func, num_sample_sets: int, sample_set_sizes: np.ndarray, sample_sets: BitSet, sample_index_map: np.ndarray, row_trees: np.ndarray, col_trees: np.ndarray, - polarised: bool, # TODO: polarisation + polarised: bool, ) -> np.ndarray: """ Compute a tree X tree LD matrix by walking along the tree sequence and @@ -1090,8 +1089,8 @@ def d2_unbiased( "D_prime": D_prime_summary_func, "pi2": pi2_summary_func, "Dz": Dz_summary_func, - "d2_unbiased": d2_unbiased, - "dz_unbiased": dz_unbiased, + "D2_unbiased": d2_unbiased, + "Dz_unbiased": dz_unbiased, "pi2_unbiased": pi2_unbiased, } @@ -1103,9 +1102,9 @@ def d2_unbiased( pi2_summary_func: norm_total_weighted, r_summary_func: norm_total_weighted, r2_summary_func: norm_hap_weighted, - d2_unbiased: None, - dz_unbiased: None, - pi2_unbiased: None, + d2_unbiased: norm_total_weighted, + dz_unbiased: norm_total_weighted, + pi2_unbiased: norm_total_weighted, } POLARIZATION = { @@ -1276,7 +1275,11 @@ def test_subset_positions(partition): bp = ts.breakpoints(as_array=True) mid = (bp[1:] + bp[:-1]) / 2 np.testing.assert_allclose( - ld_matrix(ts, mode="branch", stat="d2_unbiased", positions=[mid[a], mid[b]]), + ld_matrix(ts, mode="branch", stat="D2_unbiased", positions=[mid[a], mid[b]]), + PAPER_EX_BRANCH_TRUTH_MATRIX[a[0] : a[-1] + 1, b[0] : b[-1] + 1], + ) + np.testing.assert_allclose( + ts.ld_matrix(mode="branch", stat="D2_unbiased", positions=[mid[a], mid[b]]), PAPER_EX_BRANCH_TRUTH_MATRIX[a[0] : a[-1] + 1, b[0] : b[-1] + 1], ) @@ -1321,7 +1324,13 @@ def test_subset_positions_one_list(tree_index): bp = ts.breakpoints(as_array=True) mid = (bp[1:] + bp[:-1]) / 2 np.testing.assert_allclose( - ld_matrix(ts, mode="branch", stat="d2_unbiased", positions=[mid[tree_index]]), + ld_matrix(ts, mode="branch", stat="D2_unbiased", positions=[mid[tree_index]]), + PAPER_EX_BRANCH_TRUTH_MATRIX[ + tree_index[0] : tree_index[-1] + 1, tree_index[0] : tree_index[-1] + 1 + ], + ) + np.testing.assert_allclose( + ts.ld_matrix(mode="branch", stat="D2_unbiased", positions=[mid[tree_index]]), PAPER_EX_BRANCH_TRUTH_MATRIX[ tree_index[0] : tree_index[-1] + 1, tree_index[0] : tree_index[-1] + 1 ], @@ -1363,7 +1372,11 @@ def test_repeated_position_elements(tree_index): np.testing.assert_allclose( truth, - ld_matrix(ts, mode="branch", stat="d2_unbiased", positions=[l_pos, r_pos]), + ld_matrix(ts, mode="branch", stat="D2_unbiased", positions=[l_pos, r_pos]), + ) + np.testing.assert_allclose( + truth, + ts.ld_matrix(mode="branch", stat="D2_unbiased", positions=[l_pos, r_pos]), ) @@ -1378,7 +1391,7 @@ def test_sample_sets(partition): :param partition: length 2 list of [ss_1, ss_2]. """ ts = get_paper_ex_ts() - np.testing.assert_array_almost_equal( + np.testing.assert_allclose( ld_matrix(ts, sample_sets=partition), ts.ld_matrix(sample_sets=partition) ) @@ -1394,7 +1407,7 @@ def test_compare_to_ld_calculator(): @pytest.mark.parametrize( "stat", - sorted(SUMMARY_FUNCS.keys() - {"d2_unbiased", "dz_unbiased", "pi2_unbiased"}), + sorted(SUMMARY_FUNCS.keys()), ) def test_multiallelic_with_back_mutation(stat): ts = msprime.sim_ancestry( @@ -1417,7 +1430,7 @@ def test_multiallelic_with_back_mutation(stat): # TODO: port unbiased summary functions @pytest.mark.parametrize( "stat", - sorted(SUMMARY_FUNCS.keys() - {"d2_unbiased", "dz_unbiased", "pi2_unbiased"}), + sorted(SUMMARY_FUNCS.keys()), ) def test_ld_matrix(ts, stat): np.testing.assert_array_almost_equal( @@ -1432,21 +1445,35 @@ def test_ld_matrix(ts, stat): def test_ld_empty_examples(ts): with pytest.raises(ValueError, match="at least one element"): ts.ld_matrix() + with pytest.raises(ValueError, match="at least one element"): + ts.ld_matrix(mode="branch") def test_input_validation(): ts = get_paper_ex_ts() with pytest.raises(ValueError, match="Unknown two-locus statistic"): ts.ld_matrix(stat="bad_stat") + with pytest.raises(ValueError, match="must be a list of"): ts.ld_matrix(sites=["abc"]) with pytest.raises(ValueError, match="must be a list of"): ts.ld_matrix(sites=[1, 2, 3]) with pytest.raises(ValueError, match="must be a length 1 or 2 list"): ts.ld_matrix(sites=[[1, 2], [2, 3], [3, 4]]) + with pytest.raises(ValueError, match="must be a length 1 or 2 list"): + ts.ld_matrix(sites=[[1, 2], [2, 3], [3, 4]]) with pytest.raises(ValueError, match="must be a length 1 or 2 list"): ts.ld_matrix(sites=[]) + with pytest.raises(ValueError, match="must be a list of"): + ts.ld_matrix(positions=["abc"], mode="branch") + with pytest.raises(ValueError, match="must be a list of"): + ts.ld_matrix(positions=[1.0, 2.0, 3.0], mode="branch") + with pytest.raises(ValueError, match="must be a length 1 or 2 list"): + ts.ld_matrix(positions=[[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]], mode="branch") + with pytest.raises(ValueError, match="must be a length 1 or 2 list"): + ts.ld_matrix(positions=[], mode="branch") + @dataclass class TreeState: @@ -1479,7 +1506,9 @@ def __init__(self, ts, sample_sets, num_sample_sets, sample_index_map): for n in range(ts.num_nodes): for k in range(num_sample_sets): if sample_sets.contains(k, sample_index_map[n]): - self.node_samples.add((num_sample_sets * n) + k, n) + self.node_samples.add( + (num_sample_sets * n) + k, sample_index_map[n] + ) # these are empty for the uninitialized state (index = -1) self.edges_in = [] self.edges_out = [] @@ -1730,188 +1759,51 @@ def compute_branch_stat(ts, stat_func, stat, params, state_dim, l_state, r_state return stat, r_state -# What follows is an implementation of two-locus statistics as described in -# McVean 2002 (https://doi.org/10.1093/genetics/162.2.987). We compute the -# covariance between coalescent times to produce expectations of coalescent -# times between three sampling patterns of samples. These expectations can be -# compined to produce D2, Dz, and pi2. These are for testing and to demonstrate -# conceptual parity between our method and McVean's method. - - -def tmrca(tr, x, y): - """ - Mirror the functionality in the branch two-locus stats. We want to compute - the contribution of each subset of samples. If there is no most recent common - ancestor, we walk up the tree and find each sample's individual MRCA (which - as written is realy just the root of the tree). This is to work around the case - of empty, gapped, and decapitated trees. - """ - try: - # First, we try to get the tmrca - return tr.tmrca(x, y) - except ValueError as e: - # If we cannot, crawl up as far as the sample is connected - x_mrca, y_mrca = -1, -1 - if "not share a common ancestor" not in str(e): - raise e - for r in tr.roots: - if x in set(tr.samples(r)): - x_mrca = r - if y in set(tr.samples(r)): - y_mrca = r - if x_mrca == -1 or y_mrca == -1: - raise ValueError - return (tr.time(x_mrca) + tr.time(y_mrca)) / 2 - - -def compute_D2(x, y, ij, ijk, ijkl): - E_ijij = 0 - E_ijik = 0 - E_ijkl = 0 - if len(ij) == 0 or len(ijk) == 0 or len(ijkl) == 0: - # this method requires at least 4 samples - return float("nan") - for i, j in ij: - i_time = x.time(i) - j_time = x.time(j) - ij_time = (i_time + j_time) / 2 - E_ijij += (tmrca(x, i, j) - ij_time) * (tmrca(y, i, j) - ij_time) - for i, j, k in ijk: - i_time = x.time(i) - j_time = x.time(j) - k_time = x.time(k) - ij_time = (i_time + j_time) / 2 - ik_time = (i_time + k_time) / 2 - E_ijik += (tmrca(x, i, j) - ij_time) * (tmrca(y, i, k) - ik_time) - for i, j, k, l in ijkl: - i_time = x.time(i) - j_time = x.time(j) - k_time = x.time(k) - l_time = x.time(l) - ij_time = (i_time + j_time) / 2 - kl_time = (k_time + l_time) / 2 - E_ijkl += (tmrca(x, i, j) - ij_time) * (tmrca(y, k, l) - kl_time) - E_ijij = E_ijij / len(ij) - E_ijik = E_ijik / len(ijk) - E_ijkl = E_ijkl / len(ijkl) - return E_ijij - 2 * E_ijik + E_ijkl - - -def compute_Dz(x, y, ij, ijk, ijkl): - E_ijik = 0 - E_ijkl = 0 - if len(ijk) == 0 or len(ijkl) == 0: - # this method requires at least 4 samples - return float("nan") - for i, j, k in ijk: - i_time = x.time(i) - j_time = x.time(j) - k_time = x.time(k) - ij_time = (i_time + j_time) / 2 - ik_time = (i_time + k_time) / 2 - E_ijik += (tmrca(x, i, j) - ij_time) * (tmrca(y, i, k) - ik_time) - for i, j, k, l in ijkl: - i_time = x.time(i) - j_time = x.time(j) - k_time = x.time(k) - l_time = x.time(l) - ij_time = (i_time + j_time) / 2 - kl_time = (k_time + l_time) / 2 - E_ijkl += (tmrca(x, i, j) - ij_time) * (tmrca(y, k, l) - kl_time) - E_ijik = E_ijik / len(ijk) - E_ijkl = E_ijkl / len(ijkl) - return 4 * (E_ijik - E_ijkl) - - -def compute_pi2(x, y, ij, ijk, ijkl): - E_ijkl = 0 - if len(ijkl) == 0: - # this method requires at least 4 samples - return float("nan") - for i, j, k, l in ijkl: - i_time = x.time(i) - j_time = x.time(j) - k_time = x.time(k) - l_time = x.time(l) - ij_time = (i_time + j_time) / 2 - kl_time = (k_time + l_time) / 2 - E_ijkl += (tmrca(x, i, j) - ij_time) * (tmrca(y, k, l) - kl_time) - E_ijkl = E_ijkl / len(ijkl) - return E_ijkl - - -def combine(samples): - # All combinations where i != j - ij = list(combinations(samples, 2)) - # All combinations where i != {j,k} and j != k - ijk = [ - (i, j, k) - for i, j, k in product(samples, repeat=3) - if i != k and i != j and j != k - ] - # All combinations where i != {k,l} and j != {k,l} - ijkl = [ - (i, j, samples[k], samples[l]) - for i, j in combinations(samples, 2) - for k in range(len(samples)) - for l in range(k + 1, len(samples)) # noqa: E741 - if i != samples[k] and j != samples[k] and samples[l] != i and samples[l] != j - ] - return ij, ijk, ijkl - - -def naive_matrix(ts, stat_func, sample_set=None): - """Compute a tree x tree LD matrix for a given tree sequence and two-locus - statistic. This produces a matrix of LD that is generated from the - covariance in gene genealogies, as described in McVean 2002. - - :param ts: Tree sequence to gather data from. - :param stat_func: Function to compute a two-locus statistic from two - materialized trees and sample combinations. - :returns: Pairwise branch LD matrix for an entire tree sequence. - """ - result = np.zeros((ts.num_trees, ts.num_trees), dtype=np.float64) - # These stats require at least 4 samples in the tree - ij, ijk, ijkl = combine(sample_set or ts.samples()) - for i, j in combinations_with_replacement(range(ts.num_trees), 2): - val = stat_func(ts.at_index(i), ts.at_index(j), ij, ijk, ijkl) - result[i, j] = val - tri_idx = np.tril_indices(len(result), k=-1) - result[tri_idx] = result.T[tri_idx] - return result - - @pytest.mark.parametrize( "ts", [ ts for ts in get_example_tree_sequences() - # no_samples and empty_ts aren't handled here. if ts.id - in { - # We only perform tests on a useful subset of the example trees due to - # runtime constraints of the naive McVean implementation. We plan to expand - # coverage to more examples after implementing the C version - "all_nodes_samples", - "internal_nodes_samples", - "mixed_internal_leaf_samples", - "n=2_m=32_rho=0.5", - "bottleneck_n=10_mutated", - "rev_node_order", - "decapitate", + not in { + "no_samples", + "empty_ts", + # We must skip these cases so that tests run in a reasonable + # amount of time. To get more complete testing, these filters + # can be commented out. (runtime ~1hr) + "gap_0", + "gap_0.1", + "gap_0.5", + "gap_0.75", + "n=2_m=32_rho=0", + "n=10_m=1_rho=0", + "n=10_m=1_rho=0.1", + "n=10_m=2_rho=0", + "n=10_m=2_rho=0.1", + "n=10_m=32_rho=0", + "n=10_m=32_rho=0.1", + "n=10_m=32_rho=0.5", + # we keep one n=100 case to ensure bit arrays are working + "n=100_m=1_rho=0.1", + "n=100_m=1_rho=0.5", + "n=100_m=2_rho=0", + "n=100_m=2_rho=0.1", + "n=100_m=2_rho=0.5", + "n=100_m=32_rho=0", + "n=100_m=32_rho=0.1", + "n=100_m=32_rho=0.5", + "all_fields", + "back_mutations", + "multichar", + "multichar_no_metadata", + "bottleneck_n=100_mutated", } ], ) -@pytest.mark.parametrize( - "stat,stat_func", - zip( - ["d2_unbiased", "dz_unbiased", "pi2_unbiased"], - [compute_D2, compute_Dz, compute_pi2], - ), -) -def test_branch_ld_matrix(ts, stat, stat_func): +@pytest.mark.parametrize("stat", sorted(SUMMARY_FUNCS.keys())) +def test_branch_ld_matrix(ts, stat): np.testing.assert_array_almost_equal( - ld_matrix(ts, stat=stat, mode="branch"), naive_matrix(ts, stat_func) + ts.ld_matrix(stat=stat, mode="branch"), ld_matrix(ts, stat=stat, mode="branch") ) @@ -1933,23 +1825,16 @@ def get_test_branch_sample_set_test_cases(): [[1, 2, 4, 9]], id="bottleneck_n=10_mutated", ), - pytest.param( - p_dict["multichar"].values[0], [[10, 11, 12, 13, 14, 15]], id="multichar" - ), pytest.param(p_dict["gap_at_end"].values[0], [[1, 3, 5, 8]], id="gap_at_end"), ] @pytest.mark.parametrize("ts,sample_set", get_test_branch_sample_set_test_cases()) -@pytest.mark.parametrize( - "stat,stat_func", - zip( - ["d2_unbiased", "dz_unbiased", "pi2_unbiased"], - [compute_D2, compute_Dz, compute_pi2], - ), -) -def test_branch_ld_matrix_sample_sets(ts, sample_set, stat, stat_func): +@pytest.mark.parametrize("stat", sorted(SUMMARY_FUNCS.keys())) +def test_branch_ld_matrix_sample_sets(ts, sample_set, stat): np.testing.assert_array_almost_equal( - ld_matrix(ts, stat=stat, mode="branch", sample_sets=sample_set), - naive_matrix(ts, stat_func, sample_set[0]), + np.expand_dims( + ld_matrix(ts, stat=stat, mode="branch", sample_sets=sample_set), axis=0 + ), + ts.ld_matrix(stat=stat, mode="branch", sample_sets=sample_set), ) diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index d130ba03b6..f04b317d53 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -1522,101 +1522,173 @@ def test_extend_edges_bad_args(self): "r_matrix", "Dz_matrix", "pi2_matrix", + "D2_unbiased_matrix", + "Dz_unbiased_matrix", + "pi2_unbiased_matrix", ], ) def test_ld_matrix(self, stat_method_name): ts = self.get_example_tree_sequence(10) stat_method = getattr(ts, stat_method_name) - mode = "site" - sample_sets = ts.get_samples() - sample_set_sizes = np.array([len(sample_sets)], dtype=np.uint32) + ss = ts.get_samples() # sample sets + ss_sizes = np.array([len(ss)], dtype=np.uint32) row_sites = np.arange(ts.get_num_sites(), dtype=np.int32) col_sites = row_sites + row_pos = ts.get_breakpoints()[:-1] + col_pos = row_pos + row_pos_list = list(map(float, ts.get_breakpoints()[:-1])) + col_pos_list = row_pos_list row_sites_list = list(range(ts.get_num_sites())) col_sites_list = row_sites_list # happy path - a = stat_method(sample_set_sizes, sample_sets, row_sites, col_sites, mode) + a = stat_method(ss_sizes, ss, row_sites, col_sites, None, None, "site") assert a.shape == (10, 10, 1) - a = stat_method( - sample_set_sizes, sample_sets, row_sites_list, col_sites_list, mode + ss_sizes, ss, row_sites_list, col_sites_list, None, None, "site" ) assert a.shape == (10, 10, 1) + a = stat_method(ss_sizes, ss, None, None, None, None, "site") + assert a.shape == (10, 10, 1) + + a = stat_method(ss_sizes, ss, None, None, row_pos, col_pos, "branch") + assert a.shape == (2, 2, 1) + a = stat_method(ss_sizes, ss, None, None, row_pos_list, col_pos_list, "branch") + assert a.shape == (2, 2, 1) + a = stat_method(ss_sizes, ss, None, None, None, None, "branch") + assert a.shape == (2, 2, 1) # CPython API errors with pytest.raises(ValueError, match="Sum of sample_set_sizes"): - bad_sample_sets = np.array([], dtype=np.int32) - stat_method(sample_set_sizes, bad_sample_sets, row_sites, col_sites, mode) + bad_ss = np.array([], dtype=np.int32) + stat_method(ss_sizes, bad_ss, row_sites, col_sites, None, None, "site") with pytest.raises(TypeError, match="cast array data"): - bad_sample_sets = np.array(ts.get_samples(), dtype=np.uint32) - stat_method(sample_set_sizes, bad_sample_sets, row_sites, col_sites, mode) + bad_ss = np.array(ts.get_samples(), dtype=np.uint32) + stat_method(ss_sizes, bad_ss, row_sites, col_sites, None, None, "site") with pytest.raises(ValueError, match="Unrecognised stats mode"): - stat_method(sample_set_sizes, sample_sets, row_sites, col_sites, "bla") + stat_method(ss_sizes, ss, row_sites, col_sites, None, None, "bla") with pytest.raises(TypeError, match="at most"): - stat_method( - sample_set_sizes, sample_sets, row_sites, col_sites, mode, "abc" - ) + stat_method(ss_sizes, ss, row_sites, col_sites, None, None, "site", "abc") with pytest.raises(ValueError, match="invalid literal"): bad_sites = ["abadsite", 0, 3, 2] - stat_method(sample_set_sizes, sample_sets, bad_sites, col_sites, mode) + stat_method(ss_sizes, ss, bad_sites, col_sites, None, None, "site") with pytest.raises(TypeError): bad_sites = [None, 0, 3, 2] - stat_method(sample_set_sizes, sample_sets, bad_sites, col_sites, mode) + stat_method(ss_sizes, ss, bad_sites, col_sites, None, None, "site") with pytest.raises(TypeError): bad_sites = [{}, 0, 3, 2] - stat_method(sample_set_sizes, sample_sets, bad_sites, col_sites, mode) + stat_method(ss_sizes, ss, bad_sites, col_sites, None, None, "site") with pytest.raises(TypeError, match="Cannot cast array data"): bad_sites = np.array([0, 1, 2], dtype=np.uint32) - stat_method(sample_set_sizes, sample_sets, bad_sites, col_sites, mode) + stat_method(ss_sizes, ss, bad_sites, col_sites, None, None, "site") with pytest.raises(ValueError, match="invalid literal"): bad_sites = ["abadsite", 0, 3, 2] - stat_method(sample_set_sizes, sample_sets, row_sites, bad_sites, mode) + stat_method(ss_sizes, ss, row_sites, bad_sites, None, None, "site") with pytest.raises(TypeError): bad_sites = [None, 0, 3, 2] - stat_method(sample_set_sizes, sample_sets, row_sites, bad_sites, mode) + stat_method(ss_sizes, ss, row_sites, bad_sites, None, None, "site") with pytest.raises(TypeError): bad_sites = [{}, 0, 3, 2] - stat_method(sample_set_sizes, sample_sets, row_sites, bad_sites, mode) + stat_method(ss_sizes, ss, row_sites, bad_sites, None, None, "site") with pytest.raises(TypeError, match="Cannot cast array data"): bad_sites = np.array([0, 1, 2], dtype=np.uint32) - stat_method(sample_set_sizes, sample_sets, row_sites, bad_sites, mode) + stat_method(ss_sizes, ss, row_sites, bad_sites, None, None, "site") + with pytest.raises(ValueError): + bad_pos = ["abadpos", 0.1, 0.2, 2.0] + stat_method(ss_sizes, ss, None, None, bad_pos, col_pos, "branch") + with pytest.raises(TypeError): + bad_pos = [{}, 0.1, 0.2, 2.0] + stat_method(ss_sizes, ss, None, None, bad_pos, col_pos, "branch") + with pytest.raises(ValueError): + bad_pos = ["abadpos", 0, 3, 2] + stat_method(ss_sizes, ss, None, None, row_pos, bad_pos, "branch") + with pytest.raises(TypeError): + bad_pos = [{}, 0, 3, 2] + stat_method(ss_sizes, ss, None, None, row_pos, bad_pos, "branch") + with pytest.raises(ValueError, match="Cannot specify sites in branch mode"): + stat_method(ss_sizes, ss, row_sites, col_sites, None, None, "branch") + with pytest.raises(ValueError, match="Cannot specify positions in site mode"): + stat_method(ss_sizes, ss, None, None, row_pos, col_pos, "site") # C API errors - with pytest.raises(tskit.LibraryError, match="TSK_ERR_UNSORTED_SITES"): + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_UNSORTED_SITES"): bad_sites = np.array([1, 0, 2], dtype=np.int32) - stat_method(sample_set_sizes, sample_sets, bad_sites, col_sites, mode) - with pytest.raises(tskit.LibraryError, match="TSK_ERR_UNSORTED_SITES"): + stat_method(ss_sizes, ss, bad_sites, col_sites, None, None, "site") + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_UNSORTED_SITES"): bad_sites = np.array([1, 0, 2], dtype=np.int32) - stat_method(sample_set_sizes, sample_sets, row_sites, bad_sites, mode) + stat_method(ss_sizes, ss, row_sites, bad_sites, None, None, "site") + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_DUPLICATE_SITES"): + bad_sites = np.array([1, 1, 2], dtype=np.int32) + stat_method(ss_sizes, ss, bad_sites, col_sites, None, None, "site") + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_DUPLICATE_SITES"): + bad_sites = np.array([1, 1, 2], dtype=np.int32) + stat_method(ss_sizes, ss, row_sites, bad_sites, None, None, "site") + with pytest.raises(tskit.LibraryError, match="TSK_ERR_SITE_OUT_OF_BOUNDS"): + bad_sites = np.array([-1, 0, 2], dtype=np.int32) + stat_method(ss_sizes, ss, bad_sites, col_sites, None, None, "site") + with pytest.raises(tskit.LibraryError, match="TSK_ERR_SITE_OUT_OF_BOUNDS"): + bad_sites = np.array([-1, 0, 2], dtype=np.int32) + stat_method(ss_sizes, ss, row_sites, bad_sites, None, None, "site") + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_UNSORTED_POSITIONS"): + bad_pos = np.array([0.7, 0, 0.8], dtype=np.float64) + stat_method(ss_sizes, ss, None, None, bad_pos, col_pos, "branch") + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_UNSORTED_POSITIONS"): + bad_pos = np.array([0.7, 0, 0.8], dtype=np.float64) + stat_method(ss_sizes, ss, None, None, row_pos, bad_pos, "branch") + with pytest.raises( + tskit.LibraryError, match="TSK_ERR_STAT_DUPLICATE_POSITIONS" + ): + bad_pos = np.array([0.7, 0.7, 0.8], dtype=np.float64) + stat_method(ss_sizes, ss, None, None, bad_pos, col_pos, "branch") + with pytest.raises( + tskit.LibraryError, match="TSK_ERR_STAT_DUPLICATE_POSITIONS" + ): + bad_pos = np.array([0.7, 0.7, 0.8], dtype=np.float64) + stat_method(ss_sizes, ss, None, None, row_pos, bad_pos, "branch") + with pytest.raises(tskit.LibraryError, match="TSK_ERR_POSITION_OUT_OF_BOUNDS"): + bad_pos = np.array([-0.1, 0.7, 0.8], dtype=np.float64) + stat_method(ss_sizes, ss, None, None, bad_pos, col_pos, "branch") + with pytest.raises(tskit.LibraryError, match="TSK_ERR_POSITION_OUT_OF_BOUNDS"): + bad_pos = np.array([-0.1, 0.7, 0.8], dtype=np.float64) + stat_method(ss_sizes, ss, None, None, row_pos, bad_pos, "branch") with pytest.raises( _tskit.LibraryError, match="TSK_ERR_INSUFFICIENT_SAMPLE_SETS" ): - bad_sample_sets = np.array([], dtype=np.int32) - bad_sample_set_sizes = np.array([], dtype=np.uint32) - stat_method( - bad_sample_set_sizes, bad_sample_sets, row_sites, col_sites, mode - ) + bad_ss = np.array([], dtype=np.int32) + bad_ss_sizes = np.array([], dtype=np.uint32) + stat_method(bad_ss_sizes, bad_ss, row_sites, col_sites, None, None, "site") + with pytest.raises( + _tskit.LibraryError, match="TSK_ERR_INSUFFICIENT_SAMPLE_SETS" + ): + bad_ss = np.array([], dtype=np.int32) + bad_ss_sizes = np.array([], dtype=np.uint32) + stat_method(bad_ss_sizes, bad_ss, None, None, row_pos, col_pos, "branch") with pytest.raises(_tskit.LibraryError, match="TSK_ERR_EMPTY_SAMPLE_SET"): - bad_sample_sets = np.array([], dtype=np.int32) - bad_sample_set_sizes = np.array([0], dtype=np.uint32) - stat_method( - bad_sample_set_sizes, bad_sample_sets, row_sites, col_sites, mode - ) + bad_ss = np.array([], dtype=np.int32) + bad_ss_sizes = np.array([0], dtype=np.uint32) + stat_method(bad_ss_sizes, bad_ss, row_sites, col_sites, None, None, "site") + with pytest.raises(_tskit.LibraryError, match="TSK_ERR_EMPTY_SAMPLE_SET"): + bad_ss = np.array([], dtype=np.int32) + bad_ss_sizes = np.array([0], dtype=np.uint32) + stat_method(bad_ss_sizes, bad_ss, None, None, row_pos, col_pos, "branch") with pytest.raises(_tskit.LibraryError, match="TSK_ERR_NODE_OUT_OF_BOUNDS"): - bad_sample_sets = np.array([1000], dtype=np.int32) - bad_sample_set_sizes = np.array([1], dtype=np.uint32) - stat_method( - bad_sample_set_sizes, bad_sample_sets, row_sites, col_sites, mode - ) + bad_ss = np.array([1000], dtype=np.int32) + bad_ss_sizes = np.array([1], dtype=np.uint32) + stat_method(bad_ss_sizes, bad_ss, row_sites, col_sites, None, None, "site") + with pytest.raises(_tskit.LibraryError, match="TSK_ERR_NODE_OUT_OF_BOUNDS"): + bad_ss = np.array([1000], dtype=np.int32) + bad_ss_sizes = np.array([1], dtype=np.uint32) + stat_method(bad_ss_sizes, bad_ss, None, None, row_pos, col_pos, "branch") with pytest.raises(_tskit.LibraryError, match="TSK_ERR_DUPLICATE_SAMPLE"): - bad_sample_sets = np.array([2, 2], dtype=np.int32) - bad_sample_set_sizes = np.array([2], dtype=np.uint32) - stat_method( - bad_sample_set_sizes, bad_sample_sets, row_sites, col_sites, mode - ) + bad_ss = np.array([2, 2], dtype=np.int32) + bad_ss_sizes = np.array([2], dtype=np.uint32) + stat_method(bad_ss_sizes, bad_ss, row_sites, col_sites, None, None, "site") + with pytest.raises(_tskit.LibraryError, match="TSK_ERR_DUPLICATE_SAMPLE"): + bad_ss = np.array([2, 2], dtype=np.int32) + bad_ss_sizes = np.array([2], dtype=np.uint32) + stat_method(bad_ss_sizes, bad_ss, None, None, row_pos, col_pos, "branch") with pytest.raises(_tskit.LibraryError, match="TSK_ERR_UNSUPPORTED_STAT_MODE"): - stat_method(sample_set_sizes, sample_sets, row_sites, col_sites, "branch") + stat_method(ss_sizes, ss, col_sites, row_sites, None, None, "node") def test_kc_distance_errors(self): ts1 = self.get_example_tree_sequence(10) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 36706bdf50..942c9021fc 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -7576,31 +7576,53 @@ def __one_way_sample_set_stat( stat = stat[()] return stat + def parse_sites(self, sites): + row_sites, col_sites = None, None + if sites is not None: + if any(not hasattr(a, "__getitem__") or isinstance(a, str) for a in sites): + raise ValueError("Sites must be a list of lists, tuples, or ndarrays") + if len(sites) == 2: + row_sites, col_sites = sites + elif len(sites) == 1: + row_sites = col_sites = sites[0] + else: + raise ValueError( + f"Sites must be a length 1 or 2 list, got a length {len(sites)} list" + ) + return row_sites, col_sites + + def parse_positions(self, positions): + row_positions, col_positions = None, None + if positions is not None: + if any( + not hasattr(a, "__getitem__") or isinstance(a, str) for a in positions + ): + raise ValueError( + "Positions must be a list of lists, tuples, or ndarrays" + ) + if len(positions) == 2: + row_positions, col_positions = positions + elif len(positions) == 1: + row_positions = col_positions = positions[0] + else: + raise ValueError( + "Positions must be a length 1 or 2 list, " + f"got a length {len(positions)} list" + ) + return row_positions, col_positions + def __two_locus_sample_set_stat( self, ll_method, sample_sets, sites=None, + positions=None, mode=None, ): if sample_sets is None: sample_sets = self.samples() - if sites is not None and any( - not hasattr(a, "__getitem__") or isinstance(a, str) for a in sites - ): - raise ValueError("Sites must be a list of lists, tuples, or ndarrays") - - if sites is None: - row_sites = np.arange(self.num_sites, dtype=np.int32) - col_sites = np.arange(self.num_sites, dtype=np.int32) - elif len(sites) == 2: - row_sites, col_sites = sites - elif len(sites) == 1: - row_sites = col_sites = sites[0] - else: - raise ValueError( - f"Sites must be a length 1 or 2 list, got a length {len(sites)} list" - ) + row_sites, col_sites = self.parse_sites(sites) + row_positions, col_positions = self.parse_positions(positions) # First try to convert to a 1D numpy array. If we succeed, then we strip off # the corresponding dimension from the output. @@ -7624,7 +7646,15 @@ def __two_locus_sample_set_stat( flattened = util.safe_np_int_cast(np.hstack(sample_sets), np.int32) - result = ll_method(sample_set_sizes, flattened, row_sites, col_sites, mode) + result = ll_method( + sample_set_sizes, + flattened, + row_sites, + col_sites, + row_positions, + col_positions, + mode, + ) if drop_dimension: result = result.reshape(result.shape[:2]) @@ -9522,7 +9552,9 @@ def impute_unknown_mutations_time( mutations_time[unknown] = self.nodes_time[self.mutations_node[unknown]] return mutations_time - def ld_matrix(self, sample_sets=None, sites=None, mode="site", stat="r2"): + def ld_matrix( + self, sample_sets=None, sites=None, positions=None, mode="site", stat="r2" + ): stats = { "D": self._ll_tree_sequence.D_matrix, "D2": self._ll_tree_sequence.D2_matrix, @@ -9531,6 +9563,9 @@ def ld_matrix(self, sample_sets=None, sites=None, mode="site", stat="r2"): "r": self._ll_tree_sequence.r_matrix, "Dz": self._ll_tree_sequence.Dz_matrix, "pi2": self._ll_tree_sequence.pi2_matrix, + "Dz_unbiased": self._ll_tree_sequence.Dz_unbiased_matrix, + "D2_unbiased": self._ll_tree_sequence.D2_unbiased_matrix, + "pi2_unbiased": self._ll_tree_sequence.pi2_unbiased_matrix, } try: @@ -9544,6 +9579,7 @@ def ld_matrix(self, sample_sets=None, sites=None, mode="site", stat="r2"): two_locus_stat, sample_sets, sites=sites, + positions=positions, mode=mode, )