Skip to content

Commit

Permalink
make it compatible with numpy 1.* and 2.*
Browse files Browse the repository at this point in the history
  • Loading branch information
psmyth94 committed Oct 13, 2024
1 parent 5267be3 commit bd7ce01
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/biosets/data_handling/dict/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def to_dicts(self, X: Dict[str, list], **kwargs):
return [{k: v[i] for k, v in X.items()} for i in range(len(first_entry))]

def to_numpy(self, X: Dict[str, list], **kwargs):
return np.asarray(self.to_list(X, **kwargs), **np_asarray_kwargs(kwargs))
return np.asarray(self.to_list(X, **kwargs), **get_kwargs(kwargs, np.asarray))

def to_pandas(self, X: Dict[str, list], **kwargs):
return pd.DataFrame(X, **get_kwargs(kwargs, pd.DataFrame.__init__))
Expand Down
1 change: 1 addition & 0 deletions src/biosets/utils/import_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def _is_package_available(
_matplotlib_available = _is_package_available("matplotlib")
_natten_available = _is_package_available("natten")
_nltk_available = _is_package_available("nltk")
_numpy_version = importlib.metadata.version("numpy")
_onnx_available = _is_package_available("onnx")
_openai_available = _is_package_available("openai")
_optimum_available = _is_package_available("optimum")
Expand Down
28 changes: 16 additions & 12 deletions src/biosets/utils/inspect.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import inspect
import sys
import version

from biosets.utils import logging
from biosets.utils.import_util import _numpy_version

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -123,15 +125,17 @@ def ds_init_kwargs(kwargs: dict):


def np_asarray_kwargs(kwargs: dict):
dtype = kwargs.get("dtype", None)
order = kwargs.get("order", None)
device = kwargs.get("device", None)
copy = kwargs.get("copy", None)
like = kwargs.get("like", None)
return {
"dtype": dtype,
"order": order,
"device": device,
"copy": copy,
"like": like,
}
if version.parse(_numpy_version) < version.parse("2.0.0"):
return {
"dtype": kwargs.get("dtype", None),
"order": kwargs.get("order", None),
"like": kwargs.get("like", None),
}
else:
return {
"dtype": kwargs.get("dtype", None),
"order": kwargs.get("order", None),
"device": kwargs.get("device", None),
"copy": kwargs.get("copy", None),
"like": kwargs.get("like", None),
}

0 comments on commit bd7ce01

Please sign in to comment.