diff --git a/examples/test.py b/examples/test.py index 36be7a6..4c95c1d 100644 --- a/examples/test.py +++ b/examples/test.py @@ -13,7 +13,8 @@ shape = (int(np.max(im1) + 1), int(np.max(im2) + 1)) out = overlap_parallel(im1.astype(np.int32), im2.astype(np.int32), shape) print(out.sum()) -out_serial = overlap(im1.astype(np.int32), im2.astype(np.int32), shape) +# out_serial = overlap(im1.astype(np.int32), im2.astype(np.int32), shape) +out_serial = overlap(im1.astype(np.uint16), im2.astype(np.uint16), shape) # from IPython import embed # embed(colors="Linux") diff --git a/fast_overlap/fast_overlap.pyx b/fast_overlap/fast_overlap.pyx index 6f99e20..728c3d1 100644 --- a/fast_overlap/fast_overlap.pyx +++ b/fast_overlap/fast_overlap.pyx @@ -25,16 +25,28 @@ cpdef overlap_parallel(int [:,::1] prev, int[:,::1] curr, shape): cdef np.ndarray[int, ndim=2, mode="c"] output = np.zeros(shape, dtype=np.dtype("i")) cdef Py_ssize_t ncols = shape[1] - print('about to nogil') with nogil: overlap_parallel_cpp(&prev[0,0], &curr[0,0], prev.shape, &output[0,0], ncols) return output + +from libc cimport stdint + +ctypedef fused ints: + stdint.uint8_t + stdint.uint16_t + stdint.uint32_t + stdint.uint64_t + stdint.int8_t + stdint.int16_t + stdint.int32_t + stdint.int64_t + # @cython.boundscheck(False) @cython.wraparound(False) @cython.nonecheck(False) -cpdef overlap(int[:, :] prev, int[:,:] curr, shape): +cpdef overlap(ints[:, :] prev, ints[:,:] curr, shape): """ Calculate the pairwise overlap the labels for two arrays. diff --git a/tests/test_overlap.py b/tests/test_overlap.py index 2486f5e..2f16916 100644 --- a/tests/test_overlap.py +++ b/tests/test_overlap.py @@ -1,6 +1,7 @@ from pathlib import Path import numpy as np +import pytest import fast_overlap @@ -9,8 +10,10 @@ shape = (int(np.max(ims[0]) + 1), int(np.max(ims[1]) + 1)) -def test_overlap(): - out = fast_overlap.overlap(ims[0].astype(np.int32), ims[1].astype(np.int32), shape) +# test a few different types but not all +@pytest.mark.parametrize("type", [np.uint16, np.uint64, np.int32, np.int64]) +def test_overlap(type): + out = fast_overlap.overlap(ims[0].astype(type), ims[1].astype(type), shape) assert np.all(out == expected)