diff --git a/src/biosets/data_handling/dict/dict.py b/src/biosets/data_handling/dict/dict.py index 7e23fff..6cb6414 100644 --- a/src/biosets/data_handling/dict/dict.py +++ b/src/biosets/data_handling/dict/dict.py @@ -60,7 +60,7 @@ def to_dicts(self, X: Dict[str, list], **kwargs): return [{k: v[i] for k, v in X.items()} for i in range(len(first_entry))] def to_numpy(self, X: Dict[str, list], **kwargs): - return np.asarray(self.to_list(X, **kwargs), **np_asarray_kwargs(kwargs)) + return np.asarray(self.to_list(X, **kwargs), **get_kwargs(kwargs, np.asarray)) def to_pandas(self, X: Dict[str, list], **kwargs): return pd.DataFrame(X, **get_kwargs(kwargs, pd.DataFrame.__init__)) diff --git a/src/biosets/utils/import_util.py b/src/biosets/utils/import_util.py index 83f2378..2145f18 100644 --- a/src/biosets/utils/import_util.py +++ b/src/biosets/utils/import_util.py @@ -130,6 +130,7 @@ def _is_package_available( _matplotlib_available = _is_package_available("matplotlib") _natten_available = _is_package_available("natten") _nltk_available = _is_package_available("nltk") +_numpy_version = importlib.metadata.version("numpy") _onnx_available = _is_package_available("onnx") _openai_available = _is_package_available("openai") _optimum_available = _is_package_available("optimum") diff --git a/src/biosets/utils/inspect.py b/src/biosets/utils/inspect.py index b7c8263..ac317fd 100644 --- a/src/biosets/utils/inspect.py +++ b/src/biosets/utils/inspect.py @@ -1,7 +1,9 @@ import inspect import sys +import version from biosets.utils import logging +from biosets.utils.import_util import _numpy_version logger = logging.get_logger(__name__) @@ -123,15 +125,17 @@ def ds_init_kwargs(kwargs: dict): def np_asarray_kwargs(kwargs: dict): - dtype = kwargs.get("dtype", None) - order = kwargs.get("order", None) - device = kwargs.get("device", None) - copy = kwargs.get("copy", None) - like = kwargs.get("like", None) - return { - "dtype": dtype, - "order": order, - "device": device, - "copy": copy, - "like": like, - } + if version.parse(_numpy_version) < version.parse("2.0.0"): + return { + "dtype": kwargs.get("dtype", None), + "order": kwargs.get("order", None), + "like": kwargs.get("like", None), + } + else: + return { + "dtype": kwargs.get("dtype", None), + "order": kwargs.get("order", None), + "device": kwargs.get("device", None), + "copy": kwargs.get("copy", None), + "like": kwargs.get("like", None), + }