diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 835834c..30a97c0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.3.0 + rev: v4.4.0 hooks: - id: check-yaml - id: end-of-file-fixer @@ -14,18 +14,18 @@ repos: - id: nbstripout - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.257 + rev: v0.0.274 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] - repo: https://github.com/pycqa/isort - rev: 5.10.1 + rev: 5.12.0 hooks: - id: isort args: ["--profile", "black", "--filter-files"] - repo: https://github.com/psf/black - rev: 22.10.0 + rev: 23.3.0 hooks: - id: black diff --git a/sharrow/categorical.py b/sharrow/categorical.py index 30e5832..abddae2 100644 --- a/sharrow/categorical.py +++ b/sharrow/categorical.py @@ -1,6 +1,10 @@ from __future__ import annotations +from enum import IntEnum +from functools import reduce + import numpy as np +import pandas as pd import xarray as xr @@ -38,3 +42,151 @@ def category_array(self) -> np.ndarray: def is_categorical(self) -> bool: return "dictionary" in self.dataarray.attrs.get("digital_encoding", {}) + + +def _interpret_enum(e: type[IntEnum], value: int | str) -> IntEnum: + """ + Convert a string or integer into an Enum value. + + The + + Parameters + ---------- + e : Type[IntEnum] + The enum to use in interpretation. + value: int or str + The value to convert. Integer and simple string values are converted + to their corresponding value. Multiple string values can also be given + joined by the pipe operator, in the style of flags (e.g. "Red|Octagon"). + """ + if isinstance(value, int): + return e(value) + return reduce(lambda x, y: x | y, [getattr(e, v) for v in value.split("|")]) + + +def get_enum_name(e: type[IntEnum], value: int) -> str: + """ + Get the name of an enum by value, or a placeholder name if not found. + + This allows for dummy placeholder names is the enum is does not contain + all consecutive values between 0 and the maximum value, inclusive. + + Parameters + ---------- + e : Type[IntEnum] + The enum to use in interpretation. + value : int + The value for which to find a name. If not found in `e`, this + function will generate a new name as a string by prefixing `value` + with an underscore. + + Returns + ------- + str + """ + result = e._value2member_map_.get(value, f"_{value}") + try: + return result.name + except AttributeError: + return result + + +def int_enum_to_categorical_dtype(e: type[IntEnum]) -> pd.CategoricalDtype: + """ + Convert an integer-valued enum to a pandas CategoricalDtype. + + Parameters + ---------- + e : Type[IntEnum] + + Returns + ------- + pd.CategoricalDtype + """ + max_enum_value = int(max(e)) + categories = [get_enum_name(e, i) for i in range(max_enum_value + 1)] + return pd.CategoricalDtype(categories=categories) + + +def as_int_enum( + s: pd.Series, + e: type[IntEnum], + dtype: type[np.integer] | None = None, + categorical: bool = True, +) -> pd.Series: + """ + Encode a pandas Series as categorical, consistent with an IntEnum. + + Parameters + ---------- + s : pd.Series + e : Type[IntEnum] + dtype : Type[np.integer], optional + Specific dtype to use for the code point encoding. It is typically not + necessary to give this explicitly as the function will automatically + select the best (most efficient) bitwidth. + categorical : bool, default True + If set to false, the returned series will simply be integer encoded with + no formal Categorical dtype. + + Returns + ------- + pd.Series + """ + min_enum_value = int(min(e)) + max_enum_value = int(max(e)) + assert min_enum_value >= 0 + if dtype is None: + if max_enum_value < 256 and min_enum_value >= 0: + dtype = np.uint8 + elif max_enum_value < 128 and min_enum_value >= -128: + dtype = np.int8 + elif max_enum_value < 65536 and min_enum_value >= 0: + dtype = np.uint16 + elif max_enum_value < 32768 and min_enum_value >= -32768: + dtype = np.int16 + elif max_enum_value < 2_147_483_648 and min_enum_value >= -2_147_483_648: + dtype = np.int32 + else: + dtype = np.int64 + if not isinstance(s, pd.Series): + s = pd.Series(s) + result = s.apply(lambda x: _interpret_enum(e, x)).astype(dtype) + if categorical: + categories = [get_enum_name(e, i) for i in range(max_enum_value + 1)] + result = pd.Categorical.from_codes(codes=result, categories=categories) + return result + + +@pd.api.extensions.register_series_accessor("as_int_enum") +class _AsIntEnum: + """ + Encode a pandas Series as categorical, consistent with an IntEnum. + + Parameters + ---------- + s : pd.Series + e : Type[IntEnum] + dtype : Type[np.integer], optional + Specific dtype to use for the code point encoding. It is typically not + necessary to give this explicitly as the function will automatically + select the best (most efficient) bitwidth. + categorical : bool, default True + If set to false, the returned series will simply be integer encoded with + no formal Categorical dtype. + + Returns + ------- + pd.Series + """ + + def __init__(self, pandas_obj): + self._obj = pandas_obj + + def __call__( + self: pd.Series, + e: type[IntEnum], + dtype: type[np.integer] | None = None, + categorical: bool = True, + ): + return as_int_enum(self._obj, e, dtype, categorical) diff --git a/sharrow/dataset.py b/sharrow/dataset.py index 68d02e9..6ffab94 100755 --- a/sharrow/dataset.py +++ b/sharrow/dataset.py @@ -129,6 +129,13 @@ def dataset_from_dataframe_fast( If true, create a sparse arrays instead of dense numpy arrays. This can potentially save a large amount of memory if the DataFrame has a MultiIndex. Requires the sparse package (sparse.pydata.org). + preserve_cat : bool, default True + If true, preserve encoding of categorical columns. Xarray lacks an + official implementation of a categorical datatype, so sharrow's + dictionary-based digital encoding is applied instead. Note that in + native xarray usage, the resulting variable will look like integer + values instead of the category values. The `dataset.cat` accessor + can be used to interact with the categorical data. Returns ------- diff --git a/sharrow/tests/test_categorical.py b/sharrow/tests/test_categorical.py index 64edb8e..b500112 100644 --- a/sharrow/tests/test_categorical.py +++ b/sharrow/tests/test_categorical.py @@ -1,13 +1,15 @@ from __future__ import annotations +from enum import IntEnum + import numpy as np +import pandas as pd import xarray as xr import sharrow def test_simple_cat(tours_dataset: xr.Dataset): - tree = sharrow.DataTree(tours=tours_dataset) assert all(tours_dataset.TourMode.cat.categories == ["Bus", "Car", "Walk"]) @@ -26,7 +28,6 @@ def test_2_level_tree_cat( tours_dataset: xr.Dataset, person_dataset: xr.Dataset, ): - tree = sharrow.DataTree(tours=tours_dataset) tree.add_dataset("persons", person_dataset, "tours.person_id @ persons.person_id") @@ -53,7 +54,6 @@ def test_3_level_tree_cat( person_dataset: xr.Dataset, household_dataset: xr.Dataset, ): - tree = sharrow.DataTree(tours=tours_dataset) tree.add_dataset("persons", person_dataset, "tours.person_id @ persons.person_id") tree.add_dataset( @@ -83,7 +83,6 @@ def test_rootless_tree_cat( person_dataset: xr.Dataset, household_dataset: xr.Dataset, ): - tree = sharrow.DataTree(tours=tours_dataset, root_node_name=False) tree.add_dataset("persons", person_dataset, "tours.person_id @ persons.person_id") tree.add_dataset( @@ -97,3 +96,22 @@ def test_rootless_tree_cat( a = f.load_dataarray(dtype=np.int8) a = a.isel(expressions=0) assert all(a == np.asarray([0, 1, 0, 0, 0])) + + +def test_int_enum_categorical(): + class TourMode(IntEnum): + Car = 1 + Bus = 2 + Walk = 3 + + df = pd.DataFrame( + { + "TourMode": ["Car", "Bus", "Car", "Car", "Walk"], + "person_id": [441, 445, 552, 556, 934], + }, + index=pd.Index([4411, 4451, 5521, 5561, 9341], name="tour_id"), + ) + df["TourMode2"] = df["TourMode"].as_int_enum(TourMode) + assert df["TourMode2"].dtype == "category" + assert all(df["TourMode2"].cat.categories == ["_0", "Car", "Bus", "Walk"]) + assert all(df["TourMode2"].cat.codes == [1, 2, 1, 1, 3])