Skip to content

Commit

Permalink
Merge pull request #36 from jpn--/int-enum
Browse files Browse the repository at this point in the history
categorical from IntEnum
  • Loading branch information
jpn-- authored Jul 14, 2023
2 parents 11cc7f1 + 2f40a25 commit 5d07145
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 8 deletions.
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
repos:

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
rev: v4.4.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
Expand All @@ -14,18 +14,18 @@ repos:
- id: nbstripout

- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.257
rev: v0.0.274
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]

- repo: https://github.com/pycqa/isort
rev: 5.10.1
rev: 5.12.0
hooks:
- id: isort
args: ["--profile", "black", "--filter-files"]

- repo: https://github.com/psf/black
rev: 22.10.0
rev: 23.3.0
hooks:
- id: black
152 changes: 152 additions & 0 deletions sharrow/categorical.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from __future__ import annotations

from enum import IntEnum
from functools import reduce

import numpy as np
import pandas as pd
import xarray as xr


Expand Down Expand Up @@ -38,3 +42,151 @@ def category_array(self) -> np.ndarray:

def is_categorical(self) -> bool:
return "dictionary" in self.dataarray.attrs.get("digital_encoding", {})


def _interpret_enum(e: type[IntEnum], value: int | str) -> IntEnum:
"""
Convert a string or integer into an Enum value.
The
Parameters
----------
e : Type[IntEnum]
The enum to use in interpretation.
value: int or str
The value to convert. Integer and simple string values are converted
to their corresponding value. Multiple string values can also be given
joined by the pipe operator, in the style of flags (e.g. "Red|Octagon").
"""
if isinstance(value, int):
return e(value)
return reduce(lambda x, y: x | y, [getattr(e, v) for v in value.split("|")])


def get_enum_name(e: type[IntEnum], value: int) -> str:
"""
Get the name of an enum by value, or a placeholder name if not found.
This allows for dummy placeholder names is the enum is does not contain
all consecutive values between 0 and the maximum value, inclusive.
Parameters
----------
e : Type[IntEnum]
The enum to use in interpretation.
value : int
The value for which to find a name. If not found in `e`, this
function will generate a new name as a string by prefixing `value`
with an underscore.
Returns
-------
str
"""
result = e._value2member_map_.get(value, f"_{value}")
try:
return result.name
except AttributeError:
return result


def int_enum_to_categorical_dtype(e: type[IntEnum]) -> pd.CategoricalDtype:
"""
Convert an integer-valued enum to a pandas CategoricalDtype.
Parameters
----------
e : Type[IntEnum]
Returns
-------
pd.CategoricalDtype
"""
max_enum_value = int(max(e))
categories = [get_enum_name(e, i) for i in range(max_enum_value + 1)]
return pd.CategoricalDtype(categories=categories)


def as_int_enum(
s: pd.Series,
e: type[IntEnum],
dtype: type[np.integer] | None = None,
categorical: bool = True,
) -> pd.Series:
"""
Encode a pandas Series as categorical, consistent with an IntEnum.
Parameters
----------
s : pd.Series
e : Type[IntEnum]
dtype : Type[np.integer], optional
Specific dtype to use for the code point encoding. It is typically not
necessary to give this explicitly as the function will automatically
select the best (most efficient) bitwidth.
categorical : bool, default True
If set to false, the returned series will simply be integer encoded with
no formal Categorical dtype.
Returns
-------
pd.Series
"""
min_enum_value = int(min(e))
max_enum_value = int(max(e))
assert min_enum_value >= 0
if dtype is None:
if max_enum_value < 256 and min_enum_value >= 0:
dtype = np.uint8
elif max_enum_value < 128 and min_enum_value >= -128:
dtype = np.int8
elif max_enum_value < 65536 and min_enum_value >= 0:
dtype = np.uint16
elif max_enum_value < 32768 and min_enum_value >= -32768:
dtype = np.int16
elif max_enum_value < 2_147_483_648 and min_enum_value >= -2_147_483_648:
dtype = np.int32
else:
dtype = np.int64
if not isinstance(s, pd.Series):
s = pd.Series(s)
result = s.apply(lambda x: _interpret_enum(e, x)).astype(dtype)
if categorical:
categories = [get_enum_name(e, i) for i in range(max_enum_value + 1)]
result = pd.Categorical.from_codes(codes=result, categories=categories)
return result


@pd.api.extensions.register_series_accessor("as_int_enum")
class _AsIntEnum:
"""
Encode a pandas Series as categorical, consistent with an IntEnum.
Parameters
----------
s : pd.Series
e : Type[IntEnum]
dtype : Type[np.integer], optional
Specific dtype to use for the code point encoding. It is typically not
necessary to give this explicitly as the function will automatically
select the best (most efficient) bitwidth.
categorical : bool, default True
If set to false, the returned series will simply be integer encoded with
no formal Categorical dtype.
Returns
-------
pd.Series
"""

def __init__(self, pandas_obj):
self._obj = pandas_obj

def __call__(
self: pd.Series,
e: type[IntEnum],
dtype: type[np.integer] | None = None,
categorical: bool = True,
):
return as_int_enum(self._obj, e, dtype, categorical)
7 changes: 7 additions & 0 deletions sharrow/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,13 @@ def dataset_from_dataframe_fast(
If true, create a sparse arrays instead of dense numpy arrays. This
can potentially save a large amount of memory if the DataFrame has
a MultiIndex. Requires the sparse package (sparse.pydata.org).
preserve_cat : bool, default True
If true, preserve encoding of categorical columns. Xarray lacks an
official implementation of a categorical datatype, so sharrow's
dictionary-based digital encoding is applied instead. Note that in
native xarray usage, the resulting variable will look like integer
values instead of the category values. The `dataset.cat` accessor
can be used to interact with the categorical data.
Returns
-------
Expand Down
26 changes: 22 additions & 4 deletions sharrow/tests/test_categorical.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from __future__ import annotations

from enum import IntEnum

import numpy as np
import pandas as pd
import xarray as xr

import sharrow


def test_simple_cat(tours_dataset: xr.Dataset):

tree = sharrow.DataTree(tours=tours_dataset)

assert all(tours_dataset.TourMode.cat.categories == ["Bus", "Car", "Walk"])
Expand All @@ -26,7 +28,6 @@ def test_2_level_tree_cat(
tours_dataset: xr.Dataset,
person_dataset: xr.Dataset,
):

tree = sharrow.DataTree(tours=tours_dataset)
tree.add_dataset("persons", person_dataset, "tours.person_id @ persons.person_id")

Expand All @@ -53,7 +54,6 @@ def test_3_level_tree_cat(
person_dataset: xr.Dataset,
household_dataset: xr.Dataset,
):

tree = sharrow.DataTree(tours=tours_dataset)
tree.add_dataset("persons", person_dataset, "tours.person_id @ persons.person_id")
tree.add_dataset(
Expand Down Expand Up @@ -83,7 +83,6 @@ def test_rootless_tree_cat(
person_dataset: xr.Dataset,
household_dataset: xr.Dataset,
):

tree = sharrow.DataTree(tours=tours_dataset, root_node_name=False)
tree.add_dataset("persons", person_dataset, "tours.person_id @ persons.person_id")
tree.add_dataset(
Expand All @@ -97,3 +96,22 @@ def test_rootless_tree_cat(
a = f.load_dataarray(dtype=np.int8)
a = a.isel(expressions=0)
assert all(a == np.asarray([0, 1, 0, 0, 0]))


def test_int_enum_categorical():
class TourMode(IntEnum):
Car = 1
Bus = 2
Walk = 3

df = pd.DataFrame(
{
"TourMode": ["Car", "Bus", "Car", "Car", "Walk"],
"person_id": [441, 445, 552, 556, 934],
},
index=pd.Index([4411, 4451, 5521, 5561, 9341], name="tour_id"),
)
df["TourMode2"] = df["TourMode"].as_int_enum(TourMode)
assert df["TourMode2"].dtype == "category"
assert all(df["TourMode2"].cat.categories == ["_0", "Car", "Bus", "Walk"])
assert all(df["TourMode2"].cat.codes == [1, 2, 1, 1, 3])

0 comments on commit 5d07145

Please sign in to comment.