Skip to content

Commit

Permalink
Harden low-level sample sets
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Aug 23, 2023
1 parent 7c5d1d5 commit 11404c7
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 23 deletions.
27 changes: 12 additions & 15 deletions python/_tskitmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -802,11 +802,16 @@ parse_sample_sets(PyObject *sample_set_sizes, PyArrayObject **ret_sample_set_siz
}
shape = PyArray_DIMS(sample_set_sizes_array);
num_sample_sets = shape[0];

/* The sum of the lengths in sample_set_sizes must be equal to the length
* of the sample_sets array */
sum = 0;
a = PyArray_DATA(sample_set_sizes_array);
for (j = 0; j < num_sample_sets; j++) {
if (sum + a[j] < sum) {
PyErr_SetString(PyExc_ValueError, "Overflow in sample set sizes sum");
goto out;
}
sum += a[j];
}

Expand Down Expand Up @@ -9760,23 +9765,18 @@ TreeSequence_divergence_matrix(TreeSequence *self, PyObject *args, PyObject *kwd
if (TreeSequence_check_state(self) != 0) {
goto out;
}
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|OOsi", kwlist, &py_windows,
if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|si", kwlist, &py_windows,
&py_sample_set_sizes, &py_sample_sets, &mode, &span_normalise)) {
goto out;
}

if (py_sample_set_sizes != Py_None && py_sample_sets != Py_None) {
if (parse_sample_sets(py_sample_set_sizes, &sample_set_sizes_array,
py_sample_sets, &sample_sets_array, &num_sample_sets)
!= 0) {
goto out;
}
sample_set_sizes = PyArray_DATA(sample_set_sizes_array);
sample_sets = PyArray_DATA(sample_sets_array);
} else {
assert(py_sample_set_sizes == Py_None && py_sample_sets == Py_None);
num_sample_sets = tsk_treeseq_get_num_samples(self->tree_sequence);
if (parse_sample_sets(py_sample_set_sizes, &sample_set_sizes_array, py_sample_sets,
&sample_sets_array, &num_sample_sets)
!= 0) {
goto out;
}
sample_set_sizes = PyArray_DATA(sample_set_sizes_array);
sample_sets = PyArray_DATA(sample_sets_array);
if (parse_windows(py_windows, &windows_array, &num_windows) != 0) {
goto out;
}
Expand All @@ -9794,9 +9794,6 @@ TreeSequence_divergence_matrix(TreeSequence *self, PyObject *args, PyObject *kwd
if (span_normalise) {
options |= TSK_STAT_SPAN_NORMALISE;
}
/* printf("num sets = %d\n", (int) num_sample_sets); */
/* printf("sizes = %p\n", (void *) sample_set_sizes); */
/* printf("sample_sets = %p\n", (void *) sample_sets); */

// clang-format off
Py_BEGIN_ALLOW_THREADS
Expand Down
28 changes: 21 additions & 7 deletions python/tests/test_lowlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1533,27 +1533,41 @@ def test_divergence_matrix(self):
n = 10
ts = self.get_example_tree_sequence(n, random_seed=12)
windows = [0, ts.get_sequence_length()]
D = ts.divergence_matrix(windows)
ids = np.arange(n, dtype=np.int32)
sizes = np.ones(n, dtype=np.uint64)
D = ts.divergence_matrix(windows, sizes, ids)
assert D.shape == (1, n, n)
D = ts.divergence_matrix(windows, sample_set_sizes=[1, 1], sample_sets=[0, 1])
assert D.shape == (1, 2, 2)
D = ts.divergence_matrix(
windows, sample_set_sizes=[1, 1], sample_sets=[0, 1], span_normalise=True
)
assert D.shape == (1, 2, 2)
# TODO Add moer tests for sample sets handling

for bad_node in [-1, -2, 1000]:
with pytest.raises(_tskit.LibraryError, match="TSK_ERR_NODE_OUT_OF_BOUNDS"):
ts.divergence_matrix(windows, [1, 1], [0, bad_node])
with pytest.raises(ValueError, match="Sum of sample_set_sizes"):
ts.divergence_matrix(windows, [1, 2], [0, 1])
with pytest.raises(ValueError, match="Overflow"):
ts.divergence_matrix(windows, [-1, 2], [0])

with pytest.raises(TypeError, match="str"):
ts.divergence_matrix(windows, span_normalise="xdf")
ts.divergence_matrix(windows, sizes, ids, span_normalise="xdf")
with pytest.raises(TypeError):
ts.divergence_matrix(windoze=[0, 1])
with pytest.raises(ValueError, match="at least 2"):
ts.divergence_matrix(windows=[0])
ts.divergence_matrix(
[0],
sizes,
ids,
)
with pytest.raises(_tskit.LibraryError, match="BAD_WINDOWS"):
ts.divergence_matrix(windows=[-1, 0, 1])
ts.divergence_matrix([-1, 0, 1], sizes, ids)
with pytest.raises(ValueError, match="Unrecognised stats mode"):
ts.divergence_matrix(windows=[0, 1], mode="sdf")
ts.divergence_matrix([0, 1], sizes, ids, mode="sdf")
with pytest.raises(_tskit.LibraryError, match="UNSUPPORTED_STAT_MODE"):
ts.divergence_matrix(windows=[0, 1], mode="node")
ts.divergence_matrix([0, 1], sizes, ids, mode="node")

def test_load_tables_build_indexes(self):
for ts in self.get_example_tree_sequences():
Expand Down
2 changes: 1 addition & 1 deletion python/tskit/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -7828,7 +7828,7 @@ def _parse_stat_matrix_ids_arg(ids):
of ID lists.
"""
id_dtype = np.int32
size_dtype = np.uint32
size_dtype = np.uint64
# Exclude some types that could be specified accidentally, and
# we may want to reserve for future use.
if isinstance(ids, (str, bytes, collections.abc.Mapping, numbers.Number)):
Expand Down

0 comments on commit 11404c7

Please sign in to comment.