diff --git a/c/tests/test_stats.c b/c/tests/test_stats.c index 8d90bf7d45..a40817e7a9 100644 --- a/c/tests/test_stats.c +++ b/c/tests/test_stats.c @@ -2017,17 +2017,17 @@ test_empty_genetic_relatedness_vector(void) } ret = tsk_treeseq_genetic_relatedness_vector( - &ts, num_weights, weights, 1, windows, result, 0); + &ts, num_weights, weights, 1, windows, num_samples, ts.samples, result, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); - ret = tsk_treeseq_genetic_relatedness_vector( - &ts, num_weights, weights, 1, windows, result, TSK_STAT_NONCENTRED); + ret = tsk_treeseq_genetic_relatedness_vector(&ts, num_weights, weights, 1, windows, + num_samples, ts.samples, result, TSK_STAT_NONCENTRED); CU_ASSERT_EQUAL_FATAL(ret, 0); windows[0] = 0.5 * tsk_treeseq_get_sequence_length(&ts); windows[1] = 0.75 * tsk_treeseq_get_sequence_length(&ts); ret = tsk_treeseq_genetic_relatedness_vector( - &ts, num_weights, weights, 1, windows, result, 0); + &ts, num_weights, weights, 1, windows, num_samples, ts.samples, result, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); tsk_treeseq_free(&ts); @@ -2061,8 +2061,8 @@ verify_genetic_relatedness_vector( } } - ret = tsk_treeseq_genetic_relatedness_vector( - ts, num_weights, weights, num_windows, windows, result, 0); + ret = tsk_treeseq_genetic_relatedness_vector(ts, num_weights, weights, num_windows, + windows, num_samples, ts->samples, result, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); windows[0] = windows[1] / 2; @@ -2070,17 +2070,17 @@ verify_genetic_relatedness_vector( windows[num_windows - 1] = windows[num_windows - 2] + (L / (double) (2 * num_windows)); } - ret = tsk_treeseq_genetic_relatedness_vector( - ts, num_weights, weights, num_windows, windows, result, 0); + ret = tsk_treeseq_genetic_relatedness_vector(ts, num_weights, weights, num_windows, + windows, num_samples, ts->samples, result, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); - ret = tsk_treeseq_genetic_relatedness_vector( - ts, num_weights, weights, num_windows, windows, result, TSK_STAT_NONCENTRED); + ret = tsk_treeseq_genetic_relatedness_vector(ts, num_weights, weights, num_windows, + windows, num_samples, ts->samples, result, TSK_STAT_NONCENTRED); CU_ASSERT_EQUAL_FATAL(ret, 0); tsk_set_debug_stream(_devnull); - ret = tsk_treeseq_genetic_relatedness_vector( - ts, num_weights, weights, num_windows, windows, result, TSK_DEBUG); + ret = tsk_treeseq_genetic_relatedness_vector(ts, num_weights, weights, num_windows, + windows, num_samples, ts->samples, result, TSK_DEBUG); CU_ASSERT_EQUAL_FATAL(ret, 0); tsk_set_debug_stream(stdout); @@ -2117,6 +2117,7 @@ test_paper_ex_genetic_relatedness_vector_errors(void) tsk_size_t num_samples; double *weights, *result; tsk_size_t j; + tsk_size_t num_windows = 2; tsk_size_t num_weights = 2; double windows[] = { 0, 0, 0 }; @@ -2125,7 +2126,7 @@ test_paper_ex_genetic_relatedness_vector_errors(void) num_samples = tsk_treeseq_get_num_samples(&ts); weights = tsk_malloc(num_weights * num_samples * sizeof(double)); - result = tsk_malloc(num_weights * num_samples * sizeof(double)); + result = tsk_malloc(num_windows * num_weights * num_samples * sizeof(double)); for (j = 0; j < num_samples; j++) { weights[j] = 1.0; } @@ -2135,41 +2136,41 @@ test_paper_ex_genetic_relatedness_vector_errors(void) /* Window errors */ ret = tsk_treeseq_genetic_relatedness_vector( - &ts, 1, weights, 0, windows, result, TSK_STAT_BRANCH); + &ts, 1, weights, 0, windows, num_samples, ts.samples, result, TSK_STAT_BRANCH); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_NUM_WINDOWS); ret = tsk_treeseq_genetic_relatedness_vector( - &ts, 1, weights, 0, NULL, result, TSK_STAT_BRANCH); + &ts, 1, weights, 0, NULL, num_samples, ts.samples, result, TSK_STAT_BRANCH); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_NUM_WINDOWS); ret = tsk_treeseq_genetic_relatedness_vector( - &ts, 1, weights, 2, windows, result, TSK_STAT_BRANCH); + &ts, 1, weights, 2, windows, num_samples, ts.samples, result, TSK_STAT_BRANCH); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); windows[0] = -1; ret = tsk_treeseq_genetic_relatedness_vector( - &ts, 1, weights, 2, windows, result, TSK_STAT_BRANCH); + &ts, 1, weights, 2, windows, num_samples, ts.samples, result, TSK_STAT_BRANCH); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); windows[0] = 12; ret = tsk_treeseq_genetic_relatedness_vector( - &ts, 1, weights, 2, windows, result, TSK_STAT_BRANCH); + &ts, 1, weights, 2, windows, num_samples, ts.samples, result, TSK_STAT_BRANCH); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); windows[0] = 0; windows[2] = 12; ret = tsk_treeseq_genetic_relatedness_vector( - &ts, 1, weights, 2, windows, result, TSK_STAT_BRANCH); + &ts, 1, weights, 2, windows, num_samples, ts.samples, result, TSK_STAT_BRANCH); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS); /* unsupported mode errors */ windows[0] = 0.0; windows[1] = 5.0; windows[2] = 10.0; - ret = tsk_treeseq_genetic_relatedness_vector( - &ts, num_weights, weights, 2, windows, result, TSK_STAT_SITE); + ret = tsk_treeseq_genetic_relatedness_vector(&ts, num_weights, weights, 2, windows, + num_samples, ts.samples, result, TSK_STAT_SITE); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_UNSUPPORTED_STAT_MODE); - ret = tsk_treeseq_genetic_relatedness_vector( - &ts, num_weights, weights, 2, windows, result, TSK_STAT_NODE); + ret = tsk_treeseq_genetic_relatedness_vector(&ts, num_weights, weights, 2, windows, + num_samples, ts.samples, result, TSK_STAT_NODE); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_UNSUPPORTED_STAT_MODE); tsk_treeseq_free(&ts); @@ -2177,6 +2178,51 @@ test_paper_ex_genetic_relatedness_vector_errors(void) free(result); } +static void +test_paper_ex_genetic_relatedness_vector_node_errors(void) +{ + int ret; + tsk_treeseq_t ts; + tsk_size_t num_samples; + double *weights, *result; + tsk_size_t j; + tsk_size_t num_weights = 2; + tsk_size_t num_windows = 2; + double windows[] = { 1, 1.5, 2 }; + tsk_size_t num_nodes = 3; + const tsk_id_t good_nodes[] = { 1, 0, 2 }; + const tsk_id_t bad_nodes1[] = { 1, -1, 2 }; + const tsk_id_t bad_nodes2[] = { 1, 100, 2 }; + + tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, + paper_ex_mutations, paper_ex_individuals, NULL, 0); + num_samples = tsk_treeseq_get_num_samples(&ts); + + weights = tsk_malloc(num_weights * num_samples * sizeof(double)); + result = tsk_malloc(num_windows * num_weights * num_nodes * sizeof(double)); + for (j = 0; j < num_samples; j++) { + weights[j] = 1.0; + } + for (j = 0; j < num_samples; j++) { + weights[j + num_samples] = (float) j; + } + + /* node errors */ + ret = tsk_treeseq_genetic_relatedness_vector(&ts, num_weights, weights, 2, windows, + num_nodes, good_nodes, result, TSK_STAT_BRANCH); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_treeseq_genetic_relatedness_vector(&ts, num_weights, weights, 2, windows, + num_nodes, bad_nodes1, result, TSK_STAT_BRANCH); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + ret = tsk_treeseq_genetic_relatedness_vector(&ts, num_weights, weights, 2, windows, + num_nodes, bad_nodes2, result, TSK_STAT_BRANCH); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + + tsk_treeseq_free(&ts); + free(weights); + free(result); +} + static void test_paper_ex_Y2_errors(void) { @@ -3723,6 +3769,8 @@ main(int argc, char **argv) test_paper_ex_genetic_relatedness_vector }, { "test_paper_ex_genetic_relatedness_vector_errors", test_paper_ex_genetic_relatedness_vector_errors }, + { "test_paper_ex_genetic_relatedness_vector_node_errors", + test_paper_ex_genetic_relatedness_vector_node_errors }, { "test_paper_ex_Y2_errors", test_paper_ex_Y2_errors }, { "test_paper_ex_Y2", test_paper_ex_Y2 }, { "test_paper_ex_f2_errors", test_paper_ex_f2_errors }, diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 74305902bb..a554cfad55 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -9908,9 +9908,10 @@ typedef struct { const double *weights; tsk_size_t num_windows; const double *windows; + tsk_size_t num_focal_nodes; + const tsk_id_t *focal_nodes; tsk_flags_t options; double *result; - /* tree */ tsk_tree_position_t tree_pos; double position; tsk_size_t num_nodes; @@ -9929,6 +9930,7 @@ tsk_matvec_calculator_print_state(const tsk_matvec_calculator_t *self, FILE *out fprintf(out, "Matvec state:\n"); fprintf(out, "options = %d\n", self->options); fprintf(out, "position = %f\n", self->position); + fprintf(out, "focal nodes = %lld: [", (long long) self->num_focal_nodes); fprintf(out, "tree_pos:\n"); tsk_tree_position_print_state(&self->tree_pos, out); fprintf(out, "samples = %lld: [", (long long) num_samples); @@ -9946,7 +9948,8 @@ tsk_matvec_calculator_print_state(const tsk_matvec_calculator_t *self, FILE *out static int tsk_matvec_calculator_init(tsk_matvec_calculator_t *self, const tsk_treeseq_t *ts, tsk_size_t num_weights, const double *weights, tsk_size_t num_windows, - const double *windows, tsk_flags_t options, double *result) + const double *windows, tsk_size_t num_focal_nodes, const tsk_id_t *focal_nodes, + tsk_flags_t options, double *result) { int ret = 0; tsk_size_t num_samples = tsk_treeseq_get_num_samples(ts); @@ -9954,17 +9957,18 @@ tsk_matvec_calculator_init(tsk_matvec_calculator_t *self, const tsk_treeseq_t *t const double *row; double *new_row; tsk_size_t k; - tsk_id_t u, j; + tsk_id_t index, u, j; double *weight_means = tsk_malloc(num_weights * sizeof(*weight_means)); const tsk_size_t num_trees = ts->num_trees; const double *restrict breakpoints = ts->breakpoints; - tsk_id_t index; self->ts = ts; self->num_weights = num_weights; self->weights = weights; self->num_windows = num_windows; self->windows = windows; + self->num_focal_nodes = num_focal_nodes; + self->focal_nodes = focal_nodes; self->options = options; self->result = result; self->num_nodes = num_nodes; @@ -9981,9 +9985,16 @@ tsk_matvec_calculator_init(tsk_matvec_calculator_t *self, const tsk_treeseq_t *t goto out; } - tsk_memset(result, 0, num_windows * num_samples * num_weights * sizeof(*result)); + tsk_memset(result, 0, num_windows * num_focal_nodes * num_weights * sizeof(*result)); tsk_memset(self->parent, TSK_NULL, num_nodes * sizeof(*self->parent)); + for (j = 0; j < (tsk_id_t) num_focal_nodes; j++) { + if (focal_nodes[j] < 0 || (tsk_size_t) focal_nodes[j] >= num_nodes) { + ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + goto out; + } + } + ret = tsk_tree_position_init(&self->tree_pos, ts, 0); if (ret != 0) { goto out; @@ -10001,6 +10012,7 @@ tsk_matvec_calculator_init(tsk_matvec_calculator_t *self, const tsk_treeseq_t *t for (k = 0; k < num_weights; k++) { weight_means[k] = 0.0; } + /* centre the input */ if (!(options & TSK_STAT_NONCENTRED)) { for (j = 0; j < (tsk_id_t) num_samples; j++) { row = GET_2D_ROW(weights, num_weights, j); @@ -10013,6 +10025,7 @@ tsk_matvec_calculator_init(tsk_matvec_calculator_t *self, const tsk_treeseq_t *t } } + /* set the initial state */ for (j = 0; j < (tsk_id_t) num_samples; j++) { u = ts->samples[j]; row = GET_2D_ROW(weights, num_weights, j); @@ -10129,7 +10142,7 @@ tsk_matvec_calculator_write_output(tsk_matvec_calculator_t *self, double *restri int ret = 0; tsk_id_t u; tsk_size_t j, k; - tsk_size_t n = tsk_treeseq_get_num_samples(self->ts); + const tsk_size_t n = self->num_focal_nodes; const tsk_size_t num_weights = self->num_weights; const double position = self->position; double *u_row, *out_row; @@ -10139,7 +10152,7 @@ tsk_matvec_calculator_write_output(tsk_matvec_calculator_t *self, double *restri double *restrict x = self->x; double *restrict w = self->w; double *restrict v = self->v; - const tsk_id_t *restrict samples = self->ts->samples; + const tsk_id_t *restrict focal_nodes = self->focal_nodes; if (out_means == NULL) { ret = TSK_ERR_NO_MEMORY; @@ -10148,7 +10161,7 @@ tsk_matvec_calculator_write_output(tsk_matvec_calculator_t *self, double *restri for (j = 0; j < n; j++) { out_row = GET_2D_ROW(y, num_weights, j); - u = samples[j]; + u = focal_nodes[j]; while (u != TSK_NULL) { if (x[u] != position) { tsk_matvec_calculator_add_z( @@ -10195,7 +10208,7 @@ tsk_matvec_calculator_run(tsk_matvec_calculator_t *self) int ret = 0; tsk_size_t j, k, m; tsk_id_t e, p, c; - tsk_size_t n = tsk_treeseq_get_num_samples(self->ts); + const tsk_size_t out_size = self->num_weights * self->num_focal_nodes; const tsk_size_t num_edges = self->ts->tables->edges.num_rows; const double *restrict edge_right = self->ts->tables->edges.right; const double *restrict edge_left = self->ts->tables->edges.left; @@ -10253,7 +10266,7 @@ tsk_matvec_calculator_run(tsk_matvec_calculator_t *self) tsk_bug_assert(self->position < next_position); self->position = next_position; if (self->position == windows[m + 1]) { - out = GET_2D_ROW(self->result, self->num_weights * n, m); + out = GET_2D_ROW(self->result, out_size, m); tsk_matvec_calculator_write_output(self, out); m += 1; } @@ -10268,7 +10281,8 @@ tsk_matvec_calculator_run(tsk_matvec_calculator_t *self) int tsk_treeseq_genetic_relatedness_vector(const tsk_treeseq_t *self, tsk_size_t num_weights, - const double *weights, tsk_size_t num_windows, const double *windows, double *result, + const double *weights, tsk_size_t num_windows, const double *windows, + tsk_size_t num_focal_nodes, const tsk_id_t *focal_nodes, double *result, tsk_flags_t options) { int ret = 0; @@ -10287,8 +10301,8 @@ tsk_treeseq_genetic_relatedness_vector(const tsk_treeseq_t *self, tsk_size_t num goto out; } - ret = tsk_matvec_calculator_init( - &calc, self, num_weights, weights, num_windows, windows, options, result); + ret = tsk_matvec_calculator_init(&calc, self, num_weights, weights, num_windows, + windows, num_focal_nodes, focal_nodes, options, result); if (ret != 0) { goto out; } diff --git a/c/tskit/trees.h b/c/tskit/trees.h index df9cf92850..bef944fff3 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -1028,12 +1028,14 @@ int tsk_treeseq_genetic_relatedness_weighted(const tsk_treeseq_t *self, /* One way weighted stats with vector output */ typedef int weighted_vector_method(const tsk_treeseq_t *self, tsk_size_t num_weights, - const double *weights, tsk_size_t num_windows, const double *windows, double *result, + const double *weights, tsk_size_t num_windows, const double *windows, + tsk_size_t num_focal_nodes, const tsk_id_t *focal_nodes, double *result, tsk_flags_t options); int tsk_treeseq_genetic_relatedness_vector(const tsk_treeseq_t *self, tsk_size_t num_weights, const double *weights, tsk_size_t num_windows, - const double *windows, double *result, tsk_flags_t options); + const double *windows, tsk_size_t num_focal_nodes, const tsk_id_t *focal_nodes, + double *result, tsk_flags_t options); /* One way sample set stats */ diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 6d275a499a..8663bd8695 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -9652,14 +9652,17 @@ TreeSequence_weighted_stat_vector_method( { PyObject *ret = NULL; static char *kwlist[] - = { "weights", "windows", "mode", "span_normalise", "centre", NULL }; + = { "weights", "windows", "mode", "span_normalise", "centre", "nodes", NULL }; PyObject *weights = NULL; PyObject *windows = NULL; + PyObject *focal_nodes = NULL; PyArrayObject *weights_array = NULL; PyArrayObject *windows_array = NULL; PyArrayObject *result_array = NULL; + PyArrayObject *focal_nodes_array = NULL; tsk_size_t num_windows; - npy_intp *w_shape, result_shape[3]; + tsk_size_t num_focal_nodes; + npy_intp *focal_nodes_shape, *w_shape, result_shape[3]; tsk_flags_t options = 0; tsk_size_t num_samples; char *mode = NULL; @@ -9670,8 +9673,8 @@ TreeSequence_weighted_stat_vector_method( if (TreeSequence_check_state(self) != 0) { goto out; } - if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO|sii", kwlist, &weights, &windows, - &mode, &span_normalise, ¢re)) { + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO|siiO", kwlist, &weights, &windows, + &mode, &span_normalise, ¢re, &focal_nodes)) { goto out; } if (parse_stats_mode(mode, &options) != 0) { @@ -9697,16 +9700,24 @@ TreeSequence_weighted_stat_vector_method( PyErr_SetString(PyExc_ValueError, "First dimension must be num_samples"); goto out; } + focal_nodes_array = (PyArrayObject *) PyArray_FROMANY( + focal_nodes, NPY_INT32, 1, 1, NPY_ARRAY_IN_ARRAY); + if (focal_nodes_array == NULL) { + goto out; + } + focal_nodes_shape = PyArray_DIMS(focal_nodes_array); + num_focal_nodes = focal_nodes_shape[0]; result_shape[0] = num_windows; - result_shape[1] = num_samples; + result_shape[1] = num_focal_nodes; result_shape[2] = w_shape[1]; result_array = (PyArrayObject *) PyArray_SimpleNew(3, result_shape, NPY_FLOAT64); if (result_array == NULL) { goto out; } err = method(self->tree_sequence, w_shape[1], PyArray_DATA(weights_array), - num_windows, PyArray_DATA(windows_array), PyArray_DATA(result_array), options); + num_windows, PyArray_DATA(windows_array), num_focal_nodes, + PyArray_DATA(focal_nodes_array), PyArray_DATA(result_array), options); if (err != 0) { handle_library_error(err); goto out; @@ -9716,6 +9727,7 @@ TreeSequence_weighted_stat_vector_method( out: Py_XDECREF(weights_array); Py_XDECREF(windows_array); + Py_XDECREF(focal_nodes_array); Py_XDECREF(result_array); return ret; } diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index 5093d3a027..c07f975f83 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -2682,6 +2682,7 @@ def get_example(self, num_weights=2): (num_samples, num_weights) ), "windows": [0, ts.get_sequence_length()], + "nodes": list(ts.get_samples()), } return ts, params @@ -2690,18 +2691,38 @@ def get_example(self, num_weights=2): def test_basic_example(self, mode, num_weights): ts, params = self.get_example(num_weights) ns = ts.get_num_samples() - result = ts.genetic_relatedness_vector( - params["weights"], params["windows"], mode, True, True - ) - assert result.shape == (1, ns, num_weights) - result = ts.genetic_relatedness_vector( - params["weights"], params["windows"], mode, True, False - ) - assert result.shape == (1, ns, num_weights) - result = ts.genetic_relatedness_vector( - params["weights"], params["windows"], mode, False, True - ) - assert result.shape == (1, ns, num_weights) + params["mode"] = mode + for a, b in ([True, True], [True, False], [False, True]): + params["span_normalise"] = a + params["centre"] = b + result = ts.genetic_relatedness_vector(**params) + assert result.shape == (1, ns, num_weights) + + @pytest.mark.parametrize("mode", ["branch"]) + def test_good_nodes(self, mode): + num_weights = 2 + ts, params = self.get_example(num_weights) + params["mode"] = mode + for nodes in [ + list(ts.get_samples())[:3], + list(ts.get_samples())[:1], + [0, ts.get_num_nodes() - 1], + ]: + params["nodes"] = nodes + result = ts.genetic_relatedness_vector(**params) + assert result.shape == (1, len(nodes), num_weights) + + def test_bad_nodes(self): + ts, params = self.get_example() + params["mode"] = "branch" + for nodes in ["abc", [[1, 2]]]: + params["nodes"] = nodes + with pytest.raises(ValueError, match="desired array"): + ts.genetic_relatedness_vector(**params) + for nodes in [[-1, 3], [3, 2 * ts.get_num_nodes()]]: + params["nodes"] = nodes + with pytest.raises(_tskit.LibraryError, match="TSK_ERR_NODE_OUT_OF_BOUNDS"): + ts.genetic_relatedness_vector(**params) def test_bad_args(self): ts, params = self.get_example() @@ -2727,10 +2748,9 @@ def test_bad_args(self): @pytest.mark.parametrize("mode", ["site", "node"]) def test_modes_not_supported(self, mode): ts, params = self.get_example() + params["mode"] = mode with pytest.raises(_tskit.LibraryError): - ts.genetic_relatedness_vector( - params["weights"], params["windows"], mode, True, True - ) + ts.genetic_relatedness_vector(**params) @pytest.mark.parametrize("mode", ["branch"]) def test_bad_weights(self, mode): diff --git a/python/tests/test_relatedness_vector.py b/python/tests/test_relatedness_vector.py index 82a8d80119..f765c75c9f 100644 --- a/python/tests/test_relatedness_vector.py +++ b/python/tests/test_relatedness_vector.py @@ -46,6 +46,7 @@ def __init__( windows, num_nodes, samples, + focal_nodes, nodes_time, edges_left, edges_right, @@ -70,6 +71,7 @@ def __init__( self.sequence_length = sequence_length self.nodes_time = nodes_time self.samples = samples + self.focal_nodes = focal_nodes self.tree_pos = tree_pos self.position = windows[0] self.x = np.zeros(N, dtype=np.float64) @@ -192,9 +194,9 @@ def write_output(self): Compute and return the current state, zero-ing out all contributions (used for switching between windows). """ - n = len(self.samples) + n = len(self.focal_nodes) out = np.zeros((n, self.num_weights)) - for j, c in enumerate(self.samples): + for j, c in enumerate(self.focal_nodes): while c != tskit.NULL: if self.x[c] != self.position: self.v[c] += self.get_z(c) @@ -210,9 +212,9 @@ def current_state(self): """ if self.verbosity > 2: print("---------------") - n = len(self.samples) + n = len(self.focal_nodes) out = np.zeros((n, self.num_weights)) - for j, a in enumerate(self.samples): + for j, a in enumerate(self.focal_nodes): # edges on the path up from a pa = a while pa != tskit.NULL: @@ -234,7 +236,9 @@ def run(self): in_order = tree_pos.in_range.order out_order = tree_pos.out_range.order num_windows = len(self.windows) - 1 - out = np.zeros((num_windows,) + self.sample_weights.shape) + out = np.zeros( + (num_windows, len(self.focal_nodes), self.sample_weights.shape[1]) + ) m = 0 self.position = self.windows[0] @@ -284,9 +288,11 @@ def run(self): return out -def relatedness_vector(ts, sample_weights, windows=None, **kwargs): +def relatedness_vector(ts, sample_weights, windows=None, nodes=None, **kwargs): if len(sample_weights.shape) == 1: sample_weights = sample_weights[:, np.newaxis] + if nodes is None: + nodes = np.fromiter(ts.samples(), dtype=np.int32) drop_dimension = windows is None if drop_dimension: windows = [0, ts.sequence_length] @@ -303,6 +309,7 @@ def relatedness_vector(ts, sample_weights, windows=None, **kwargs): windows, ts.num_nodes, samples=ts.samples(), + focal_nodes=nodes, nodes_time=ts.nodes_time, edges_left=ts.edges_left, edges_right=ts.edges_right, @@ -319,7 +326,24 @@ def relatedness_vector(ts, sample_weights, windows=None, **kwargs): return out -def relatedness_matrix(ts, windows, centre): +def relatedness_matrix(ts, windows, centre, nodes=None): + if nodes is None: + keep_rows = np.arange(ts.num_samples) + keep_cols = np.arange(ts.num_samples) + else: + orig_samples = list(ts.samples()) + extra_nodes = set(nodes).difference(set(orig_samples)) + tables = ts.dump_tables() + tables.nodes.clear() + for n in ts.nodes(): + if n.id in extra_nodes: + n = n.replace(flags=n.flags | tskit.NODE_IS_SAMPLE) + tables.nodes.append(n) + ts = tables.tree_sequence() + all_samples = list(ts.samples()) + keep_rows = np.array([all_samples.index(i) for i in nodes]) + keep_cols = np.array([all_samples.index(i) for i in orig_samples]) + use_windows = windows drop_first = windows is not None and windows[0] > 0 if drop_first: @@ -341,14 +365,17 @@ def relatedness_matrix(ts, windows, centre): Sigma = Sigma[1:] if drop_last: Sigma = Sigma[:-1] - shape = (len(windows) - 1, ts.num_samples, ts.num_samples) - else: - shape = (ts.num_samples, ts.num_samples) - return Sigma.reshape(shape) + nwin = 1 if windows is None else len(windows) - 1 + shape = (nwin, ts.num_samples, ts.num_samples) + Sigma = Sigma.reshape(shape) + out = np.array([x[np.ix_(keep_rows, keep_cols)] for x in Sigma]) + if windows is None: + out = out[0] + return out def verify_relatedness_vector( - ts, w, windows, *, internal_checks=False, verbosity=0, centre=True + ts, w, windows, *, internal_checks=False, verbosity=0, centre=True, nodes=None ): R1 = relatedness_vector( ts, @@ -357,35 +384,48 @@ def verify_relatedness_vector( internal_checks=internal_checks, verbosity=verbosity, centre=centre, + nodes=nodes, ) + nrows = ts.num_samples if nodes is None else len(nodes) wvec = w if len(w.shape) > 1 else w[:, np.newaxis] - Sigma = relatedness_matrix(ts, windows=windows, centre=centre) + Sigma = relatedness_matrix(ts, windows=windows, centre=centre, nodes=nodes) if windows is None: R2 = Sigma.dot(wvec) else: - R2 = np.zeros((len(windows) - 1, ts.num_samples, wvec.shape[1])) + R2 = np.zeros((len(windows) - 1, nrows, wvec.shape[1])) for k in range(len(windows) - 1): R2[k] = Sigma[k].dot(wvec) - R3 = ts.genetic_relatedness_vector(w, windows=windows, mode="branch", centre=centre) + R3 = ts.genetic_relatedness_vector( + w, windows=windows, mode="branch", centre=centre, nodes=nodes + ) if verbosity > 0: print(ts.draw_text()) print("weights:", w) print("windows:", windows) + print("centre:", centre) print("here:", R1) print("with ts:", R2) print("with lib:", R3) print("Sigma:", Sigma) if windows is None: - assert R1.shape == (ts.num_samples, wvec.shape[1]) + assert R1.shape == (nrows, wvec.shape[1]) else: - assert R1.shape == (len(windows) - 1, ts.num_samples, wvec.shape[1]) - np.testing.assert_allclose(R1, R2, atol=1e-13) - np.testing.assert_allclose(R1, R3, atol=1e-13) + assert R1.shape == (len(windows) - 1, nrows, wvec.shape[1]) + np.testing.assert_allclose(R1, R2, atol=1e-10) + np.testing.assert_allclose(R1, R3, atol=1e-10) return R1 def check_relatedness_vector( - ts, n=2, num_windows=0, *, internal_checks=False, verbosity=0, seed=123, centre=True + ts, + n=2, + num_windows=0, + *, + internal_checks=False, + verbosity=0, + seed=123, + centre=True, + do_nodes=True, ): rng = np.random.default_rng(seed=seed) if num_windows == 0: @@ -396,20 +436,27 @@ def check_relatedness_vector( ) else: windows = np.linspace(0, ts.sequence_length, num_windows + 1) - for k in range(n): - if k == 0: - w = rng.normal(size=ts.num_samples) + num_nodes_list = (0,) if (centre or not do_nodes) else (0, 3) + for num_nodes in num_nodes_list: + if num_nodes == 0: + nodes = None else: - w = rng.normal(size=ts.num_samples * k).reshape((ts.num_samples, k)) - w = np.round(len(w) * w) - R = verify_relatedness_vector( - ts, - w, - windows, - internal_checks=internal_checks, - verbosity=verbosity, - centre=centre, - ) + nodes = rng.choice(ts.num_nodes, num_nodes, replace=False) + for k in range(n): + if k == 0: + w = rng.normal(size=ts.num_samples) + else: + w = rng.normal(size=ts.num_samples * k).reshape((ts.num_samples, k)) + w = np.round(len(w) * w) + R = verify_relatedness_vector( + ts, + w, + windows, + internal_checks=internal_checks, + verbosity=verbosity, + centre=centre, + nodes=nodes, + ) return R @@ -441,6 +488,83 @@ def test_bad_windows(self): np.ones(ts.num_samples), windows=bad_w, mode="branch" ) + def test_nodes_centred_error(self): + ts = msprime.sim_ancestry( + 5, + ploidy=2, + sequence_length=10, + random_seed=123, + ) + with pytest.raises(ValueError, match="must have centre"): + ts.genetic_relatedness_vector( + np.ones(ts.num_samples), mode="branch", centre=True, nodes=[0, 1] + ) + + def test_bad_nodes(self): + n = 5 + ts = msprime.sim_ancestry( + n, + ploidy=2, + sequence_length=10, + random_seed=123, + ) + for bad_nodes in ([[]], "foo"): + with pytest.raises(ValueError): + ts.genetic_relatedness_vector( + np.ones(ts.num_samples), + mode="branch", + centre=False, + nodes=bad_nodes, + ) + for bad_nodes in ([-1, 10], [3, 2 * ts.num_nodes]): + with pytest.raises(tskit.LibraryError, match="TSK_ERR_NODE_OUT_OF_BOUNDS"): + ts.genetic_relatedness_vector( + np.ones(ts.num_samples), + mode="branch", + centre=False, + nodes=bad_nodes, + ) + + def test_good_nodes(self): + n = 5 + ts = msprime.sim_ancestry( + n, + ploidy=2, + sequence_length=10, + random_seed=123, + ) + V0 = ts.genetic_relatedness_vector( + np.ones(ts.num_samples), mode="branch", centre=False + ) + V = ts.genetic_relatedness_vector( + np.ones(ts.num_samples), + mode="branch", + centre=False, + nodes=list(ts.samples()), + ) + np.testing.assert_allclose(V0, V, atol=1e-13) + V = ts.genetic_relatedness_vector( + np.ones(ts.num_samples), + mode="branch", + centre=False, + nodes=np.fromiter(ts.samples(), dtype=np.int32), + ) + np.testing.assert_allclose(V0, V, atol=1e-13) + V = ts.genetic_relatedness_vector( + np.ones(ts.num_samples), + mode="branch", + centre=False, + nodes=np.fromiter(ts.samples(), dtype=np.int64), + ) + np.testing.assert_allclose(V0, V, atol=1e-13) + V = ts.genetic_relatedness_vector( + np.ones(ts.num_samples), + mode="branch", + centre=False, + nodes=list(ts.samples())[:2], + ) + np.testing.assert_allclose(V0[:2], V, atol=1e-13) + @pytest.mark.parametrize("n", [2, 3, 5]) @pytest.mark.parametrize("seed", range(1, 4)) @pytest.mark.parametrize("centre", (True, False)) @@ -565,7 +689,7 @@ def test_dangling_on_samples(self, n): # Adding non sample branches below the samples does not alter # the overall divergence *between* the samples ts1 = tskit.Tree.generate_balanced(n).tree_sequence - D1 = check_relatedness_vector(ts1) + D1 = check_relatedness_vector(ts1, do_nodes=False) tables = ts1.dump_tables() for u in ts1.samples(): v = tables.nodes.add_row(time=-1) @@ -573,7 +697,7 @@ def test_dangling_on_samples(self, n): tables.sort() tables.build_index() ts2 = tables.tree_sequence() - D2 = check_relatedness_vector(ts2, internal_checks=True) + D2 = check_relatedness_vector(ts2, internal_checks=True, do_nodes=False) np.testing.assert_array_almost_equal(D1, D2) @pytest.mark.parametrize("n", [2, 3, 10]) @@ -582,7 +706,7 @@ def test_dangling_on_all(self, n, centre): # Adding non sample branches below the samples does not alter # the overall divergence *between* the samples ts1 = tskit.Tree.generate_balanced(n).tree_sequence - D1 = check_relatedness_vector(ts1, centre=centre) + D1 = check_relatedness_vector(ts1, centre=centre, do_nodes=False) tables = ts1.dump_tables() for u in range(ts1.num_nodes): v = tables.nodes.add_row(time=-1) @@ -590,7 +714,9 @@ def test_dangling_on_all(self, n, centre): tables.sort() tables.build_index() ts2 = tables.tree_sequence() - D2 = check_relatedness_vector(ts2, internal_checks=True, centre=centre) + D2 = check_relatedness_vector( + ts2, internal_checks=True, centre=centre, do_nodes=False + ) np.testing.assert_array_almost_equal(D1, D2) @pytest.mark.parametrize("centre", (True, False)) @@ -598,7 +724,7 @@ def test_disconnected_non_sample_topology(self, centre): # Adding non sample branches below the samples does not alter # the overall divergence *between* the samples ts1 = tskit.Tree.generate_balanced(5).tree_sequence - D1 = check_relatedness_vector(ts1, centre=centre) + D1 = check_relatedness_vector(ts1, centre=centre, do_nodes=False) tables = ts1.dump_tables() # Add an extra bit of disconnected non-sample topology u = tables.nodes.add_row(time=0) @@ -607,5 +733,7 @@ def test_disconnected_non_sample_topology(self, centre): tables.sort() tables.build_index() ts2 = tables.tree_sequence() - D2 = check_relatedness_vector(ts2, internal_checks=True, centre=centre) + D2 = check_relatedness_vector( + ts2, internal_checks=True, centre=centre, do_nodes=False + ) np.testing.assert_array_almost_equal(D1, D2) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index e617c7ea8d..3d4637829d 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -7858,10 +7858,20 @@ def __weighted_vector_stat( mode=None, span_normalise=True, centre=True, + nodes=None, ): W = np.asarray(W) if len(W.shape) == 1: W = W.reshape(W.shape[0], 1) + if nodes is None: + nodes = list(self.samples()) + else: + if centre: + raise ValueError("If `nodes` is provided, must have centre=False.") + try: + nodes = util.safe_np_int_cast(nodes, np.int32) + except Exception: + raise ValueError("Could not interpret `nodes` as a list of node IDs.") stat = self.__run_windowed_stat( windows, ll_method, @@ -7869,6 +7879,7 @@ def __weighted_vector_stat( mode=mode, span_normalise=span_normalise, centre=centre, + nodes=nodes, ) return stat @@ -8518,20 +8529,31 @@ def genetic_relatedness_vector( mode="site", span_normalise=True, centre=True, + nodes=None, ): r""" Computes the product of the genetic relatedness matrix and a vector of weights (one per sample). The output is a (num windows) x (num samples) x (num weights) - array whose :math:`(i,j)`-th element is :math:`\sum_{b} W_{bj} C_{ib}`, + array whose :math:`(w,i,j)`-th element is :math:`\sum_{b} W_{bj} C_{ib}`, where :math:`W` is the matrix of weights, and :math:`C_{ab}` is the :meth:`genetic_relatedness <.TreeSequence.genetic_relatedness>` between sample - a and sample b, and the sum is over all samples in the tree sequence. - Like other statistics, if windows is None, the first dimension in the output is - dropped. + `a` and sample `b` in window `w`, and the sum is over all samples in the tree + sequence. Like other statistics, if windows is None, the first dimension in + the output is dropped. The relatedness used here corresponds to `polarised=True`; no unpolarised option is available for this method. + Optionally, you may provide a list of focal nodes that modifies the behavior + as follows. If `nodes` is a list of `n` node IDs (that do not need to be + samples), then the output will have dimension (num windows) x n x (num weights), + and the matrix :math:`C` used in the definition above is the rectangular matrix + with :math:`C_{ij}` the relatedness between `nodes[i]` and `samples[j]`. This + can only be used with `centre=False`; if relatedness between uncentred nodes + and centred samples is desired, then simply subtract column means from `W` first. + The default is `nodes=None`, which is equivalent to setting `nodes` equal to + `ts.samples()`. + :param numpy.ndarray W: An array of values with one row for each sample node and one column for each set of weights. :param list windows: An increasing list of breakpoints between the windows @@ -8542,6 +8564,8 @@ def genetic_relatedness_vector( window (defaults to True). :param bool centre: Whether to use the *centred* relatedness matrix or not: see :meth:`genetic_relatedness <.TreeSequence.genetic_relatedness>`. + :param list nodes: Optionally, a list of focal nodes as described above + (default: None). :return: A ndarray with shape equal to (num windows, num samples, num weights), or (num samples, num weights) if windows is None. """ @@ -8549,6 +8573,7 @@ def genetic_relatedness_vector( raise ValueError( "First weight dimension must be equal to number of samples." ) + out = self.__weighted_vector_stat( self._ll_tree_sequence.genetic_relatedness_vector, W, @@ -8556,6 +8581,7 @@ def genetic_relatedness_vector( mode=mode, span_normalise=span_normalise, centre=centre, + nodes=nodes, ) return out