diff --git a/pysages/backends/utils.py b/pysages/backends/utils.py index 8e5eb1b6..c5191234 100644 --- a/pysages/backends/utils.py +++ b/pysages/backends/utils.py @@ -9,7 +9,7 @@ from numpy.ctypeslib import as_ctypes_type from pysages.typing import JaxArray -from pysages.utils import dispatch +from pysages.utils import dispatch, unsafe_buffer_pointer def cupy_helpers(): @@ -38,7 +38,7 @@ def view(array: JaxArray): # NOTE: We need a more general strategy to handle # `SharedDeviceArray`s and `GlobalDeviceArray`s. ptype = ctypes.POINTER(as_ctypes_type(array.dtype)) - addr = array.device_buffer.unsafe_buffer_pointer() + addr = unsafe_buffer_pointer(array) ptr = ctypes.cast(ctypes.c_void_p(addr), ptype) return numba.carray(ptr, array.shape) diff --git a/pysages/utils/__init__.py b/pysages/utils/__init__.py index ffde6466..e52f8df2 100644 --- a/pysages/utils/__init__.py +++ b/pysages/utils/__init__.py @@ -17,6 +17,7 @@ prod, solve_pos_def, try_import, + unsafe_buffer_pointer, ) from .core import ( ToCPU, diff --git a/pysages/utils/compat.py b/pysages/utils/compat.py index 443b70f5..e4382646 100644 --- a/pysages/utils/compat.py +++ b/pysages/utils/compat.py @@ -36,11 +36,24 @@ def prod(iterable, start=1): return result -# Compatibility for jax >=0.4.27 +# Compatibility for jax >=0.4.22 -# https://github.com/google/jax/blob/main/CHANGELOG.md#jax-0427-may-7-2024 +# https://github.com/google/jax/blob/main/CHANGELOG.md#jax-0422-dec-13-2023 +if _jax_version_tuple < (0, 4, 22): -if _jax_version_tuple < (0, 4, 27): + def unsafe_buffer_pointer(array): + return array.device_buffer.unsafe_buffer_pointer() + +else: + + def unsafe_buffer_pointer(array): + return array.unsafe_buffer_pointer() + + +# Compatibility for jax >=0.4.21 + +# https://github.com/google/jax/blob/main/CHANGELOG.md#jax-0421-dec-4-2023 +if _jax_version_tuple < (0, 4, 21): def device_platform(array): return array.device().platform