diff --git a/yt/utilities/lib/bitarray.pxd b/yt/utilities/lib/bitarray.pxd index a809d4a351f..ff4f7701ca2 100644 --- a/yt/utilities/lib/bitarray.pxd +++ b/yt/utilities/lib/bitarray.pxd @@ -14,7 +14,21 @@ cimport numpy as np cdef inline void ba_set_value(np.uint8_t *buf, np.uint64_t ind, np.uint8_t val) noexcept nogil: - # This assumes 8 bit buffer + # This assumes 8 bit buffer. If value is greater than zero (thus allowing + # us to use 1-255 as 'True') then we identify first the index in the buffer + # we are setting. We do this by truncating the index by bit-shifting to + # the left three times, essentially dividing it by eight (and taking the + # floor.) + # The next step is to turn *on* what we're attempting to turn on, which + # means taking our index and truncating it to the first 3 bits (which we do + # with an & operation) and then turning on the correct bit. + # + # So if we're asking for index 33 in the bitarray, we would want the 4th + # uint8 element, then the 2nd bit (index 1). + # + # To turn it on, we logical *or* with that. To turn it off, we logical + # *and* with the *inverse*, which will allow everything *but* that bit to + # stay on. if val > 0: buf[ind >> 3] |= (1 << (ind & 7)) else: @@ -25,13 +39,63 @@ cdef inline np.uint8_t ba_get_value(np.uint8_t *buf, np.uint64_t ind) noexcept n if rv == 0: return 0 return 1 +cdef inline void ba_set_range(np.uint8_t *buf, np.uint64_t start_ind, + np.uint64_t stop_ind, np.uint8_t val) nogil: + # Should this be inclusive of both end points? I think it should not, to + # match slicing semantics. + # + # We need to figure out the first and last values, and then we just set the + # ones in-between to 255. + if stop_ind < start_ind: return + cdef np.uint64_t i + cdef np.uint8_t j, bitmask + cdef np.uint64_t buf_start = start_ind >> 3 + cdef np.uint64_t buf_stop = stop_ind >> 3 + cdef np.uint8_t start_j = start_ind & 7 + cdef np.uint8_t stop_j = stop_ind & 7 + if buf_start == buf_stop: + for j in range(start_j, stop_j): + ba_set_value(&buf[buf_start], j, val) + return + bitmask = 0 + for j in range(start_j, 8): + bitmask |= (1 << j) + if val > 0: + buf[buf_start] |= bitmask + else: + buf[buf_start] &= ~bitmask + if val > 0: + bitmask = 255 + else: + bitmask = 0 + for i in range(buf_start + 1, buf_stop): + buf[i] = bitmask + bitmask = 0 + for j in range(0, stop_j): + bitmask |= (1 << j) + if val > 0: + buf[buf_stop] |= bitmask + else: + buf[buf_stop] &= ~bitmask + + +cdef inline np.uint8_t _num_set_bits( np.uint8_t b ): + # https://stackoverflow.com/questions/30688465/how-to-check-the-number-of-set-bits-in-an-8-bit-unsigned-char + b = b - ((b >> 1) & 0x55) + b = (b & 0x33) + ((b >> 2) & 0x33) + return (((b + (b >> 4)) & 0x0F) * 0x01) + cdef class bitarray: cdef np.uint8_t *buf cdef np.uint64_t size cdef np.uint64_t buf_size # Not exactly the same + cdef np.uint8_t final_bitmask cdef public object ibuf cdef void _set_value(self, np.uint64_t ind, np.uint8_t val) cdef np.uint8_t _query_value(self, np.uint64_t ind) - #cdef void set_range(self, np.uint64_t ind, np.uint64_t count, int val) - #cdef int query_range(self, np.uint64_t ind, np.uint64_t count, int *val) + cdef void _set_range(self, np.uint64_t start, np.uint64_t stop, np.uint8_t val) + cdef np.uint64_t _count(self) + cdef bitarray _logical_and(self, bitarray other, bitarray result = *) + cdef bitarray _logical_or(self, bitarray other, bitarray result = *) + cdef bitarray _logical_xor(self, bitarray other, bitarray result = *) diff --git a/yt/utilities/lib/bitarray.pyx b/yt/utilities/lib/bitarray.pyx index 1a7f85ac85d..07092455260 100644 --- a/yt/utilities/lib/bitarray.pyx +++ b/yt/utilities/lib/bitarray.pyx @@ -54,9 +54,14 @@ cdef class bitarray: if size != arr.size: raise RuntimeError self.buf_size = (size >> 3) + cdef np.uint8_t bitmask = 255 if (size & 7) != 0: # We need an extra one if we've got any lingering bits self.buf_size += 1 + bitmask = 0 + for i in range(size & 7): + bitmask |= (1< ibuf_t.data @@ -163,3 +168,128 @@ cdef class bitarray: """ return ba_get_value(self.buf, ind) + + @cython.boundscheck(False) + @cython.wraparound(False) + @cython.cdivision(True) + cdef void _set_range(self, np.uint64_t start, np.uint64_t stop, np.uint8_t val): + ba_set_range(self.buf, start, stop, val) + + @cython.boundscheck(False) + @cython.wraparound(False) + @cython.cdivision(True) + def set_range(self, np.uint64_t start, np.uint64_t stop, np.uint8_t val): + r"""Set a range of values to on/off. Uses slice-style indexing. + + No return value. + + Parameters + ---------- + start : int + The starting component of a slice. + stop : int + The ending component of a slice. + val : bool or uint8_t + What to set the range to + + Examples + -------- + + >>> arr_in = np.array([True, True, False, True, True, False]) + >>> a = ba.bitarray(arr = arr_in) + >>> a.set_range(0, 3, 0) + + """ + ba_set_range(self.buf, start, stop, val) + + @cython.boundscheck(False) + @cython.wraparound(False) + @cython.cdivision(True) + cdef np.uint64_t _count(self): + cdef np.uint64_t count = 0 + cdef np.uint64_t i + self.buf[self.buf_size - 1] &= self.final_bitmask + for i in range(self.buf_size): + count += _num_set_bits(self.buf[i]) + return count + + + @cython.boundscheck(False) + @cython.wraparound(False) + @cython.cdivision(True) + def count(self): + r"""Count the number of values set in the array. + + Parameters + ---------- + + Examples + -------- + + >>> arr_in = np.array([True, True, False, True, True, False]) + >>> a = ba.bitarray(arr = arr_in) + >>> a.count() + + """ + return self._count() + + cdef bitarray _logical_and(self, bitarray other, bitarray result = None): + # Create a place to put it. Note that we might have trailing values, + # we actually need to reset the ending set. + if other.size != self.size: + raise IndexError + if result is None: + result = bitarray(self.size) + for i in range(self.buf_size): + result.buf[i] = other.buf[i] & self.buf[i] + result.buf[self.buf_size - 1] &= self.final_bitmask + return result + + def logical_and(self, bitarray other, bitarray result = None): + return self._logical_and(other, result) + + def __and__(self, bitarray other): + # Wrap it directly here. + return self.logical_and(other) + + def __iand__(self, bitarray other): + rv = self.logical_and(other, self) + return rv + + cdef bitarray _logical_or(self, bitarray other, bitarray result = None): + if other.size != self.size: + raise IndexError + if result is None: + result = bitarray(self.size) + for i in range(self.buf_size): + result.buf[i] = other.buf[i] | self.buf[i] + result.buf[self.buf_size - 1] &= self.final_bitmask + return result + + def logical_or(self, bitarray other, bitarray result = None): + return self._logical_or(other, result) + + def __or__(self, bitarray other): + return self.logical_or(other) + + def __ior__(self, bitarray other): + return self.logical_or(other, self) + + cdef bitarray _logical_xor(self, bitarray other, bitarray result = None): + if other.size != self.size: + raise IndexError + if result is None: + result = bitarray(self.size) + for i in range(self.buf_size): + result.buf[i] = other.buf[i] ^ self.buf[i] + result.buf[self.buf_size - 1] &= self.final_bitmask + return result + + def logical_xor(self, bitarray other, bitarray result = None): + return self._logical_xor(other, result) + + def __xor__(self, bitarray other): + return self.logical_xor(other) + + def __ixor__(self, bitarray other): + return self.logical_xor(other, self) diff --git a/yt/utilities/lib/tests/test_bitarray.py b/yt/utilities/lib/tests/test_bitarray.py index 9880652ca37..15978b45ee1 100644 --- a/yt/utilities/lib/tests/test_bitarray.py +++ b/yt/utilities/lib/tests/test_bitarray.py @@ -6,9 +6,10 @@ def test_inout_bitarray(): # Check that we can do it for bitarrays that are funny-shaped + rng = np.random.default_rng() for i in range(7): # Check we can feed in an array - arr_in = np.random.random(32**3 + i) > 0.5 + arr_in = rng.random(32**3 + i) > 0.5 b = ba.bitarray(arr=arr_in) if i > 0: assert_equal(b.ibuf.size, (32**3) / 8.0 + 1) @@ -22,18 +23,73 @@ def test_inout_bitarray(): assert_equal(arr_in, arr_out) # Try a big array - arr_in = np.random.random(32**3 + i) > 0.5 + arr_in = rng.random(32**3 + i) > 0.5 b = ba.bitarray(arr=arr_in) arr_out = b.as_bool_array() assert_equal(arr_in, arr_out) + assert_equal(b.count(), arr_in.sum()) # Let's check we can do something interesting. - arr_in1 = np.random.random(32**3) > 0.5 - arr_in2 = np.random.random(32**3) > 0.5 + arr_in1 = rng.random(32**3) > 0.5 + arr_in2 = rng.random(32**3) > 0.5 b1 = ba.bitarray(arr=arr_in1) b2 = ba.bitarray(arr=arr_in2) b3 = ba.bitarray(arr=(arr_in1 & arr_in2)) assert_equal((b1.ibuf & b2.ibuf), b3.ibuf) + assert_equal(b1.count(), arr_in1.sum()) + assert_equal(b2.count(), arr_in2.sum()) + # Let's check the logical and operation + b4 = b1.logical_and(b2) + assert_equal(b4.count(), b3.count()) + assert_array_equal(b4.as_bool_array(), b3.as_bool_array()) + + b5 = b1 & b2 + assert_equal(b5.count(), b3.count()) + assert_array_equal(b5.as_bool_array(), b3.as_bool_array()) + + b1 &= b2 + assert_equal(b1.count(), b4.count()) + assert_array_equal(b1.as_bool_array(), b4.as_bool_array()) + + # Repeat this, but with the logical or operators + b1 = ba.bitarray(arr=arr_in1) + b2 = ba.bitarray(arr=arr_in2) + b3 = ba.bitarray(arr=(arr_in1 | arr_in2)) + assert_equal((b1.ibuf | b2.ibuf), b3.ibuf) + assert_equal(b1.count(), arr_in1.sum()) + assert_equal(b2.count(), arr_in2.sum()) + # Let's check the logical and operation + b4 = b1.logical_or(b2) + assert_equal(b4.count(), b3.count()) + assert_array_equal(b4.as_bool_array(), b3.as_bool_array()) + + b5 = b1 | b2 + assert_equal(b5.count(), b3.count()) + assert_array_equal(b5.as_bool_array(), b3.as_bool_array()) + + b1 |= b2 + assert_equal(b1.count(), b4.count()) + assert_array_equal(b1.as_bool_array(), b4.as_bool_array()) + + # Repeat this, but with the logical xor operators + b1 = ba.bitarray(arr=arr_in1) + b2 = ba.bitarray(arr=arr_in2) + b3 = ba.bitarray(arr=(arr_in1 ^ arr_in2)) + assert_equal((b1.ibuf ^ b2.ibuf), b3.ibuf) + assert_equal(b1.count(), arr_in1.sum()) + assert_equal(b2.count(), arr_in2.sum()) + # Let's check the logical and operation + b4 = b1.logical_xor(b2) + assert_equal(b4.count(), b3.count()) + assert_array_equal(b4.as_bool_array(), b3.as_bool_array()) + + b5 = b1 ^ b2 + assert_equal(b5.count(), b3.count()) + assert_array_equal(b5.as_bool_array(), b3.as_bool_array()) + + b1 ^= b2 + assert_equal(b1.count(), b4.count()) + assert_array_equal(b1.as_bool_array(), b4.as_bool_array()) b = ba.bitarray(10) for i in range(10): @@ -51,3 +107,60 @@ def test_inout_bitarray(): b.set_value(2, 1) arr = b.as_bool_array() assert_array_equal(arr, [0, 0, 1, 0, 0, 0, 0, 1, 0, 0]) + + +def test_set_range(): + b = ba.bitarray(127) + # Test once where we're in the middle of start and end bits + b.set_range(4, 65, 1) + comparison_array = np.zeros(127, dtype="uint8") + comparison_array[4:65] = 1 + arr = b.as_bool_array().astype("uint8") + assert_array_equal(arr, comparison_array) + assert_equal(b.count(), comparison_array.sum()) + + # Test when we start and stop in the same byte + b = ba.bitarray(127) + b.set_range(4, 6, 1) + comparison_array = np.zeros(127, dtype="uint8") + comparison_array[4:6] = 1 + arr = b.as_bool_array().astype("uint8") + assert_array_equal(arr, comparison_array) + assert_equal(b.count(), comparison_array.sum()) + + # Test now where we're in the middle of start + b = ba.bitarray(64) + b.set_range(33, 36, 1) + comparison_array = np.zeros(64, dtype="uint8") + comparison_array[33:36] = 1 + arr = b.as_bool_array().astype("uint8") + assert_array_equal(arr, comparison_array) + assert_equal(b.count(), comparison_array.sum()) + + # Now we test when we end on a byte edge, but we have 65 entries + b = ba.bitarray(65) + b.set_range(32, 64, 1) + comparison_array = np.zeros(65, dtype="uint8") + comparison_array[32:64] = 1 + arr = b.as_bool_array().astype("uint8") + assert_array_equal(arr, comparison_array) + assert_equal(b.count(), comparison_array.sum()) + + # Let's do the inverse + b = ba.bitarray(127) + b.set_range(0, 127, 1) + assert_equal(b.as_bool_array().all(), True) + b.set_range(0, 127, 0) + assert_equal(b.as_bool_array().any(), False) + b.set_range(3, 9, 1) + comparison_array = np.zeros(127, dtype="uint8") + comparison_array[3:9] = 1 + arr = b.as_bool_array().astype("uint8") + assert_array_equal(arr, comparison_array) + assert_equal(b.count(), comparison_array.sum()) + + # Now let's overlay some zeros + b.set_range(7, 10, 0) + comparison_array[7:10] = 0 + arr = b.as_bool_array().astype("uint8") + assert_array_equal(arr, comparison_array)