diff --git a/rapidtide/tests/test_sharedmem.py b/rapidtide/tests/test_sharedmem.py index 34fad977..4c59fa57 100755 --- a/rapidtide/tests/test_sharedmem.py +++ b/rapidtide/tests/test_sharedmem.py @@ -26,40 +26,33 @@ def test_numpy2shared(debug=False): if debug: print(f"{intype=}, {sourcevector.size=}, {sourcevector.dtype=}") for outtype in [np.float32, np.float64]: + + destvector, shm = tide_util.numpy2shared(sourcevector, outtype) if debug: - print(f"\t{outtype=}") - for function in [tide_util.numpy2shared_old, tide_util.numpy2shared_new]: - destvector, shm = function(sourcevector, outtype) - if debug: - print(f"\t\t{function=}, {destvector.size=}, {destvector.dtype=}") + print(f"\t{outtype=}, {destvector.size=}, {destvector.dtype=}") - # check everything - assert destvector.dtype == outtype - assert destvector.size == sourcevector.size - np.testing.assert_almost_equal(sourcevector, destvector, 3) + # check everything + assert destvector.dtype == outtype + assert destvector.size == sourcevector.size + np.testing.assert_almost_equal(sourcevector, destvector, 3) - # clean up if needed - if shm is not None: - tide_util.cleanup_shm_new(shm) + # clean up + tide_util.cleanup_shm(shm) def test_allocshared(debug=False): datashape = (10, 10, 10) for outtype in [np.float32, np.float64]: + destarray, shm = tide_util.allocshared(datashape, outtype) if debug: - print(f"{outtype=}") - for function in [tide_util.allocshared_old, tide_util.allocshared_new]: - destarray, shm = function(datashape, outtype) - if debug: - print(f"\t{function=}, {destarray.size=}, {destarray.dtype=}") + print(f"{outtype=}, {destarray.size=}, {destarray.dtype=}") - # check everything - assert destarray.dtype == outtype - assert destarray.size == np.prod(datashape) + # check everything + assert destarray.dtype == outtype + assert destarray.size == np.prod(datashape) - # clean up if needed - if shm is not None: - tide_util.cleanup_shm_new(shm) + # clean up if needed + tide_util.cleanup_shm(shm) if __name__ == "__main__": diff --git a/rapidtide/util.py b/rapidtide/util.py index a5787d6d..9b19781c 100644 --- a/rapidtide/util.py +++ b/rapidtide/util.py @@ -1003,36 +1003,7 @@ def comparehappyruns(root1, root2, debug=False): # shared memory routines -def numpy2shared_old(inarray, thetype): - thesize = inarray.size - theshape = inarray.shape - if thetype == np.float64: - inarray_shared = RawArray("d", inarray.reshape(thesize)) - else: - inarray_shared = RawArray("f", inarray.reshape(thesize)) - inarray = np.frombuffer(inarray_shared, dtype=thetype, count=thesize) - inarray.shape = theshape - return inarray, None - - -def allocshared_old(theshape, thetype): - thesize = int(1) - if not isinstance(theshape, (list, tuple)): - thesize = theshape - else: - for element in theshape: - thesize *= int(element) - if thetype == np.float64: - outarray_shared = RawArray("d", thesize) - else: - outarray_shared = RawArray("f", thesize) - outarray = np.frombuffer(outarray_shared, dtype=thetype, count=thesize) - outarray.shape = theshape - return outarray, None - - -# shared memory routines -def numpy2shared_new(inarray, theouttype): +def numpy2shared(inarray, theouttype): # Create a shared memory block to store the array data outnbytes = np.dtype(theouttype).itemsize * inarray.size shm = shared_memory.SharedMemory(create=True, size=outnbytes) @@ -1041,7 +1012,7 @@ def numpy2shared_new(inarray, theouttype): return inarray_shared, shm # Return both the array and the shared memory object -def allocshared_new(theshape, thetype): +def allocshared(theshape, thetype): # Calculate size based on shape thesize = np.prod(theshape) # Determine the data type size @@ -1052,32 +1023,8 @@ def allocshared_new(theshape, thetype): return outarray, shm # Return both the array and the shared memory object -def cleanup_shm_new(shm): - # Cleanup - shm.close() - shm.unlink() - - -newshm = True - - -def numpy2shared(inarray, thetype): - if newshm: - return numpy2shared_new(inarray, thetype) - else: - return numpy2shared_old(inarray, thetype) - - -def allocshared(theshape, thetype): - if newshm: - return allocshared_new(theshape, thetype) - else: - return allocshared_old(theshape, thetype) - - def cleanup_shm(shm): - if newshm: - if shm is not None: - cleanup_shm_new(shm) - else: - return + # Cleanup + if shm is not None: + shm.close() + shm.unlink()