Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

categorical from IntEnum #36

Merged
merged 1 commit into from
Jul 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
repos:

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
rev: v4.4.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
Expand All @@ -14,18 +14,18 @@ repos:
- id: nbstripout

- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.257
rev: v0.0.274
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]

- repo: https://github.com/pycqa/isort
rev: 5.10.1
rev: 5.12.0
hooks:
- id: isort
args: ["--profile", "black", "--filter-files"]

- repo: https://github.com/psf/black
rev: 22.10.0
rev: 23.3.0
hooks:
- id: black
152 changes: 152 additions & 0 deletions sharrow/categorical.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from __future__ import annotations

from enum import IntEnum
from functools import reduce

import numpy as np
import pandas as pd
import xarray as xr


Expand Down Expand Up @@ -38,3 +42,151 @@ def category_array(self) -> np.ndarray:

def is_categorical(self) -> bool:
return "dictionary" in self.dataarray.attrs.get("digital_encoding", {})


def _interpret_enum(e: type[IntEnum], value: int | str) -> IntEnum:
"""
Convert a string or integer into an Enum value.

The

Parameters
----------
e : Type[IntEnum]
The enum to use in interpretation.
value: int or str
The value to convert. Integer and simple string values are converted
to their corresponding value. Multiple string values can also be given
joined by the pipe operator, in the style of flags (e.g. "Red|Octagon").
"""
if isinstance(value, int):
return e(value)
return reduce(lambda x, y: x | y, [getattr(e, v) for v in value.split("|")])


def get_enum_name(e: type[IntEnum], value: int) -> str:
"""
Get the name of an enum by value, or a placeholder name if not found.

This allows for dummy placeholder names is the enum is does not contain
all consecutive values between 0 and the maximum value, inclusive.

Parameters
----------
e : Type[IntEnum]
The enum to use in interpretation.
value : int
The value for which to find a name. If not found in `e`, this
function will generate a new name as a string by prefixing `value`
with an underscore.

Returns
-------
str
"""
result = e._value2member_map_.get(value, f"_{value}")
try:
return result.name
except AttributeError:
return result


def int_enum_to_categorical_dtype(e: type[IntEnum]) -> pd.CategoricalDtype:
"""
Convert an integer-valued enum to a pandas CategoricalDtype.

Parameters
----------
e : Type[IntEnum]

Returns
-------
pd.CategoricalDtype
"""
max_enum_value = int(max(e))
categories = [get_enum_name(e, i) for i in range(max_enum_value + 1)]
return pd.CategoricalDtype(categories=categories)


def as_int_enum(
s: pd.Series,
e: type[IntEnum],
dtype: type[np.integer] | None = None,
categorical: bool = True,
) -> pd.Series:
"""
Encode a pandas Series as categorical, consistent with an IntEnum.

Parameters
----------
s : pd.Series
e : Type[IntEnum]
dtype : Type[np.integer], optional
Specific dtype to use for the code point encoding. It is typically not
necessary to give this explicitly as the function will automatically
select the best (most efficient) bitwidth.
categorical : bool, default True
If set to false, the returned series will simply be integer encoded with
no formal Categorical dtype.

Returns
-------
pd.Series
"""
min_enum_value = int(min(e))
max_enum_value = int(max(e))
assert min_enum_value >= 0
if dtype is None:
if max_enum_value < 256 and min_enum_value >= 0:
dtype = np.uint8
elif max_enum_value < 128 and min_enum_value >= -128:
dtype = np.int8
elif max_enum_value < 65536 and min_enum_value >= 0:
dtype = np.uint16
elif max_enum_value < 32768 and min_enum_value >= -32768:
dtype = np.int16
elif max_enum_value < 2_147_483_648 and min_enum_value >= -2_147_483_648:
dtype = np.int32
else:
dtype = np.int64
if not isinstance(s, pd.Series):
s = pd.Series(s)
result = s.apply(lambda x: _interpret_enum(e, x)).astype(dtype)
if categorical:
categories = [get_enum_name(e, i) for i in range(max_enum_value + 1)]
result = pd.Categorical.from_codes(codes=result, categories=categories)
return result


@pd.api.extensions.register_series_accessor("as_int_enum")
class _AsIntEnum:
"""
Encode a pandas Series as categorical, consistent with an IntEnum.

Parameters
----------
s : pd.Series
e : Type[IntEnum]
dtype : Type[np.integer], optional
Specific dtype to use for the code point encoding. It is typically not
necessary to give this explicitly as the function will automatically
select the best (most efficient) bitwidth.
categorical : bool, default True
If set to false, the returned series will simply be integer encoded with
no formal Categorical dtype.

Returns
-------
pd.Series
"""

def __init__(self, pandas_obj):
self._obj = pandas_obj

def __call__(
self: pd.Series,
e: type[IntEnum],
dtype: type[np.integer] | None = None,
categorical: bool = True,
):
return as_int_enum(self._obj, e, dtype, categorical)
7 changes: 7 additions & 0 deletions sharrow/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,13 @@ def dataset_from_dataframe_fast(
If true, create a sparse arrays instead of dense numpy arrays. This
can potentially save a large amount of memory if the DataFrame has
a MultiIndex. Requires the sparse package (sparse.pydata.org).
preserve_cat : bool, default True
If true, preserve encoding of categorical columns. Xarray lacks an
official implementation of a categorical datatype, so sharrow's
dictionary-based digital encoding is applied instead. Note that in
native xarray usage, the resulting variable will look like integer
values instead of the category values. The `dataset.cat` accessor
can be used to interact with the categorical data.

Returns
-------
Expand Down
26 changes: 22 additions & 4 deletions sharrow/tests/test_categorical.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from __future__ import annotations

from enum import IntEnum

import numpy as np
import pandas as pd
import xarray as xr

import sharrow


def test_simple_cat(tours_dataset: xr.Dataset):

tree = sharrow.DataTree(tours=tours_dataset)

assert all(tours_dataset.TourMode.cat.categories == ["Bus", "Car", "Walk"])
Expand All @@ -26,7 +28,6 @@ def test_2_level_tree_cat(
tours_dataset: xr.Dataset,
person_dataset: xr.Dataset,
):

tree = sharrow.DataTree(tours=tours_dataset)
tree.add_dataset("persons", person_dataset, "tours.person_id @ persons.person_id")

Expand All @@ -53,7 +54,6 @@ def test_3_level_tree_cat(
person_dataset: xr.Dataset,
household_dataset: xr.Dataset,
):

tree = sharrow.DataTree(tours=tours_dataset)
tree.add_dataset("persons", person_dataset, "tours.person_id @ persons.person_id")
tree.add_dataset(
Expand Down Expand Up @@ -83,7 +83,6 @@ def test_rootless_tree_cat(
person_dataset: xr.Dataset,
household_dataset: xr.Dataset,
):

tree = sharrow.DataTree(tours=tours_dataset, root_node_name=False)
tree.add_dataset("persons", person_dataset, "tours.person_id @ persons.person_id")
tree.add_dataset(
Expand All @@ -97,3 +96,22 @@ def test_rootless_tree_cat(
a = f.load_dataarray(dtype=np.int8)
a = a.isel(expressions=0)
assert all(a == np.asarray([0, 1, 0, 0, 0]))


def test_int_enum_categorical():
class TourMode(IntEnum):
Car = 1
Bus = 2
Walk = 3

df = pd.DataFrame(
{
"TourMode": ["Car", "Bus", "Car", "Car", "Walk"],
"person_id": [441, 445, 552, 556, 934],
},
index=pd.Index([4411, 4451, 5521, 5561, 9341], name="tour_id"),
)
df["TourMode2"] = df["TourMode"].as_int_enum(TourMode)
assert df["TourMode2"].dtype == "category"
assert all(df["TourMode2"].cat.categories == ["_0", "Car", "Bus", "Walk"])
assert all(df["TourMode2"].cat.codes == [1, 2, 1, 1, 3])
Loading