Skip to content

Commit

Permalink
Merge pull request #7 from Hekstra-Lab/fused
Browse files Browse the repository at this point in the history
Add fused types to non-parallel version
  • Loading branch information
ianhi committed Nov 1, 2021
2 parents bc26c8e + 19351d6 commit 2e8f5cd
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 5 deletions.
3 changes: 2 additions & 1 deletion examples/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
shape = (int(np.max(im1) + 1), int(np.max(im2) + 1))
out = overlap_parallel(im1.astype(np.int32), im2.astype(np.int32), shape)
print(out.sum())
out_serial = overlap(im1.astype(np.int32), im2.astype(np.int32), shape)
# out_serial = overlap(im1.astype(np.int32), im2.astype(np.int32), shape)
out_serial = overlap(im1.astype(np.uint16), im2.astype(np.uint16), shape)

# from IPython import embed
# embed(colors="Linux")
Expand Down
16 changes: 14 additions & 2 deletions fast_overlap/fast_overlap.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,28 @@ cpdef overlap_parallel(int [:,::1] prev, int[:,::1] curr, shape):
cdef np.ndarray[int, ndim=2, mode="c"] output = np.zeros(shape, dtype=np.dtype("i"))
cdef Py_ssize_t ncols = shape[1]

print('about to nogil')
with nogil:
overlap_parallel_cpp(&prev[0,0], &curr[0,0], prev.shape, &output[0,0], ncols)
return output



from libc cimport stdint

ctypedef fused ints:
stdint.uint8_t
stdint.uint16_t
stdint.uint32_t
stdint.uint64_t
stdint.int8_t
stdint.int16_t
stdint.int32_t
stdint.int64_t

# @cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
cpdef overlap(int[:, :] prev, int[:,:] curr, shape):
cpdef overlap(ints[:, :] prev, ints[:,:] curr, shape):
"""
Calculate the pairwise overlap the labels for two arrays.
Expand Down
7 changes: 5 additions & 2 deletions tests/test_overlap.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path

import numpy as np
import pytest

import fast_overlap

Expand All @@ -9,8 +10,10 @@
shape = (int(np.max(ims[0]) + 1), int(np.max(ims[1]) + 1))


def test_overlap():
out = fast_overlap.overlap(ims[0].astype(np.int32), ims[1].astype(np.int32), shape)
# test a few different types but not all
@pytest.mark.parametrize("type", [np.uint16, np.uint64, np.int32, np.int64])
def test_overlap(type):
out = fast_overlap.overlap(ims[0].astype(type), ims[1].astype(type), shape)
assert np.all(out == expected)


Expand Down

0 comments on commit 2e8f5cd

Please sign in to comment.