diff --git a/altair/utils/core.py b/altair/utils/core.py index c2b0634cc..ea275d589 100644 --- a/altair/utils/core.py +++ b/altair/utils/core.py @@ -37,7 +37,7 @@ else: from typing_extensions import ParamSpec -from typing import Literal, Protocol, TYPE_CHECKING +from typing import Literal, Protocol, TYPE_CHECKING, runtime_checkable if TYPE_CHECKING: from pandas.core.interchange.dataframe_protocol import Column as PandasColumn @@ -46,6 +46,7 @@ P = ParamSpec("P") +@runtime_checkable class DataFrameLike(Protocol): def __dataframe__( self, nan_as_null: bool = False, allow_copy: bool = True diff --git a/altair/utils/data.py b/altair/utils/data.py index 0e9071209..e4b135d38 100644 --- a/altair/utils/data.py +++ b/altair/utils/data.py @@ -105,7 +105,7 @@ def raise_max_rows_error(): # mypy gets confused as it doesn't see Dict[Any, Any] # as equivalent to TDataType return data # type: ignore[return-value] - elif hasattr(data, "__dataframe__"): + elif isinstance(data, DataFrameLike): pa_table = arrow_table_from_dfi_dataframe(data) if max_rows is not None and pa_table.num_rows > max_rows: raise_max_rows_error() @@ -141,7 +141,7 @@ def sample( else: # Maybe this should raise an error or return something useful? return None - elif hasattr(data, "__dataframe__"): + elif isinstance(data, DataFrameLike): pa_table = arrow_table_from_dfi_dataframe(data) if not n: if frac is None: @@ -229,7 +229,7 @@ def to_values(data: DataType) -> ToValuesReturnType: if "values" not in data: raise KeyError("values expected in data dict, but not present.") return data - elif hasattr(data, "__dataframe__"): + elif isinstance(data, DataFrameLike): pa_table = sanitize_arrow_table(arrow_table_from_dfi_dataframe(data)) return {"values": pa_table.to_pylist()} else: @@ -272,7 +272,7 @@ def _data_to_json_string(data: DataType) -> str: if "values" not in data: raise KeyError("values expected in data dict, but not present.") return json.dumps(data["values"], sort_keys=True) - elif hasattr(data, "__dataframe__"): + elif isinstance(data, DataFrameLike): pa_table = arrow_table_from_dfi_dataframe(data) return json.dumps(pa_table.to_pylist()) else: @@ -296,7 +296,7 @@ def _data_to_csv_string(data: Union[dict, pd.DataFrame, DataFrameLike]) -> str: if "values" not in data: raise KeyError("values expected in data dict, but not present") return pd.DataFrame.from_dict(data["values"]).to_csv(index=False) - elif hasattr(data, "__dataframe__"): + elif isinstance(data, DataFrameLike): # experimental interchange dataframe support import pyarrow as pa import pyarrow.csv as pa_csv