Skip to content

Commit

Permalink
Merge pull request #4525 from matthewturk/bitarray_fixes
Browse files Browse the repository at this point in the history
ENH: Add a few bitarray functions and tests
  • Loading branch information
matthewturk authored Apr 26, 2024
2 parents 2a52594 + 9351f96 commit 28a963e
Show file tree
Hide file tree
Showing 3 changed files with 314 additions and 7 deletions.
70 changes: 67 additions & 3 deletions yt/utilities/lib/bitarray.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,21 @@ cimport numpy as np

cdef inline void ba_set_value(np.uint8_t *buf, np.uint64_t ind,
np.uint8_t val) noexcept nogil:
# This assumes 8 bit buffer
# This assumes 8 bit buffer. If value is greater than zero (thus allowing
# us to use 1-255 as 'True') then we identify first the index in the buffer
# we are setting. We do this by truncating the index by bit-shifting to
# the left three times, essentially dividing it by eight (and taking the
# floor.)
# The next step is to turn *on* what we're attempting to turn on, which
# means taking our index and truncating it to the first 3 bits (which we do
# with an & operation) and then turning on the correct bit.
#
# So if we're asking for index 33 in the bitarray, we would want the 4th
# uint8 element, then the 2nd bit (index 1).
#
# To turn it on, we logical *or* with that. To turn it off, we logical
# *and* with the *inverse*, which will allow everything *but* that bit to
# stay on.
if val > 0:
buf[ind >> 3] |= (1 << (ind & 7))
else:
Expand All @@ -25,13 +39,63 @@ cdef inline np.uint8_t ba_get_value(np.uint8_t *buf, np.uint64_t ind) noexcept n
if rv == 0: return 0
return 1

cdef inline void ba_set_range(np.uint8_t *buf, np.uint64_t start_ind,
np.uint64_t stop_ind, np.uint8_t val) nogil:
# Should this be inclusive of both end points? I think it should not, to
# match slicing semantics.
#
# We need to figure out the first and last values, and then we just set the
# ones in-between to 255.
if stop_ind < start_ind: return
cdef np.uint64_t i
cdef np.uint8_t j, bitmask
cdef np.uint64_t buf_start = start_ind >> 3
cdef np.uint64_t buf_stop = stop_ind >> 3
cdef np.uint8_t start_j = start_ind & 7
cdef np.uint8_t stop_j = stop_ind & 7
if buf_start == buf_stop:
for j in range(start_j, stop_j):
ba_set_value(&buf[buf_start], j, val)
return
bitmask = 0
for j in range(start_j, 8):
bitmask |= (1 << j)
if val > 0:
buf[buf_start] |= bitmask
else:
buf[buf_start] &= ~bitmask
if val > 0:
bitmask = 255
else:
bitmask = 0
for i in range(buf_start + 1, buf_stop):
buf[i] = bitmask
bitmask = 0
for j in range(0, stop_j):
bitmask |= (1 << j)
if val > 0:
buf[buf_stop] |= bitmask
else:
buf[buf_stop] &= ~bitmask


cdef inline np.uint8_t _num_set_bits( np.uint8_t b ):
# https://stackoverflow.com/questions/30688465/how-to-check-the-number-of-set-bits-in-an-8-bit-unsigned-char
b = b - ((b >> 1) & 0x55)
b = (b & 0x33) + ((b >> 2) & 0x33)
return (((b + (b >> 4)) & 0x0F) * 0x01)

cdef class bitarray:
cdef np.uint8_t *buf
cdef np.uint64_t size
cdef np.uint64_t buf_size # Not exactly the same
cdef np.uint8_t final_bitmask
cdef public object ibuf

cdef void _set_value(self, np.uint64_t ind, np.uint8_t val)
cdef np.uint8_t _query_value(self, np.uint64_t ind)
#cdef void set_range(self, np.uint64_t ind, np.uint64_t count, int val)
#cdef int query_range(self, np.uint64_t ind, np.uint64_t count, int *val)
cdef void _set_range(self, np.uint64_t start, np.uint64_t stop, np.uint8_t val)
cdef np.uint64_t _count(self)
cdef bitarray _logical_and(self, bitarray other, bitarray result = *)
cdef bitarray _logical_or(self, bitarray other, bitarray result = *)
cdef bitarray _logical_xor(self, bitarray other, bitarray result = *)
130 changes: 130 additions & 0 deletions yt/utilities/lib/bitarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,14 @@ cdef class bitarray:
if size != arr.size:
raise RuntimeError
self.buf_size = (size >> 3)
cdef np.uint8_t bitmask = 255
if (size & 7) != 0:
# We need an extra one if we've got any lingering bits
self.buf_size += 1
bitmask = 0
for i in range(size & 7):
bitmask |= (1<<i)
self.final_bitmask = bitmask
cdef np.ndarray[np.uint8_t] ibuf_t
ibuf_t = self.ibuf = np.zeros(self.buf_size, "uint8")
self.buf = <np.uint8_t *> ibuf_t.data
Expand Down Expand Up @@ -163,3 +168,128 @@ cdef class bitarray:
"""
return ba_get_value(self.buf, ind)

@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
cdef void _set_range(self, np.uint64_t start, np.uint64_t stop, np.uint8_t val):
ba_set_range(self.buf, start, stop, val)

@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
def set_range(self, np.uint64_t start, np.uint64_t stop, np.uint8_t val):
r"""Set a range of values to on/off. Uses slice-style indexing.
No return value.
Parameters
----------
start : int
The starting component of a slice.
stop : int
The ending component of a slice.
val : bool or uint8_t
What to set the range to
Examples
--------
>>> arr_in = np.array([True, True, False, True, True, False])
>>> a = ba.bitarray(arr = arr_in)
>>> a.set_range(0, 3, 0)
"""
ba_set_range(self.buf, start, stop, val)

@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
cdef np.uint64_t _count(self):
cdef np.uint64_t count = 0
cdef np.uint64_t i
self.buf[self.buf_size - 1] &= self.final_bitmask
for i in range(self.buf_size):
count += _num_set_bits(self.buf[i])
return count


@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
def count(self):
r"""Count the number of values set in the array.
Parameters
----------
Examples
--------
>>> arr_in = np.array([True, True, False, True, True, False])
>>> a = ba.bitarray(arr = arr_in)
>>> a.count()
"""
return self._count()

cdef bitarray _logical_and(self, bitarray other, bitarray result = None):
# Create a place to put it. Note that we might have trailing values,
# we actually need to reset the ending set.
if other.size != self.size:
raise IndexError
if result is None:
result = bitarray(self.size)
for i in range(self.buf_size):
result.buf[i] = other.buf[i] & self.buf[i]
result.buf[self.buf_size - 1] &= self.final_bitmask
return result

def logical_and(self, bitarray other, bitarray result = None):
return self._logical_and(other, result)

def __and__(self, bitarray other):
# Wrap it directly here.
return self.logical_and(other)

def __iand__(self, bitarray other):
rv = self.logical_and(other, self)
return rv

cdef bitarray _logical_or(self, bitarray other, bitarray result = None):
if other.size != self.size:
raise IndexError
if result is None:
result = bitarray(self.size)
for i in range(self.buf_size):
result.buf[i] = other.buf[i] | self.buf[i]
result.buf[self.buf_size - 1] &= self.final_bitmask
return result

def logical_or(self, bitarray other, bitarray result = None):
return self._logical_or(other, result)

def __or__(self, bitarray other):
return self.logical_or(other)

def __ior__(self, bitarray other):
return self.logical_or(other, self)

cdef bitarray _logical_xor(self, bitarray other, bitarray result = None):
if other.size != self.size:
raise IndexError
if result is None:
result = bitarray(self.size)
for i in range(self.buf_size):
result.buf[i] = other.buf[i] ^ self.buf[i]
result.buf[self.buf_size - 1] &= self.final_bitmask
return result

def logical_xor(self, bitarray other, bitarray result = None):
return self._logical_xor(other, result)

def __xor__(self, bitarray other):
return self.logical_xor(other)

def __ixor__(self, bitarray other):
return self.logical_xor(other, self)
121 changes: 117 additions & 4 deletions yt/utilities/lib/tests/test_bitarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@

def test_inout_bitarray():
# Check that we can do it for bitarrays that are funny-shaped
rng = np.random.default_rng()
for i in range(7):
# Check we can feed in an array
arr_in = np.random.random(32**3 + i) > 0.5
arr_in = rng.random(32**3 + i) > 0.5
b = ba.bitarray(arr=arr_in)
if i > 0:
assert_equal(b.ibuf.size, (32**3) / 8.0 + 1)
Expand All @@ -22,18 +23,73 @@ def test_inout_bitarray():
assert_equal(arr_in, arr_out)

# Try a big array
arr_in = np.random.random(32**3 + i) > 0.5
arr_in = rng.random(32**3 + i) > 0.5
b = ba.bitarray(arr=arr_in)
arr_out = b.as_bool_array()
assert_equal(arr_in, arr_out)
assert_equal(b.count(), arr_in.sum())

# Let's check we can do something interesting.
arr_in1 = np.random.random(32**3) > 0.5
arr_in2 = np.random.random(32**3) > 0.5
arr_in1 = rng.random(32**3) > 0.5
arr_in2 = rng.random(32**3) > 0.5
b1 = ba.bitarray(arr=arr_in1)
b2 = ba.bitarray(arr=arr_in2)
b3 = ba.bitarray(arr=(arr_in1 & arr_in2))
assert_equal((b1.ibuf & b2.ibuf), b3.ibuf)
assert_equal(b1.count(), arr_in1.sum())
assert_equal(b2.count(), arr_in2.sum())
# Let's check the logical and operation
b4 = b1.logical_and(b2)
assert_equal(b4.count(), b3.count())
assert_array_equal(b4.as_bool_array(), b3.as_bool_array())

b5 = b1 & b2
assert_equal(b5.count(), b3.count())
assert_array_equal(b5.as_bool_array(), b3.as_bool_array())

b1 &= b2
assert_equal(b1.count(), b4.count())
assert_array_equal(b1.as_bool_array(), b4.as_bool_array())

# Repeat this, but with the logical or operators
b1 = ba.bitarray(arr=arr_in1)
b2 = ba.bitarray(arr=arr_in2)
b3 = ba.bitarray(arr=(arr_in1 | arr_in2))
assert_equal((b1.ibuf | b2.ibuf), b3.ibuf)
assert_equal(b1.count(), arr_in1.sum())
assert_equal(b2.count(), arr_in2.sum())
# Let's check the logical and operation
b4 = b1.logical_or(b2)
assert_equal(b4.count(), b3.count())
assert_array_equal(b4.as_bool_array(), b3.as_bool_array())

b5 = b1 | b2
assert_equal(b5.count(), b3.count())
assert_array_equal(b5.as_bool_array(), b3.as_bool_array())

b1 |= b2
assert_equal(b1.count(), b4.count())
assert_array_equal(b1.as_bool_array(), b4.as_bool_array())

# Repeat this, but with the logical xor operators
b1 = ba.bitarray(arr=arr_in1)
b2 = ba.bitarray(arr=arr_in2)
b3 = ba.bitarray(arr=(arr_in1 ^ arr_in2))
assert_equal((b1.ibuf ^ b2.ibuf), b3.ibuf)
assert_equal(b1.count(), arr_in1.sum())
assert_equal(b2.count(), arr_in2.sum())
# Let's check the logical and operation
b4 = b1.logical_xor(b2)
assert_equal(b4.count(), b3.count())
assert_array_equal(b4.as_bool_array(), b3.as_bool_array())

b5 = b1 ^ b2
assert_equal(b5.count(), b3.count())
assert_array_equal(b5.as_bool_array(), b3.as_bool_array())

b1 ^= b2
assert_equal(b1.count(), b4.count())
assert_array_equal(b1.as_bool_array(), b4.as_bool_array())

b = ba.bitarray(10)
for i in range(10):
Expand All @@ -51,3 +107,60 @@ def test_inout_bitarray():
b.set_value(2, 1)
arr = b.as_bool_array()
assert_array_equal(arr, [0, 0, 1, 0, 0, 0, 0, 1, 0, 0])


def test_set_range():
b = ba.bitarray(127)
# Test once where we're in the middle of start and end bits
b.set_range(4, 65, 1)
comparison_array = np.zeros(127, dtype="uint8")
comparison_array[4:65] = 1
arr = b.as_bool_array().astype("uint8")
assert_array_equal(arr, comparison_array)
assert_equal(b.count(), comparison_array.sum())

# Test when we start and stop in the same byte
b = ba.bitarray(127)
b.set_range(4, 6, 1)
comparison_array = np.zeros(127, dtype="uint8")
comparison_array[4:6] = 1
arr = b.as_bool_array().astype("uint8")
assert_array_equal(arr, comparison_array)
assert_equal(b.count(), comparison_array.sum())

# Test now where we're in the middle of start
b = ba.bitarray(64)
b.set_range(33, 36, 1)
comparison_array = np.zeros(64, dtype="uint8")
comparison_array[33:36] = 1
arr = b.as_bool_array().astype("uint8")
assert_array_equal(arr, comparison_array)
assert_equal(b.count(), comparison_array.sum())

# Now we test when we end on a byte edge, but we have 65 entries
b = ba.bitarray(65)
b.set_range(32, 64, 1)
comparison_array = np.zeros(65, dtype="uint8")
comparison_array[32:64] = 1
arr = b.as_bool_array().astype("uint8")
assert_array_equal(arr, comparison_array)
assert_equal(b.count(), comparison_array.sum())

# Let's do the inverse
b = ba.bitarray(127)
b.set_range(0, 127, 1)
assert_equal(b.as_bool_array().all(), True)
b.set_range(0, 127, 0)
assert_equal(b.as_bool_array().any(), False)
b.set_range(3, 9, 1)
comparison_array = np.zeros(127, dtype="uint8")
comparison_array[3:9] = 1
arr = b.as_bool_array().astype("uint8")
assert_array_equal(arr, comparison_array)
assert_equal(b.count(), comparison_array.sum())

# Now let's overlay some zeros
b.set_range(7, 10, 0)
comparison_array[7:10] = 0
arr = b.as_bool_array().astype("uint8")
assert_array_equal(arr, comparison_array)

0 comments on commit 28a963e

Please sign in to comment.