Skip to content

Commit

Permalink
allow genetic_relatedness_vector to work with arbitrary nodes; closes #…
Browse files Browse the repository at this point in the history
  • Loading branch information
petrelharp authored and mergify[bot] committed Sep 27, 2024
1 parent d69d65b commit bc5a73c
Show file tree
Hide file tree
Showing 7 changed files with 352 additions and 102 deletions.
94 changes: 71 additions & 23 deletions c/tests/test_stats.c
Original file line number Diff line number Diff line change
Expand Up @@ -2017,17 +2017,17 @@ test_empty_genetic_relatedness_vector(void)
}

ret = tsk_treeseq_genetic_relatedness_vector(
&ts, num_weights, weights, 1, windows, result, 0);
&ts, num_weights, weights, 1, windows, num_samples, ts.samples, result, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);

ret = tsk_treeseq_genetic_relatedness_vector(
&ts, num_weights, weights, 1, windows, result, TSK_STAT_NONCENTRED);
ret = tsk_treeseq_genetic_relatedness_vector(&ts, num_weights, weights, 1, windows,
num_samples, ts.samples, result, TSK_STAT_NONCENTRED);
CU_ASSERT_EQUAL_FATAL(ret, 0);

windows[0] = 0.5 * tsk_treeseq_get_sequence_length(&ts);
windows[1] = 0.75 * tsk_treeseq_get_sequence_length(&ts);
ret = tsk_treeseq_genetic_relatedness_vector(
&ts, num_weights, weights, 1, windows, result, 0);
&ts, num_weights, weights, 1, windows, num_samples, ts.samples, result, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);

tsk_treeseq_free(&ts);
Expand Down Expand Up @@ -2061,26 +2061,26 @@ verify_genetic_relatedness_vector(
}
}

ret = tsk_treeseq_genetic_relatedness_vector(
ts, num_weights, weights, num_windows, windows, result, 0);
ret = tsk_treeseq_genetic_relatedness_vector(ts, num_weights, weights, num_windows,
windows, num_samples, ts->samples, result, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);

windows[0] = windows[1] / 2;
if (num_windows > 1) {
windows[num_windows - 1]
= windows[num_windows - 2] + (L / (double) (2 * num_windows));
}
ret = tsk_treeseq_genetic_relatedness_vector(
ts, num_weights, weights, num_windows, windows, result, 0);
ret = tsk_treeseq_genetic_relatedness_vector(ts, num_weights, weights, num_windows,
windows, num_samples, ts->samples, result, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);

ret = tsk_treeseq_genetic_relatedness_vector(
ts, num_weights, weights, num_windows, windows, result, TSK_STAT_NONCENTRED);
ret = tsk_treeseq_genetic_relatedness_vector(ts, num_weights, weights, num_windows,
windows, num_samples, ts->samples, result, TSK_STAT_NONCENTRED);
CU_ASSERT_EQUAL_FATAL(ret, 0);

tsk_set_debug_stream(_devnull);
ret = tsk_treeseq_genetic_relatedness_vector(
ts, num_weights, weights, num_windows, windows, result, TSK_DEBUG);
ret = tsk_treeseq_genetic_relatedness_vector(ts, num_weights, weights, num_windows,
windows, num_samples, ts->samples, result, TSK_DEBUG);
CU_ASSERT_EQUAL_FATAL(ret, 0);
tsk_set_debug_stream(stdout);

Expand Down Expand Up @@ -2117,6 +2117,7 @@ test_paper_ex_genetic_relatedness_vector_errors(void)
tsk_size_t num_samples;
double *weights, *result;
tsk_size_t j;
tsk_size_t num_windows = 2;
tsk_size_t num_weights = 2;
double windows[] = { 0, 0, 0 };

Expand All @@ -2125,7 +2126,7 @@ test_paper_ex_genetic_relatedness_vector_errors(void)
num_samples = tsk_treeseq_get_num_samples(&ts);

weights = tsk_malloc(num_weights * num_samples * sizeof(double));
result = tsk_malloc(num_weights * num_samples * sizeof(double));
result = tsk_malloc(num_windows * num_weights * num_samples * sizeof(double));
for (j = 0; j < num_samples; j++) {
weights[j] = 1.0;
}
Expand All @@ -2135,48 +2136,93 @@ test_paper_ex_genetic_relatedness_vector_errors(void)

/* Window errors */
ret = tsk_treeseq_genetic_relatedness_vector(
&ts, 1, weights, 0, windows, result, TSK_STAT_BRANCH);
&ts, 1, weights, 0, windows, num_samples, ts.samples, result, TSK_STAT_BRANCH);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_NUM_WINDOWS);
ret = tsk_treeseq_genetic_relatedness_vector(
&ts, 1, weights, 0, NULL, result, TSK_STAT_BRANCH);
&ts, 1, weights, 0, NULL, num_samples, ts.samples, result, TSK_STAT_BRANCH);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_NUM_WINDOWS);

ret = tsk_treeseq_genetic_relatedness_vector(
&ts, 1, weights, 2, windows, result, TSK_STAT_BRANCH);
&ts, 1, weights, 2, windows, num_samples, ts.samples, result, TSK_STAT_BRANCH);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS);

windows[0] = -1;
ret = tsk_treeseq_genetic_relatedness_vector(
&ts, 1, weights, 2, windows, result, TSK_STAT_BRANCH);
&ts, 1, weights, 2, windows, num_samples, ts.samples, result, TSK_STAT_BRANCH);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS);

windows[0] = 12;
ret = tsk_treeseq_genetic_relatedness_vector(
&ts, 1, weights, 2, windows, result, TSK_STAT_BRANCH);
&ts, 1, weights, 2, windows, num_samples, ts.samples, result, TSK_STAT_BRANCH);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS);

windows[0] = 0;
windows[2] = 12;
ret = tsk_treeseq_genetic_relatedness_vector(
&ts, 1, weights, 2, windows, result, TSK_STAT_BRANCH);
&ts, 1, weights, 2, windows, num_samples, ts.samples, result, TSK_STAT_BRANCH);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_WINDOWS);

/* unsupported mode errors */
windows[0] = 0.0;
windows[1] = 5.0;
windows[2] = 10.0;
ret = tsk_treeseq_genetic_relatedness_vector(
&ts, num_weights, weights, 2, windows, result, TSK_STAT_SITE);
ret = tsk_treeseq_genetic_relatedness_vector(&ts, num_weights, weights, 2, windows,
num_samples, ts.samples, result, TSK_STAT_SITE);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_UNSUPPORTED_STAT_MODE);
ret = tsk_treeseq_genetic_relatedness_vector(
&ts, num_weights, weights, 2, windows, result, TSK_STAT_NODE);
ret = tsk_treeseq_genetic_relatedness_vector(&ts, num_weights, weights, 2, windows,
num_samples, ts.samples, result, TSK_STAT_NODE);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_UNSUPPORTED_STAT_MODE);

tsk_treeseq_free(&ts);
free(weights);
free(result);
}

static void
test_paper_ex_genetic_relatedness_vector_node_errors(void)
{
int ret;
tsk_treeseq_t ts;
tsk_size_t num_samples;
double *weights, *result;
tsk_size_t j;
tsk_size_t num_weights = 2;
tsk_size_t num_windows = 2;
double windows[] = { 1, 1.5, 2 };
tsk_size_t num_nodes = 3;
const tsk_id_t good_nodes[] = { 1, 0, 2 };
const tsk_id_t bad_nodes1[] = { 1, -1, 2 };
const tsk_id_t bad_nodes2[] = { 1, 100, 2 };

tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites,
paper_ex_mutations, paper_ex_individuals, NULL, 0);
num_samples = tsk_treeseq_get_num_samples(&ts);

weights = tsk_malloc(num_weights * num_samples * sizeof(double));
result = tsk_malloc(num_windows * num_weights * num_nodes * sizeof(double));
for (j = 0; j < num_samples; j++) {
weights[j] = 1.0;
}
for (j = 0; j < num_samples; j++) {
weights[j + num_samples] = (float) j;
}

/* node errors */
ret = tsk_treeseq_genetic_relatedness_vector(&ts, num_weights, weights, 2, windows,
num_nodes, good_nodes, result, TSK_STAT_BRANCH);
CU_ASSERT_EQUAL_FATAL(ret, 0);
ret = tsk_treeseq_genetic_relatedness_vector(&ts, num_weights, weights, 2, windows,
num_nodes, bad_nodes1, result, TSK_STAT_BRANCH);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS);
ret = tsk_treeseq_genetic_relatedness_vector(&ts, num_weights, weights, 2, windows,
num_nodes, bad_nodes2, result, TSK_STAT_BRANCH);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS);

tsk_treeseq_free(&ts);
free(weights);
free(result);
}

static void
test_paper_ex_Y2_errors(void)
{
Expand Down Expand Up @@ -3723,6 +3769,8 @@ main(int argc, char **argv)
test_paper_ex_genetic_relatedness_vector },
{ "test_paper_ex_genetic_relatedness_vector_errors",
test_paper_ex_genetic_relatedness_vector_errors },
{ "test_paper_ex_genetic_relatedness_vector_node_errors",
test_paper_ex_genetic_relatedness_vector_node_errors },
{ "test_paper_ex_Y2_errors", test_paper_ex_Y2_errors },
{ "test_paper_ex_Y2", test_paper_ex_Y2 },
{ "test_paper_ex_f2_errors", test_paper_ex_f2_errors },
Expand Down
40 changes: 27 additions & 13 deletions c/tskit/trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -9908,9 +9908,10 @@ typedef struct {
const double *weights;
tsk_size_t num_windows;
const double *windows;
tsk_size_t num_focal_nodes;
const tsk_id_t *focal_nodes;
tsk_flags_t options;
double *result;
/* tree */
tsk_tree_position_t tree_pos;
double position;
tsk_size_t num_nodes;
Expand All @@ -9929,6 +9930,7 @@ tsk_matvec_calculator_print_state(const tsk_matvec_calculator_t *self, FILE *out
fprintf(out, "Matvec state:\n");
fprintf(out, "options = %d\n", self->options);
fprintf(out, "position = %f\n", self->position);
fprintf(out, "focal nodes = %lld: [", (long long) self->num_focal_nodes);
fprintf(out, "tree_pos:\n");
tsk_tree_position_print_state(&self->tree_pos, out);
fprintf(out, "samples = %lld: [", (long long) num_samples);
Expand All @@ -9946,25 +9948,27 @@ tsk_matvec_calculator_print_state(const tsk_matvec_calculator_t *self, FILE *out
static int
tsk_matvec_calculator_init(tsk_matvec_calculator_t *self, const tsk_treeseq_t *ts,
tsk_size_t num_weights, const double *weights, tsk_size_t num_windows,
const double *windows, tsk_flags_t options, double *result)
const double *windows, tsk_size_t num_focal_nodes, const tsk_id_t *focal_nodes,
tsk_flags_t options, double *result)
{
int ret = 0;
tsk_size_t num_samples = tsk_treeseq_get_num_samples(ts);
const tsk_size_t num_nodes = ts->tables->nodes.num_rows;
const double *row;
double *new_row;
tsk_size_t k;
tsk_id_t u, j;
tsk_id_t index, u, j;
double *weight_means = tsk_malloc(num_weights * sizeof(*weight_means));
const tsk_size_t num_trees = ts->num_trees;
const double *restrict breakpoints = ts->breakpoints;
tsk_id_t index;

self->ts = ts;
self->num_weights = num_weights;
self->weights = weights;
self->num_windows = num_windows;
self->windows = windows;
self->num_focal_nodes = num_focal_nodes;
self->focal_nodes = focal_nodes;
self->options = options;
self->result = result;
self->num_nodes = num_nodes;
Expand All @@ -9981,9 +9985,16 @@ tsk_matvec_calculator_init(tsk_matvec_calculator_t *self, const tsk_treeseq_t *t
goto out;
}

tsk_memset(result, 0, num_windows * num_samples * num_weights * sizeof(*result));
tsk_memset(result, 0, num_windows * num_focal_nodes * num_weights * sizeof(*result));
tsk_memset(self->parent, TSK_NULL, num_nodes * sizeof(*self->parent));

for (j = 0; j < (tsk_id_t) num_focal_nodes; j++) {
if (focal_nodes[j] < 0 || (tsk_size_t) focal_nodes[j] >= num_nodes) {
ret = TSK_ERR_NODE_OUT_OF_BOUNDS;
goto out;
}
}

ret = tsk_tree_position_init(&self->tree_pos, ts, 0);
if (ret != 0) {
goto out;
Expand All @@ -10001,6 +10012,7 @@ tsk_matvec_calculator_init(tsk_matvec_calculator_t *self, const tsk_treeseq_t *t
for (k = 0; k < num_weights; k++) {
weight_means[k] = 0.0;
}
/* centre the input */
if (!(options & TSK_STAT_NONCENTRED)) {
for (j = 0; j < (tsk_id_t) num_samples; j++) {
row = GET_2D_ROW(weights, num_weights, j);
Expand All @@ -10013,6 +10025,7 @@ tsk_matvec_calculator_init(tsk_matvec_calculator_t *self, const tsk_treeseq_t *t
}
}

/* set the initial state */
for (j = 0; j < (tsk_id_t) num_samples; j++) {
u = ts->samples[j];
row = GET_2D_ROW(weights, num_weights, j);
Expand Down Expand Up @@ -10129,7 +10142,7 @@ tsk_matvec_calculator_write_output(tsk_matvec_calculator_t *self, double *restri
int ret = 0;
tsk_id_t u;
tsk_size_t j, k;
tsk_size_t n = tsk_treeseq_get_num_samples(self->ts);
const tsk_size_t n = self->num_focal_nodes;
const tsk_size_t num_weights = self->num_weights;
const double position = self->position;
double *u_row, *out_row;
Expand All @@ -10139,7 +10152,7 @@ tsk_matvec_calculator_write_output(tsk_matvec_calculator_t *self, double *restri
double *restrict x = self->x;
double *restrict w = self->w;
double *restrict v = self->v;
const tsk_id_t *restrict samples = self->ts->samples;
const tsk_id_t *restrict focal_nodes = self->focal_nodes;

if (out_means == NULL) {
ret = TSK_ERR_NO_MEMORY;
Expand All @@ -10148,7 +10161,7 @@ tsk_matvec_calculator_write_output(tsk_matvec_calculator_t *self, double *restri

for (j = 0; j < n; j++) {
out_row = GET_2D_ROW(y, num_weights, j);
u = samples[j];
u = focal_nodes[j];
while (u != TSK_NULL) {
if (x[u] != position) {
tsk_matvec_calculator_add_z(
Expand Down Expand Up @@ -10195,7 +10208,7 @@ tsk_matvec_calculator_run(tsk_matvec_calculator_t *self)
int ret = 0;
tsk_size_t j, k, m;
tsk_id_t e, p, c;
tsk_size_t n = tsk_treeseq_get_num_samples(self->ts);
const tsk_size_t out_size = self->num_weights * self->num_focal_nodes;
const tsk_size_t num_edges = self->ts->tables->edges.num_rows;
const double *restrict edge_right = self->ts->tables->edges.right;
const double *restrict edge_left = self->ts->tables->edges.left;
Expand Down Expand Up @@ -10253,7 +10266,7 @@ tsk_matvec_calculator_run(tsk_matvec_calculator_t *self)
tsk_bug_assert(self->position < next_position);
self->position = next_position;
if (self->position == windows[m + 1]) {
out = GET_2D_ROW(self->result, self->num_weights * n, m);
out = GET_2D_ROW(self->result, out_size, m);
tsk_matvec_calculator_write_output(self, out);
m += 1;
}
Expand All @@ -10268,7 +10281,8 @@ tsk_matvec_calculator_run(tsk_matvec_calculator_t *self)

int
tsk_treeseq_genetic_relatedness_vector(const tsk_treeseq_t *self, tsk_size_t num_weights,
const double *weights, tsk_size_t num_windows, const double *windows, double *result,
const double *weights, tsk_size_t num_windows, const double *windows,
tsk_size_t num_focal_nodes, const tsk_id_t *focal_nodes, double *result,
tsk_flags_t options)
{
int ret = 0;
Expand All @@ -10287,8 +10301,8 @@ tsk_treeseq_genetic_relatedness_vector(const tsk_treeseq_t *self, tsk_size_t num
goto out;
}

ret = tsk_matvec_calculator_init(
&calc, self, num_weights, weights, num_windows, windows, options, result);
ret = tsk_matvec_calculator_init(&calc, self, num_weights, weights, num_windows,
windows, num_focal_nodes, focal_nodes, options, result);
if (ret != 0) {
goto out;
}
Expand Down
6 changes: 4 additions & 2 deletions c/tskit/trees.h
Original file line number Diff line number Diff line change
Expand Up @@ -1028,12 +1028,14 @@ int tsk_treeseq_genetic_relatedness_weighted(const tsk_treeseq_t *self,
/* One way weighted stats with vector output */

typedef int weighted_vector_method(const tsk_treeseq_t *self, tsk_size_t num_weights,
const double *weights, tsk_size_t num_windows, const double *windows, double *result,
const double *weights, tsk_size_t num_windows, const double *windows,
tsk_size_t num_focal_nodes, const tsk_id_t *focal_nodes, double *result,
tsk_flags_t options);

int tsk_treeseq_genetic_relatedness_vector(const tsk_treeseq_t *self,
tsk_size_t num_weights, const double *weights, tsk_size_t num_windows,
const double *windows, double *result, tsk_flags_t options);
const double *windows, tsk_size_t num_focal_nodes, const tsk_id_t *focal_nodes,
double *result, tsk_flags_t options);

/* One way sample set stats */

Expand Down
Loading

0 comments on commit bc5a73c

Please sign in to comment.