diff --git a/scarf/utils.py b/scarf/utils.py index 1195468..f89c57a 100644 --- a/scarf/utils.py +++ b/scarf/utils.py @@ -173,10 +173,20 @@ def controlled_compute(arr, nthreads): Returns: Result of computation. """ - from multiprocessing.pool import ThreadPool import dask - with dask.config.set(schedular="threads", pool=ThreadPool(nthreads)): # type: ignore + try: + # Multiprocessing may be faster, but it throws exception if SemLock is not implemented. + # For example, multiprocessing won't work on AWS Lambda, in those scenarios we switch ThreadPoolExecutor + from multiprocessing.pool import ThreadPool + + pool = ThreadPool(nthreads) + except Exception: + from concurrent.futures import ThreadPoolExecutor + + pool = ThreadPoolExecutor(nthreads) + + with dask.config.set(schedular="threads", pool=pool): # type: ignore res = arr.compute() return res