diff --git a/merlin/core/dispatch.py b/merlin/core/dispatch.py index 1ab11a240..e17f8b562 100644 --- a/merlin/core/dispatch.py +++ b/merlin/core/dispatch.py @@ -16,7 +16,7 @@ import enum import functools import itertools -from typing import Callable, Union +from typing import Callable, Optional, Union import dask.dataframe as dd import numpy as np @@ -311,7 +311,7 @@ def series_has_nulls(s): return s.has_nulls -def list_val_dtype(ser: SeriesLike) -> np.dtype: +def list_val_dtype(ser: SeriesLike) -> Optional[np.dtype]: """ Return the dtype of the leaves from a list or nested list @@ -322,8 +322,8 @@ def list_val_dtype(ser: SeriesLike) -> np.dtype: Returns ------- - np.dtype - The dtype of the innermost elements + Optional[np.dtype] + The dtype of the innermost elements if we find one """ if is_list_dtype(ser): if cudf is not None and isinstance(ser, cudf.Series): @@ -331,7 +331,12 @@ def list_val_dtype(ser: SeriesLike) -> np.dtype: ser = ser.list.leaves return ser.dtype elif isinstance(ser, pd.Series): - return pd.core.dtypes.cast.infer_dtype_from(next(iter(pd.core.common.flatten(ser))))[0] + try: + return pd.core.dtypes.cast.infer_dtype_from( + next(iter(pd.core.common.flatten(ser))) + )[0] + except StopIteration: + return None if isinstance(ser, np.ndarray): return ser.dtype # adds detection when in merlin column diff --git a/merlin/io/dataset.py b/merlin/io/dataset.py index 2d7eb9487..288976c73 100644 --- a/merlin/io/dataset.py +++ b/merlin/io/dataset.py @@ -1212,13 +1212,12 @@ def sample_dtypes(self, n=1, annotate_lists=False): if annotate_lists: _real_meta = self._real_meta[n] - annotated = { - col: { - "dtype": list_val_dtype(_real_meta[col]) or _real_meta[col].dtype, - "is_list": is_list_dtype(_real_meta[col]), - } - for col in _real_meta.columns - } + annotated = {} + for col in _real_meta.columns: + is_list = is_list_dtype(_real_meta[col]) + dtype = list_val_dtype(_real_meta[col]) if is_list else _real_meta[col].dtype + annotated[col] = {"dtype": dtype, "is_list": is_list} + return annotated return self._real_meta[n].dtypes diff --git a/tests/unit/io/test_dataset.py b/tests/unit/io/test_dataset.py index 2250f476b..a8452fb8c 100644 --- a/tests/unit/io/test_dataset.py +++ b/tests/unit/io/test_dataset.py @@ -49,3 +49,9 @@ def test_false_with_cudf_and_gpu(self): def test_false_missing_cudf_or_gpu(self): with pytest.raises(RuntimeError): Dataset(make_df({"a": [1, 2, 3]}), cpu=False) + + +def test_infer_list_dtype_unknown(): + df = pd.DataFrame({"col": [[], []]}) + dataset = Dataset(df, cpu=True) + assert dataset.schema["col"].dtype.element_type.value == "unknown"