diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 821f168cd1..1eda339aa7 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -802,11 +802,16 @@ parse_sample_sets(PyObject *sample_set_sizes, PyArrayObject **ret_sample_set_siz } shape = PyArray_DIMS(sample_set_sizes_array); num_sample_sets = shape[0]; + /* The sum of the lengths in sample_set_sizes must be equal to the length * of the sample_sets array */ sum = 0; a = PyArray_DATA(sample_set_sizes_array); for (j = 0; j < num_sample_sets; j++) { + if (sum + a[j] < sum) { + PyErr_SetString(PyExc_ValueError, "Overflow in sample set sizes sum"); + goto out; + } sum += a[j]; } @@ -9760,23 +9765,18 @@ TreeSequence_divergence_matrix(TreeSequence *self, PyObject *args, PyObject *kwd if (TreeSequence_check_state(self) != 0) { goto out; } - if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|OOsi", kwlist, &py_windows, + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|si", kwlist, &py_windows, &py_sample_set_sizes, &py_sample_sets, &mode, &span_normalise)) { goto out; } - if (py_sample_set_sizes != Py_None && py_sample_sets != Py_None) { - if (parse_sample_sets(py_sample_set_sizes, &sample_set_sizes_array, - py_sample_sets, &sample_sets_array, &num_sample_sets) - != 0) { - goto out; - } - sample_set_sizes = PyArray_DATA(sample_set_sizes_array); - sample_sets = PyArray_DATA(sample_sets_array); - } else { - assert(py_sample_set_sizes == Py_None && py_sample_sets == Py_None); - num_sample_sets = tsk_treeseq_get_num_samples(self->tree_sequence); + if (parse_sample_sets(py_sample_set_sizes, &sample_set_sizes_array, py_sample_sets, + &sample_sets_array, &num_sample_sets) + != 0) { + goto out; } + sample_set_sizes = PyArray_DATA(sample_set_sizes_array); + sample_sets = PyArray_DATA(sample_sets_array); if (parse_windows(py_windows, &windows_array, &num_windows) != 0) { goto out; } @@ -9794,9 +9794,6 @@ TreeSequence_divergence_matrix(TreeSequence *self, PyObject *args, PyObject *kwd if (span_normalise) { options |= TSK_STAT_SPAN_NORMALISE; } - /* printf("num sets = %d\n", (int) num_sample_sets); */ - /* printf("sizes = %p\n", (void *) sample_set_sizes); */ - /* printf("sample_sets = %p\n", (void *) sample_sets); */ // clang-format off Py_BEGIN_ALLOW_THREADS diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index f86a0ff8bd..cc2af2a4a6 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -1533,7 +1533,9 @@ def test_divergence_matrix(self): n = 10 ts = self.get_example_tree_sequence(n, random_seed=12) windows = [0, ts.get_sequence_length()] - D = ts.divergence_matrix(windows) + ids = np.arange(n, dtype=np.int32) + sizes = np.ones(n, dtype=np.uint64) + D = ts.divergence_matrix(windows, sizes, ids) assert D.shape == (1, n, n) D = ts.divergence_matrix(windows, sample_set_sizes=[1, 1], sample_sets=[0, 1]) assert D.shape == (1, 2, 2) @@ -1541,19 +1543,31 @@ def test_divergence_matrix(self): windows, sample_set_sizes=[1, 1], sample_sets=[0, 1], span_normalise=True ) assert D.shape == (1, 2, 2) - # TODO Add moer tests for sample sets handling + + for bad_node in [-1, -2, 1000]: + with pytest.raises(_tskit.LibraryError, match="TSK_ERR_NODE_OUT_OF_BOUNDS"): + ts.divergence_matrix(windows, [1, 1], [0, bad_node]) + with pytest.raises(ValueError, match="Sum of sample_set_sizes"): + ts.divergence_matrix(windows, [1, 2], [0, 1]) + with pytest.raises(ValueError, match="Overflow"): + ts.divergence_matrix(windows, [-1, 2], [0]) + with pytest.raises(TypeError, match="str"): - ts.divergence_matrix(windows, span_normalise="xdf") + ts.divergence_matrix(windows, sizes, ids, span_normalise="xdf") with pytest.raises(TypeError): ts.divergence_matrix(windoze=[0, 1]) with pytest.raises(ValueError, match="at least 2"): - ts.divergence_matrix(windows=[0]) + ts.divergence_matrix( + [0], + sizes, + ids, + ) with pytest.raises(_tskit.LibraryError, match="BAD_WINDOWS"): - ts.divergence_matrix(windows=[-1, 0, 1]) + ts.divergence_matrix([-1, 0, 1], sizes, ids) with pytest.raises(ValueError, match="Unrecognised stats mode"): - ts.divergence_matrix(windows=[0, 1], mode="sdf") + ts.divergence_matrix([0, 1], sizes, ids, mode="sdf") with pytest.raises(_tskit.LibraryError, match="UNSUPPORTED_STAT_MODE"): - ts.divergence_matrix(windows=[0, 1], mode="node") + ts.divergence_matrix([0, 1], sizes, ids, mode="node") def test_load_tables_build_indexes(self): for ts in self.get_example_tree_sequences(): diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 6e22158865..b86654339b 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -7828,7 +7828,7 @@ def _parse_stat_matrix_ids_arg(ids): of ID lists. """ id_dtype = np.int32 - size_dtype = np.uint32 + size_dtype = np.uint64 # Exclude some types that could be specified accidentally, and # we may want to reserve for future use. if isinstance(ids, (str, bytes, collections.abc.Mapping, numbers.Number)):