diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index 7ec36fc..79a20ba 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -46,12 +46,12 @@ jobs: run: | conda info -a conda list - - name: Lint with flake8 + - name: Lint with Ruff run: | # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + ruff check . --select=E9,F63,F7,F82 --statistics + # exit-zero treats all errors as warnings. + ruff check . --exit-zero --statistics - name: Test with pytest run: | python -m pytest diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 333916f..dcb1561 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.3.0 + rev: v4.4.0 hooks: - id: check-yaml - id: end-of-file-fixer @@ -13,18 +13,24 @@ repos: hooks: - id: nbstripout +- repo: https://github.com/charliermarsh/ruff-pre-commit + rev: v0.0.274 + hooks: + - id: ruff + args: [--fix, --exit-non-zero-on-fix] + - repo: https://github.com/pycqa/isort - rev: 5.10.1 + rev: 5.12.0 hooks: - id: isort args: ["--profile", "black", "--filter-files"] - repo: https://github.com/psf/black - rev: 22.10.0 + rev: 23.3.0 hooks: - id: black - repo: https://github.com/PyCQA/flake8 - rev: 5.0.4 + rev: 4.0.1 hooks: - id: flake8 diff --git a/docs/_script/hide_test_cells.py b/docs/_script/hide_test_cells.py index 0a4b5c6..a8e6bf3 100644 --- a/docs/_script/hide_test_cells.py +++ b/docs/_script/hide_test_cells.py @@ -11,10 +11,10 @@ # Text to look for in adding tags text_search_dict = { - "# TEST": "remove_cell", # Remove the whole cell - "# HIDDEN": "remove_cell", # Remove the whole cell - "# NO CODE": "remove_input", # Remove only the input - "# HIDE CODE": "hide_input", # Hide the input w/ a button to show + "# TEST": "remove-cell", # Remove the whole cell + "# HIDDEN": "remove-cell", # Remove the whole cell + "# NO CODE": "remove-input", # Remove only the input + "# HIDE CODE": "hide-input", # Hide the input w/ a button to show } # Search through each notebook and look for th text, add a tag if necessary diff --git a/docs/walkthrough/encoding.ipynb b/docs/walkthrough/encoding.ipynb index f8f3e2d..7f07159 100644 --- a/docs/walkthrough/encoding.ipynb +++ b/docs/walkthrough/encoding.ipynb @@ -16,7 +16,7 @@ "id": "f17c8818", "metadata": { "tags": [ - "remove_cell" + "remove-cell" ] }, "outputs": [], @@ -30,7 +30,9 @@ "cell_type": "code", "execution_count": null, "id": "d4e7246c", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "import numpy as np\n", @@ -217,7 +219,7 @@ "id": "2ad591bb", "metadata": { "tags": [ - "remove_cell" + "remove-cell" ] }, "outputs": [], @@ -296,7 +298,7 @@ "id": "bd549b5e", "metadata": { "tags": [ - "remove_cell" + "remove-cell" ] }, "outputs": [], @@ -377,7 +379,7 @@ "id": "7d74e53e", "metadata": { "tags": [ - "remove_cell" + "remove-cell" ] }, "outputs": [], @@ -477,7 +479,7 @@ "id": "a016d30f", "metadata": { "tags": [ - "remove_cell" + "remove-cell" ] }, "outputs": [], @@ -525,7 +527,7 @@ "id": "28afb335", "metadata": { "tags": [ - "remove_cell" + "remove-cell" ] }, "outputs": [], @@ -690,6 +692,195 @@ " _name='WLK_LOC_WLK_FAR'\n", ").to_series() == [0,152,474]).all()" ] + }, + { + "cell_type": "markdown", + "id": "cb219dc3-dd66-44cd-a7c5-2a1da4bc1467", + "metadata": { + "tags": [] + }, + "source": [ + "# Pandas Categorical Dtype\n", + "\n", + "Dictionary encoding is very similar to the approach used for the pandas Categorical dtype, and\n", + "can be used to achieve some of the efficiencies of categorical data, even though xarray lacks\n", + "a formal native categorical data representation. Sharrow's `construct` function for creating\n", + "Dataset objects will automatically use dictionary encoding for \"category\" data. \n", + "\n", + "To demonstrate, we'll load some household data and create a categorical data column." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3b765919-69a4-4fb0-b805-9d3b5fed7897", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "hh = sh.example_data.get_households()\n", + "hh[\"income_grp\"] = pd.cut(hh.income, bins=[-np.inf,30000,60000,np.inf], labels=['Low', \"Mid\", \"High\"])\n", + "hh = hh[[\"income\",\"income_grp\"]]\n", + "hh.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "312faa0b-13cf-4649-9835-7a53b5e81a0b", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "hh.info()" + ] + }, + { + "cell_type": "markdown", + "id": "c51a88d2-02b1-4502-9f4b-271fbb126699", + "metadata": {}, + "source": [ + "We'll then create a Dataset using construct." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cd1c2fd5-59c6-48cb-bd6e-d2f9dde2aa36", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "hh_dataset = sh.dataset.construct(hh[[\"income\",\"income_grp\"]])\n", + "hh_dataset" + ] + }, + { + "cell_type": "markdown", + "id": "033b3629-a16b-47a4-bb18-10af9c7c4f07", + "metadata": {}, + "source": [ + "Note that the \"income\" variable remains an integer as expected, but the \"income_grp\" variable, \n", + "which had been a \"category\" dtype in pandas, is now stored as an `int8`, giving the \n", + "category _index_ of each element (it would be an `int16` or larger if needed, but that's\n", + "not necessary with only 3 categories). The information about the labels for the categories is \n", + "retained not in the data itself but in the `digital_encoding`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "369442af-1c69-41eb-b530-ea398d6eac7a", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "hh_dataset[\"income_grp\"].digital_encoding" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "58db6505-1c90-475e-8d91-0e2e89ec0f0e", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# TESTING\n", + "assert hh_dataset[\"income_grp\"].dtype == \"int8\"\n", + "assert hh_dataset[\"income_grp\"].digital_encoding.keys() == {'dictionary', 'ordered'}\n", + "assert all(hh_dataset[\"income_grp\"].digital_encoding['dictionary'] == np.array(['Low', 'Mid', 'High'], dtype='=5.7.1 - nbmake diff --git a/envs/testing.yml b/envs/testing.yml index deb0879..76d1928 100644 --- a/envs/testing.yml +++ b/envs/testing.yml @@ -14,7 +14,7 @@ dependencies: - numexpr - sparse - filelock - - flake8 + - ruff # required for testing - pytest - pytest-cov diff --git a/pyproject.toml b/pyproject.toml index c4bee96..dea5508 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,15 @@ float_to_top = true default_section = "THIRDPARTY" known_first_party = "sharrow" +[tool.ruff] +# Enable flake8-bugbear (`B`) and pyupgrade ('UP') rules. +select = ["E", "F", "B", "UP"] +fix = true +ignore-init-module-imports = true +line-length = 120 +ignore = ["B905"] +target-version = "py39" + [tool.pytest.ini_options] minversion = "6.0" addopts = "-v --nbmake --disable-warnings" diff --git a/sharrow/__init__.py b/sharrow/__init__.py index df6c3b8..03d66a4 100644 --- a/sharrow/__init__.py +++ b/sharrow/__init__.py @@ -3,7 +3,29 @@ from . import dataset, example_data, selectors, shared_memory, sparse from ._infer_version import __version__, __version_tuple__ from .dataset import Dataset +from .datastore import DataStore from .digital_encoding import array_decode, array_encode from .flows import CacheMissWarning, Flow from .relationships import DataTree, Relationship from .table import Table, concat_tables + +__all__ = [ + "__version__", + "__version_tuple__", + "DataArray", + "Dataset", + "DataStore", + "DataTree", + "Relationship", + "Table", + "CacheMissWarning", + "Flow", + "example_data", + "array_decode", + "array_encode", + "concat_tables", + "dataset", + "selectors", + "shared_memory", + "sparse", +] diff --git a/sharrow/aster.py b/sharrow/aster.py index 4fb8f24..c85f0f9 100755 --- a/sharrow/aster.py +++ b/sharrow/aster.py @@ -1,7 +1,6 @@ import ast import io import logging -import sys import tokenize try: @@ -9,7 +8,9 @@ except ImportError: from astunparse import unparse as _unparse - unparse = lambda *args: _unparse(*args).strip("\n") + def unparse(*args): + return _unparse(*args).strip("\n") + logger = logging.getLogger("sharrow.aster") @@ -23,22 +24,21 @@ def unparse_(*args): raise -if sys.version_info >= (3, 8): - ast_Constant_Type = ast.Constant - ast_String_value = lambda x: x.value if isinstance(x, ast.Str) else x - ast_TupleIndex_Type = ast.Tuple - ast_Index_Value = lambda x: x - ast_Constant = ast.Constant -else: - ast_Constant_Type = (ast.Index, ast.Constant, ast.Str, ast.Num) - ast_String_value = ( - lambda x: x.s - if isinstance(x, ast.Str) - else (ast_String_value(x.value) if isinstance(x, ast.Index) else x) - ) - ast_TupleIndex_Type = (ast.Index, ast.Tuple) - ast_Index_Value = lambda x: x.value if isinstance(x, ast.Index) else x - ast_Constant = lambda x: ast.Constant(x, kind=None) +ast_Constant_Type = ast.Constant + + +def ast_String_value(x): + return x.value if isinstance(x, ast.Str) else x + + +ast_TupleIndex_Type = ast.Tuple + + +def ast_Index_Value(x): + return x + + +ast_Constant = ast.Constant def _isNone(c): @@ -69,7 +69,6 @@ def extract_names(command): def extract_names_2(command): - if not isinstance(command, str): return set(), dict(), dict() @@ -364,7 +363,8 @@ def log_event(self, tag, node1=None, node2=None): except: # noqa: E722 unparsed = f"{type(node1)} not unparseable" logger.debug( - f"RewriteForNumba({self.spacename}|{self.rawalias}).{tag} [{type(node1).__name__}]= {unparsed}", + f"RewriteForNumba({self.spacename}|{self.rawalias}).{tag} " + f"[{type(node1).__name__}]= {unparsed}", ) else: try: @@ -376,7 +376,9 @@ def log_event(self, tag, node1=None, node2=None): except: # noqa: E722 unparsed2 = f"{type(node2).__name__} not unparseable" logger.debug( - f"RewriteForNumba({self.spacename}|{self.rawalias}).{tag} [{type(node1).__name__},{type(node2).__name__}]= {unparsed1} => {unparsed2}", + f"RewriteForNumba({self.spacename}|{self.rawalias}).{tag} " + f"[{type(node1).__name__},{type(node2).__name__}]= " + f"{unparsed1} => {unparsed2}", ) def generic_visit(self, node): @@ -398,10 +400,12 @@ def _replacement( if self.spacevars is not None: if attr not in self.spacevars: - if topname == pref_topname and not self.swallow_errors: + if self.get_default or ( + topname == pref_topname and not self.swallow_errors + ): raise KeyError(f"{topname}..{attr}") - # we originally raised a KeyError here regardless, but what if we just - # give back the original node, and see if other spaces, + # we originally raised a KeyError here regardless, but what if + # we just give back the original node, and see if other spaces, # possibly fallback spaces, might work? If nothing works then # it will still eventually error out when compiling? # The swallow errors option allows us to continue processing @@ -446,16 +450,9 @@ def _maybe_transpose_first_two_args(_slice): if isinstance(n, int): elts.append(ast.Name(id=f"_arg{n:02}", ctx=ast.Load())) elif isinstance(n, dict): - if sys.version_info >= (3, 8): - elts.append( - ast.Constant(n=n[missing_dim_value], ctx=ast.Load()) - ) - else: - elts.append( - ast.Constant( - n[missing_dim_value], kind=None, ctx=ast.Load() - ) - ) + elts.append( + ast.Constant(n=n[missing_dim_value], ctx=ast.Load()) + ) else: elts.append(n) logger.debug(f"ELT {unparse_(elts[-1])}") @@ -487,7 +484,6 @@ def _maybe_transpose_first_two_args(_slice): digital_encoding = self.digital_encodings.get(attr, None) if digital_encoding is not None: - dictionary = digital_encoding.get("dictionary", None) offset_source = digital_encoding.get("offset_source", None) if dictionary is not None: @@ -604,7 +600,8 @@ def visit_Subscript(self, node): result, ) return result - # for XXX[YYY,ZZZ], XXX is a space name and YYY is a literal value and ZZZ is a literal value: skims['SOV_TIME','MD'] + # for XXX[YYY,ZZZ], XXX is a space name and YYY is a literal value and + # ZZZ is a literal value: skims['SOV_TIME','MD'] if ( node.value.id == self.spacename and isinstance(ast_Index_Value(node.slice), ast.Tuple) @@ -649,8 +646,10 @@ def visit_Subscript(self, node): def visit_Attribute(self, node): if isinstance(node.value, ast.Name): + # for XXX.YYY, XXX is a space name and YYY is a literal value: skims.DIST if node.value.id == self.spacename: return self._replacement(node.attr, node.ctx, node) + # for ____.YYY, handles unadorned values in the top level if node.value.id == self.rawalias and node.attr in self.spacevars: result = ast.Subscript( value=ast.Name(id=self.rawname, ctx=ast.Load()), @@ -659,6 +658,7 @@ def visit_Attribute(self, node): ) self.log_event(f"visit_Attribute(Raw {node.attr})", node, result) return result + # for YYY.ZZZ, where YYY is a variable in the root and ZZZ is anything if self.spacename == "" and node.value.id in self.spacevars: result = ast.Attribute( value=self.visit(node.value), @@ -669,6 +669,7 @@ def visit_Attribute(self, node): return result return node else: + # pass through result = ast.Attribute( value=self.visit(node.value), attr=node.attr, @@ -720,7 +721,6 @@ def visit_BinOp(self, node): if self.bool_wrapping and isinstance( node.op, (ast.BitAnd, ast.BitOr, ast.BitXor) ): - result = ast.BinOp( left=bool_wrap(left), op=node.op, @@ -737,7 +737,6 @@ def visit_BinOp(self, node): return result def visit_Call(self, node): - result = None # implement ActivitySim's "reverse" skims if ( @@ -912,7 +911,7 @@ def visit_Call(self, node): keywords=[], ) - # implement x.get("y",z) + # implement x.get("y",z) where x is the spacename if ( isinstance(node.func, ast.Attribute) and node.func.attr == "get" diff --git a/sharrow/categorical.py b/sharrow/categorical.py new file mode 100644 index 0000000..abddae2 --- /dev/null +++ b/sharrow/categorical.py @@ -0,0 +1,192 @@ +from __future__ import annotations + +from enum import IntEnum +from functools import reduce + +import numpy as np +import pandas as pd +import xarray as xr + + +class ArrayIsNotCategoricalError(TypeError): + """The array is not an encoded categorical array.""" + + +@xr.register_dataarray_accessor("cat") +class _Categorical: + """ + Accessor for pseudo-categorical arrays. + """ + + __slots__ = ("dataarray",) + + def __init__(self, dataarray: xr.DataArray): + self.dataarray = dataarray + + @property + def categories(self): + try: + return self.dataarray.attrs["digital_encoding"]["dictionary"] + except KeyError: + raise ArrayIsNotCategoricalError() from None + + @property + def ordered(self): + return self.dataarray.attrs["digital_encoding"].get("ordered", True) + + def category_array(self) -> np.ndarray: + arr = np.asarray(self.categories) + if arr.dtype.kind == "O": + arr = arr.astype(str) + return arr + + def is_categorical(self) -> bool: + return "dictionary" in self.dataarray.attrs.get("digital_encoding", {}) + + +def _interpret_enum(e: type[IntEnum], value: int | str) -> IntEnum: + """ + Convert a string or integer into an Enum value. + + The + + Parameters + ---------- + e : Type[IntEnum] + The enum to use in interpretation. + value: int or str + The value to convert. Integer and simple string values are converted + to their corresponding value. Multiple string values can also be given + joined by the pipe operator, in the style of flags (e.g. "Red|Octagon"). + """ + if isinstance(value, int): + return e(value) + return reduce(lambda x, y: x | y, [getattr(e, v) for v in value.split("|")]) + + +def get_enum_name(e: type[IntEnum], value: int) -> str: + """ + Get the name of an enum by value, or a placeholder name if not found. + + This allows for dummy placeholder names is the enum is does not contain + all consecutive values between 0 and the maximum value, inclusive. + + Parameters + ---------- + e : Type[IntEnum] + The enum to use in interpretation. + value : int + The value for which to find a name. If not found in `e`, this + function will generate a new name as a string by prefixing `value` + with an underscore. + + Returns + ------- + str + """ + result = e._value2member_map_.get(value, f"_{value}") + try: + return result.name + except AttributeError: + return result + + +def int_enum_to_categorical_dtype(e: type[IntEnum]) -> pd.CategoricalDtype: + """ + Convert an integer-valued enum to a pandas CategoricalDtype. + + Parameters + ---------- + e : Type[IntEnum] + + Returns + ------- + pd.CategoricalDtype + """ + max_enum_value = int(max(e)) + categories = [get_enum_name(e, i) for i in range(max_enum_value + 1)] + return pd.CategoricalDtype(categories=categories) + + +def as_int_enum( + s: pd.Series, + e: type[IntEnum], + dtype: type[np.integer] | None = None, + categorical: bool = True, +) -> pd.Series: + """ + Encode a pandas Series as categorical, consistent with an IntEnum. + + Parameters + ---------- + s : pd.Series + e : Type[IntEnum] + dtype : Type[np.integer], optional + Specific dtype to use for the code point encoding. It is typically not + necessary to give this explicitly as the function will automatically + select the best (most efficient) bitwidth. + categorical : bool, default True + If set to false, the returned series will simply be integer encoded with + no formal Categorical dtype. + + Returns + ------- + pd.Series + """ + min_enum_value = int(min(e)) + max_enum_value = int(max(e)) + assert min_enum_value >= 0 + if dtype is None: + if max_enum_value < 256 and min_enum_value >= 0: + dtype = np.uint8 + elif max_enum_value < 128 and min_enum_value >= -128: + dtype = np.int8 + elif max_enum_value < 65536 and min_enum_value >= 0: + dtype = np.uint16 + elif max_enum_value < 32768 and min_enum_value >= -32768: + dtype = np.int16 + elif max_enum_value < 2_147_483_648 and min_enum_value >= -2_147_483_648: + dtype = np.int32 + else: + dtype = np.int64 + if not isinstance(s, pd.Series): + s = pd.Series(s) + result = s.apply(lambda x: _interpret_enum(e, x)).astype(dtype) + if categorical: + categories = [get_enum_name(e, i) for i in range(max_enum_value + 1)] + result = pd.Categorical.from_codes(codes=result, categories=categories) + return result + + +@pd.api.extensions.register_series_accessor("as_int_enum") +class _AsIntEnum: + """ + Encode a pandas Series as categorical, consistent with an IntEnum. + + Parameters + ---------- + s : pd.Series + e : Type[IntEnum] + dtype : Type[np.integer], optional + Specific dtype to use for the code point encoding. It is typically not + necessary to give this explicitly as the function will automatically + select the best (most efficient) bitwidth. + categorical : bool, default True + If set to false, the returned series will simply be integer encoded with + no formal Categorical dtype. + + Returns + ------- + pd.Series + """ + + def __init__(self, pandas_obj): + self._obj = pandas_obj + + def __call__( + self: pd.Series, + e: type[IntEnum], + dtype: type[np.integer] | None = None, + categorical: bool = True, + ): + return as_int_enum(self._obj, e, dtype, categorical) diff --git a/sharrow/dataset.py b/sharrow/dataset.py index 68be2be..6ffab94 100755 --- a/sharrow/dataset.py +++ b/sharrow/dataset.py @@ -1,9 +1,12 @@ +from __future__ import annotations + import ast import base64 import hashlib import logging import re -from typing import Any, Hashable, Mapping, Sequence +from collections.abc import Hashable, Mapping, Sequence +from typing import Any import numpy as np import pandas as pd @@ -13,6 +16,7 @@ from .accessors import register_dataset_method from .aster import extract_all_name_tokens +from .categorical import _Categorical # noqa from .table import Table logger = logging.getLogger("sharrow") @@ -92,21 +96,21 @@ def construct(source): if isinstance(source, pd.DataFrame): source = dataset_from_dataframe_fast(source) # xarray default can be slow elif isinstance(source, (Table, pa.Table)): - source = xr.Dataset.from_table(source) - elif isinstance(source, (pa.Table)): - source = xr.Dataset.from_table(source) + source = from_table(source) elif isinstance(source, xr.Dataset): pass # don't do the superclass things elif isinstance(source, Sequence) and all(isinstance(i, str) for i in source): - source = xr.Dataset.from_table(pa.table({i: [] for i in source})) + source = from_table(pa.table({i: [] for i in source})) else: source = xr.Dataset(source) return source def dataset_from_dataframe_fast( - dataframe: pd.DataFrame, sparse: bool = False -) -> "Dataset": + dataframe: pd.DataFrame, + sparse: bool = False, + preserve_cat: bool = True, +) -> Dataset: """Convert a pandas.DataFrame into an xarray.Dataset Each column will be converted into an independent variable in the @@ -125,6 +129,13 @@ def dataset_from_dataframe_fast( If true, create a sparse arrays instead of dense numpy arrays. This can potentially save a large amount of memory if the DataFrame has a MultiIndex. Requires the sparse package (sparse.pydata.org). + preserve_cat : bool, default True + If true, preserve encoding of categorical columns. Xarray lacks an + official implementation of a categorical datatype, so sharrow's + dictionary-based digital encoding is applied instead. Note that in + native xarray usage, the resulting variable will look like integer + values instead of the category values. The `dataset.cat` accessor + can be used to interact with the categorical data. Returns ------- @@ -175,11 +186,26 @@ def dataset_from_dataframe_fast( index_name = idx.name if idx.name is not None else "index" # Cast to a NumPy array first, in case the Series is a pandas Extension # array (which doesn't have a valid NumPy dtype) - arrays = { - name: ([index_name], np.asarray(dataframe[name].values)) - for name in dataframe.columns - if name != index_name - } + arrays = {} + for name in dataframe.columns: + if name != index_name: + if dataframe[name].dtype == "category" and preserve_cat: + cat = dataframe[name].cat + categories = np.asarray(cat.categories) + if categories.dtype.kind == "O": + categories = categories.astype(str) + arrays[name] = ( + [index_name], + np.asarray(cat.codes), + { + "digital_encoding": { + "dictionary": categories, + "ordered": cat.ordered, + } + }, + ) + else: + arrays[name] = ([index_name], np.asarray(dataframe[name].values)) return Dataset(arrays, coords={index_name: (index_name, dataframe.index.values)}) @@ -197,7 +223,8 @@ def from_table( Table from which to use data and indices. index_name : str, default 'index' This name will be given to the default dimension index, if - none is given. Ignored if `index` is given explicitly. + none is given. Ignored if `index` is given explicitly and + it already has a name. index : Index-like, optional Use this index instead of a default RangeIndex. @@ -218,10 +245,21 @@ def from_table( raise ValueError( "cannot attach a non-unique MultiIndex and convert into xarray" ) - arrays = [ - (tbl.column_names[n], np.asarray(tbl.column(n))) - for n in range(len(tbl.column_names)) - ] + arrays = [] + metadata = {} + for n in range(len(tbl.column_names)): + c = tbl.column(n) + if isinstance(c.type, pa.DictionaryType): + cc = c.combine_chunks() + arrays.append((tbl.column_names[n], np.asarray(cc.indices))) + metadata[tbl.column_names[n]] = { + "digital_encoding": { + "dictionary": cc.dictionary, + "ordered": cc.type.ordered, + } + } + else: + arrays.append((tbl.column_names[n], np.asarray(c))) result = xr.Dataset() if isinstance(index, pd.MultiIndex): dims = tuple( @@ -231,11 +269,17 @@ def from_table( for dim, lev in zip(dims, index.levels): result[dim] = (dim, lev) else: - index_name = index.name if index.name is not None else "index" + try: + if index.name is not None: + index_name = index.name + except AttributeError: + pass dims = (index_name,) result[index_name] = (dims, index) result._set_numpy_data_from_dataframe(index, arrays, dims) + for k, v in metadata.items(): + result[k].attrs.update(v) return result @@ -627,27 +671,11 @@ def from_zarr_with_attr(*args, **kwargs): for k in obj: attrs = {} for aname, avalue in obj[k].attrs.items(): - if ( - isinstance(avalue, str) - and avalue.startswith(" {") - and avalue.endswith("} ") - ): - avalue = ast.literal_eval(avalue[1:-1]) - if isinstance(avalue, str) and avalue == " < None > ": - avalue = None - attrs[aname] = avalue + attrs[aname] = _from_evalable_string(avalue) obj[k] = obj[k].assign_attrs(attrs) attrs = {} for aname, avalue in obj.attrs.items(): - if ( - isinstance(avalue, str) - and avalue.startswith(" {") - and avalue.endswith("} ") - ): - avalue = ast.literal_eval(avalue[1:-1]) - if isinstance(avalue, str) and avalue == " < None > ": - avalue = None - attrs[aname] = avalue + attrs[aname] = _from_evalable_string(avalue) obj = obj.assign_attrs(attrs) return obj @@ -673,7 +701,7 @@ class _SingleDim: __slots__ = ("dataset", "dim_name") - def __init__(self, dataset: "Dataset"): + def __init__(self, dataset: Dataset): self.dataset = dataset if len(self.dataset.dims) != 1: raise ValueError("single_dim implies a single dimension dataset") @@ -691,10 +719,122 @@ def index(self): def size(self): return self.dataset.dims[self.dim_name] - def to_pyarrow(self): - return pa.Table.from_pydict( - {k: pa.array(v.to_numpy()) for k, v in self.dataset.variables.items()} + def _to_pydict(self): + columns = [k for k in self.dataset.variables if k != self.dim_name] + data = [] + for k in columns: + a = self.dataset._variables[k] + if ( + "digital_encoding" in a.attrs + and "dictionary" in a.attrs["digital_encoding"] + ): + de = a.attrs["digital_encoding"] + data.append( + pd.Categorical.from_codes( + a.values, + de["dictionary"], + de.get("ordered"), + ) + ) + else: + data.append(a.values) + return dict(zip(columns, data)) + + def to_pyarrow(self) -> pa.Table: + columns = [k for k in self.dataset.variables if k != self.dim_name] + data = [] + for k in columns: + a = self.dataset._variables[k] + if ( + "digital_encoding" in a.attrs + and "dictionary" in a.attrs["digital_encoding"] + ): + de = a.attrs["digital_encoding"] + data.append( + pa.DictionaryArray.from_arrays( + a.values, + de["dictionary"], + ordered=de.get("ordered", False), + ) + ) + else: + data.append(pa.array(a.values)) + content = dict(zip(columns, data)) + content[self.dim_name] = self.index + return pa.Table.from_pydict(content) + + def to_parquet(self, filename): + import pyarrow.parquet as pq + + t = self.to_pyarrow() + pq.write_table(t, filename) + + def to_pandas(self) -> pd.DataFrame: + """ + Convert this dataset into a pandas DataFrame. + + The resulting DataFrame is always a copy of the data in the dataset. + + Returns + ------- + pandas.DataFrame + """ + return pd.DataFrame(self._to_pydict(), index=self.index, copy=True) + + def eval( + self, + expr: str, + parser: str = "pandas", + engine: str | None = None, + local_dict=None, + global_dict=None, + ): + """ + Evaluate a Python expression as a string using various backends. + + Parameters + ---------- + expr : str + The expression to evaluate. This string cannot contain any Python + `statements + `__, + only Python `expressions + `__. + parser : {'pandas', 'python'}, default 'pandas' + The parser to use to construct the syntax tree from the expression. The + default of ``'pandas'`` parses code slightly different than standard + Python. Alternatively, you can parse an expression using the + ``'python'`` parser to retain strict Python semantics. See the + :ref:`enhancing performance ` documentation for + more details. + engine : {'python', 'numexpr'}, default 'numexpr' + The engine used to evaluate the expression. Supported engines are + - None : tries to use ``numexpr``, falls back to ``python`` + - ``'numexpr'`` : This default engine evaluates pandas objects using + numexpr for large speed ups in complex expressions with large frames. + - ``'python'`` : Performs operations as if you had ``eval``'d in top + level python. This engine is generally not that useful. + local_dict : dict or None, optional + A dictionary of local variables, taken from locals() by default. + global_dict : dict or None, optional + A dictionary of global variables, taken from globals() by default. + + Returns + ------- + DataArray or numeric scalar + """ + result = pd.eval( + expr, + parser=parser, + engine=engine, + local_dict=local_dict, + global_dict=global_dict, + resolvers=[self.dataset], ) + if result.size == self.size: + return DataArray(np.asarray(result), coords=self.dataset.coords) + else: + return result @xr.register_dataarray_accessor("single_dim") @@ -705,7 +845,7 @@ class _SingleDimArray: __slots__ = ("dataarray", "dim_name") - def __init__(self, dataarray: "DataArray"): + def __init__(self, dataarray: DataArray): self.dataarray = dataarray if len(self.dataarray.dims) != 1: raise ValueError("single_dim implies a single dimension dataset") @@ -725,6 +865,39 @@ def rename(self, name: str) -> DataArray: return self.dataarray return self.dataarray.rename({self.dim_name: name}) + def to_pandas(self) -> pd.Series: + """ + Convert this array into a pandas Series. + + If this array is categorical (i.e. with a simple dictionary-based + digital encoding) then the result will be a Series with categorical dtype. + + The DataArray's `name` attribute is preserved in the result. + """ + if self.dataarray.cat.is_categorical(): + return pd.Series( + pd.Categorical.from_codes( + self.dataarray, + self.dataarray.cat.categories, + self.dataarray.cat.ordered, + ), + index=self.index, + name=self.dataarray.name, + ) + else: + result = self.dataarray.to_pandas() + if self.dataarray.name: + result = result.rename(self.dataarray.name) + return result + + def to_pyarrow(self): + if self.dataarray.cat.is_categorical(): + return pa.DictionaryArray.from_arrays( + self.dataarray.data, self.dataarray.cat.categories + ) + else: + return pa.array(self.dataarray.data) + @xr.register_dataset_accessor("iloc") class _iLocIndexer: @@ -742,10 +915,10 @@ class _iLocIndexer: __slots__ = ("dataset",) - def __init__(self, dataset: "Dataset"): + def __init__(self, dataset: Dataset): self.dataset = dataset - def __getitem__(self, key: Mapping[Hashable, Any]) -> "Dataset": + def __getitem__(self, key: Mapping[Hashable, Any]) -> Dataset: if not is_dict_like(key): if len(self.dataset.dims) == 1: dim_name = self.dataset.dims.__iter__().__next__() @@ -858,6 +1031,58 @@ def to_zarr_zip(self, *args, **kwargs): return super().to_zarr(*args, **kwargs) +def _to_ast_literal(x): + if isinstance(x, dict): + return ( + "{" + + ", ".join( + f"{_to_ast_literal(k)}: {_to_ast_literal(v)}" for k, v in x.items() + ) + + "}" + ) + elif isinstance(x, list): + return "[" + ", ".join(_to_ast_literal(i) for i in x) + "]" + elif isinstance(x, tuple): + return "(" + ", ".join(_to_ast_literal(i) for i in x) + ")" + elif isinstance(x, pd.Index): + return _to_ast_literal(x.to_list()) + elif isinstance(x, np.ndarray): + return _to_ast_literal(list(x)) + else: + return repr(x) + + +def _to_evalable_string(x): + if x is None: + return " < None > " + elif x is True: + return " < True > " + elif x is False: + return " < False > " + else: + return f" {_to_ast_literal(x)} " + + +def _from_evalable_string(x): + if isinstance(x, str): + # if x.startswith(" {") and x.endswith("} "): + # return ast.literal_eval(x[1:-1]) + if x == " < None > ": + return None + if x == " < True > ": + return True + if x == " < False > ": + return False + if x.startswith(" ") and x.endswith(" "): + try: + return ast.literal_eval(x.strip(" ")) + except Exception: + print(x) + raise + else: + return x + + @register_dataset_method def to_zarr_with_attr(self, *args, **kwargs): """ @@ -938,19 +1163,17 @@ def to_zarr_with_attr(self, *args, **kwargs): for k in self: attrs = {} for aname, avalue in self[k].attrs.items(): - if isinstance(avalue, dict): - avalue = f" {avalue!r} " - if avalue is None: - avalue = " < None > " - attrs[aname] = avalue + attrs[aname] = _to_evalable_string(avalue) obj[k] = self[k].assign_attrs(attrs) + if hasattr(self, "coords"): + for k in self.coords: + attrs = {} + for aname, avalue in self.coords[k].attrs.items(): + attrs[aname] = _to_evalable_string(avalue) + obj.coords[k] = self.coords[k].assign_attrs(attrs) attrs = {} for aname, avalue in self.attrs.items(): - if isinstance(avalue, dict): - avalue = f" {avalue!r} " - if avalue is None: - avalue = " < None > " - attrs[aname] = avalue + attrs[aname] = _to_evalable_string(avalue) obj = obj.assign_attrs(attrs) return obj.to_zarr(*args, **kwargs) @@ -1088,7 +1311,7 @@ def from_named_objects(*args): try: name = a.name except AttributeError: - raise ValueError(f"argument {n} has no name") + raise ValueError(f"argument {n} has no name") from None if name is None: raise ValueError(f"the name for argument {n} is None") objs[name] = a diff --git a/sharrow/datastore.py b/sharrow/datastore.py new file mode 100644 index 0000000..4be79df --- /dev/null +++ b/sharrow/datastore.py @@ -0,0 +1,482 @@ +from __future__ import annotations + +import datetime +import os +import shutil +from collections.abc import Collection +from pathlib import Path + +import pandas as pd +import xarray as xr +import yaml + +from .dataset import construct, from_zarr_with_attr +from .relationships import DataTree, Relationship + + +def timestamp(): + return datetime.datetime.now(datetime.timezone.utc).astimezone().isoformat() + + +class ReadOnlyError(ValueError): + """This object is read-only.""" + + +def _read_parquet(filename, index_col=None) -> xr.Dataset: + import pyarrow.parquet as pq + + from sharrow.dataset import from_table + + content = pq.read_table(filename) + if index_col is not None: + index = content.column(index_col) + content = content.drop([index_col]) + else: + index = None + x = from_table(content, index=index, index_name=index_col or "index") + return x + + +class DataStore: + metadata_filename: str = "metadata.yaml" + checkpoint_subdir: str = "checkpoints" + LATEST = "_" + _BY_OFFSET = "digitizedOffset" + + def __init__(self, directory: Path | None, mode="a", storage_format: str = "zarr"): + self._directory = Path(directory) if directory else directory + self._mode = mode + self._checkpoints = {} + self._checkpoint_order = [] + self._tree = DataTree(root_node_name=False) + self._keep_digitized = False + assert storage_format in {"zarr", "parquet", "hdf5"} + self._storage_format = storage_format + try: + self.read_metadata() + except FileNotFoundError: + pass + + @property + def directory(self) -> Path: + if self._directory is None: + raise NotADirectoryError("no directory set") + return self._directory + + def __setitem__(self, key: str, value: xr.Dataset | pd.DataFrame): + assert isinstance(key, str) + if self._mode == "r": + raise ReadOnlyError + if isinstance(value, xr.Dataset): + self._set_dataset(key, value) + elif isinstance(value, pd.DataFrame): + self._set_dataset(key, construct(value)) + else: + raise TypeError(f"cannot put {type(value)}") + + def __getitem__(self, key: str): + assert isinstance(key, str) + return self.get_dataset(key) + + def clone(self, mode="a"): + """ + Create a clone of this DataStore. + + The clone has the same active datasets as the original (and the data + shares the same memory) but it does not retain the checkpoint metadata + and is not connected to the same checkpoint store. + + Returns + ------- + DataStore + """ + duplicate = self.__class__(None, mode=mode) + duplicate._tree = self._tree + return duplicate + + def _set_dataset( + self, + name: str, + data: xr.Dataset, + last_checkpoint: str = None, + ) -> None: + if self._mode == "r": + raise ReadOnlyError + data_vars = {} + coords = {} + for k, v in data.coords.items(): + coords[k] = v.assign_attrs(last_checkpoint=last_checkpoint) + for k, v in data.items(): + if k in coords: + continue + data_vars[k] = v.assign_attrs(last_checkpoint=last_checkpoint) + data = xr.Dataset(data_vars=data_vars, coords=coords, attrs=data.attrs) + self._tree.add_dataset(name, data) + + def _update_dataset( + self, + name: str, + data: xr.Dataset, + last_checkpoint=None, + ) -> xr.Dataset: + if self._mode == "r": + raise ReadOnlyError + if not isinstance(data, xr.Dataset): + raise TypeError(type(data)) + partial_update = self._tree.get_subspace(name, default_empty=True) + for k, v in data.items(): + if k in data.coords: + continue + assert v.name == k + partial_update = self._update_dataarray( + name, v, last_checkpoint, partial_update=partial_update + ) + for k, v in data.coords.items(): + assert v.name == k + partial_update = self._update_dataarray( + name, v, last_checkpoint, as_coord=True, partial_update=partial_update + ) + return partial_update + + def _update_dataarray( + self, + name: str, + data: xr.DataArray, + last_checkpoint=None, + as_coord=False, + partial_update=None, + ) -> xr.Dataset: + if self._mode == "r": + raise ReadOnlyError + if partial_update is None: + base_data = self._tree.get_subspace(name, default_empty=True) + else: + base_data = partial_update + if isinstance(data, xr.DataArray): + if as_coord: + updated_dataset = base_data.assign_coords( + {data.name: data.assign_attrs(last_checkpoint=last_checkpoint)} + ) + else: + updated_dataset = base_data.assign( + {data.name: data.assign_attrs(last_checkpoint=last_checkpoint)} + ) + self._tree = self._tree.replace_datasets( + {name: updated_dataset}, redigitize=self._keep_digitized + ) + return updated_dataset + else: + raise TypeError(type(data)) + + def update( + self, + name: str, + obj: xr.Dataset | xr.DataArray, + last_checkpoint: str = None, + ) -> None: + """ + Make a partial update of an existing named dataset. + + Parameters + ---------- + name : str + obj : Dataset or DataArray + last_checkpoint : str or None + Set the "last_checkpoint" attribute on all updated variables to this + value. Users should typically leave this as "None", which flags the + checkpointing algorithm to write this data to disk the next time a + checkpoint is written. + """ + if isinstance(obj, xr.Dataset): + self._update_dataset(name, obj, last_checkpoint=last_checkpoint) + elif isinstance(obj, xr.DataArray): + self._update_dataarray(name, obj, last_checkpoint=last_checkpoint) + else: + raise TypeError(type(obj)) + + def set_data( + self, + name: str, + data: xr.Dataset | pd.DataFrame, + relationships: str | Relationship | Collection[str | Relationship] = None, + ) -> None: + """ + Set the content of a named dataset. + + This completely overwrites any existing data with the same name. + + Parameters + ---------- + name : str + data : Dataset or DataFrame + relationships : str or Relationship or list thereof + """ + self.__setitem__(name, data) + if relationships is not None: + if isinstance(relationships, (str, Relationship)): + relationships = [relationships] + for r in relationships: + self._tree.add_relationship(r) + + def get_dataset(self, name: str, columns: Collection[str] = None) -> xr.Dataset: + """ + Retrieve some or all of a named dataset. + + Parameters + ---------- + name : str + columns : Collection[str], optional + Get only these variables of the dataset. + """ + if columns is None: + return self._tree.get_subspace(name) + else: + return xr.Dataset({c: self._tree[f"{name}.{c}"] for c in columns}) + + def get_dataframe(self, name: str, columns: Collection[str] = None) -> pd.DataFrame: + """ + Retrieve some or all of a named dataset, as a pandas DataFrame. + + This completely overwrites any existing data with the same name. + + Parameters + ---------- + name : str + columns : Collection[str], optional + Get only these variables of the dataset. + """ + dataset = self.get_dataset(name, columns) + return dataset.single_dim.to_pandas() + + def _to_be_checkpointed(self) -> dict[str, xr.Dataset]: + result = {} + for table_name, table_data in self._tree.subspaces_iter(): + # any data elements that were created without a + # last_checkpoint attr get one now + for _k, v in table_data.variables.items(): + if "last_checkpoint" not in v.attrs: + v.attrs["last_checkpoint"] = None + # collect everything not checkpointed + uncheckpointed = table_data.filter_by_attrs(last_checkpoint=None) + if uncheckpointed: + result[table_name] = uncheckpointed + return result + + def _zarr_subdir(self, table_name, checkpoint_name): + return self.directory.joinpath(table_name, checkpoint_name).with_suffix(".zarr") + + def _parquet_name(self, table_name, checkpoint_name): + return self.directory.joinpath(table_name, checkpoint_name).with_suffix( + ".parquet" + ) + + def make_checkpoint(self, checkpoint_name: str, overwrite: bool = True): + """ + Write data to disk. + + Only new data (since the last time a checkpoint was made) is actually + written out. + + Parameters + ---------- + checkpoint_name : str + overwrite : bool, default True + """ + if self._mode == "r": + raise ReadOnlyError + to_be_checkpointed = self._to_be_checkpointed() + new_checkpoint = { + "timestamp": timestamp(), + "tables": {}, + "relationships": [], + } + # remove checkpoint name from ordered list if it already exists + while checkpoint_name in self._checkpoint_order: + self._checkpoint_order.remove(checkpoint_name) + # add checkpoint name at end ordered list + self._checkpoint_order.append(checkpoint_name) + for table_name, table_data in to_be_checkpointed.items(): + if self._storage_format == "parquet" and len(table_data.dims) == 1: + target = self._parquet_name(table_name, checkpoint_name) + if overwrite and target.is_file(): + os.unlink(target) + target.parent.mkdir(parents=True, exist_ok=True) + table_data.single_dim.to_parquet(str(target)) + elif self._storage_format == "zarr" or ( + self._storage_format == "parquet" and len(table_data.dims) > 1 + ): + # zarr is used if ndim > 1 + target = self._zarr_subdir(table_name, checkpoint_name) + if overwrite and target.is_dir(): + shutil.rmtree(target) + target.mkdir(parents=True, exist_ok=True) + table_data.to_zarr_with_attr(target) + elif self._storage_format == "hdf5": + raise NotImplementedError + else: + raise ValueError( + f"cannot write with storage format {self._storage_format!r}" + ) + self.update(table_name, table_data, last_checkpoint=checkpoint_name) + for table_name, table_data in self._tree.subspaces_iter(): + inventory = {"data_vars": {}, "coords": {}} + for varname, vardata in table_data.items(): + inventory["data_vars"][varname] = { + "last_checkpoint": vardata.attrs.get("last_checkpoint", "MISSING"), + "dtype": str(vardata.dtype), + } + for varname, vardata in table_data.coords.items(): + _cp = checkpoint_name + # coords in every checkpoint with any content + if table_name not in to_be_checkpointed: + _cp = vardata.attrs.get("last_checkpoint", "MISSING") + inventory["coords"][varname] = { + "last_checkpoint": _cp, + "dtype": str(vardata.dtype), + } + new_checkpoint["tables"][table_name] = inventory + for r in self._tree.list_relationships(): + new_checkpoint["relationships"].append(r.to_dict()) + self._checkpoints[checkpoint_name] = new_checkpoint + self._write_checkpoint(checkpoint_name, new_checkpoint) + self._write_metadata() + + def _write_checkpoint(self, name, checkpoint): + if self._mode == "r": + raise ReadOnlyError + checkpoint_metadata_target = self.directory.joinpath( + self.checkpoint_subdir, f"{name}.yaml" + ) + if checkpoint_metadata_target.exists(): + n = 1 + while checkpoint_metadata_target.with_suffix(f".{n}.yaml").exists(): + n += 1 + os.rename( + checkpoint_metadata_target, + checkpoint_metadata_target.with_suffix(f".{n}.yaml"), + ) + checkpoint_metadata_target.parent.mkdir(parents=True, exist_ok=True) + with open(checkpoint_metadata_target, "w") as f: + yaml.safe_dump(checkpoint, f) + + def _write_metadata(self): + if self._mode == "r": + raise ReadOnlyError + metadata_target = self.directory.joinpath(self.metadata_filename) + if metadata_target.exists(): + n = 1 + while metadata_target.with_suffix(f".{n}.yaml").exists(): + n += 1 + os.rename(metadata_target, metadata_target.with_suffix(f".{n}.yaml")) + with open(metadata_target, "w") as f: + metadata = dict( + datastore_format_version=1, + checkpoint_order=self._checkpoint_order, + ) + yaml.safe_dump(metadata, f) + + def read_metadata(self, checkpoints=None): + """ + Read storage metadata + + Parameters + ---------- + checkpoints : str | list[str], optional + Read only these checkpoints. If not provided, only the latest + checkpoint metadata is read. Set to "*" to read all. + """ + with open(self.directory.joinpath(self.metadata_filename)) as f: + metadata = yaml.safe_load(f) + datastore_format_version = metadata.get("datastore_format_version", "missing") + if datastore_format_version == 1: + self._checkpoint_order = metadata["checkpoint_order"] + else: + raise NotImplementedError(f"{datastore_format_version=}") + if checkpoints is None or checkpoints == self.LATEST: + checkpoints = [self._checkpoint_order[-1]] + elif isinstance(checkpoints, str): + if checkpoints == "*": + checkpoints = list(self._checkpoint_order) + else: + checkpoints = [checkpoints] + for c in checkpoints: + with open( + self.directory.joinpath(self.checkpoint_subdir, f"{c}.yaml") + ) as f: + self._checkpoints[c] = yaml.safe_load(f) + + def restore_checkpoint(self, checkpoint_name: str): + if checkpoint_name not in self._checkpoints: + try: + self.read_metadata(checkpoint_name) + except FileNotFoundError: + raise KeyError(checkpoint_name) from None + checkpoint = self._checkpoints[checkpoint_name] + self._tree = DataTree(root_node_name=False) + for table_name, table_def in checkpoint["tables"].items(): + if table_name == "timestamp": + continue + t = xr.Dataset() + opened_targets = {} + coords = table_def.get("coords", {}) + if len(coords) == 1: + index_name = list(coords)[0] + else: + index_name = None + for coord_name, coord_def in coords.items(): + target = self._zarr_subdir(table_name, coord_def["last_checkpoint"]) + if target.exists(): + if target not in opened_targets: + opened_targets[target] = from_zarr_with_attr(target) + else: + # zarr not found, try parquet + target2 = self._parquet_name( + table_name, coord_def["last_checkpoint"] + ) + if target2.exists(): + if target not in opened_targets: + opened_targets[target] = _read_parquet(target2, index_name) + else: + raise FileNotFoundError(target) + t = t.assign_coords({coord_name: opened_targets[target][coord_name]}) + data_vars = table_def.get("data_vars", {}) + for var_name, var_def in data_vars.items(): + if var_def["last_checkpoint"] == "MISSING": + raise ValueError(f"missing checkpoint for {table_name}.{var_name}") + target = self._zarr_subdir(table_name, var_def["last_checkpoint"]) + if target.exists(): + if target not in opened_targets: + opened_targets[target] = from_zarr_with_attr(target) + else: + # zarr not found, try parquet + target2 = self._parquet_name(table_name, var_def["last_checkpoint"]) + if target2.exists(): + if target not in opened_targets: + opened_targets[target] = _read_parquet(target2, index_name) + else: + raise FileNotFoundError(target) + t = t.assign({var_name: opened_targets[target][var_name]}) + self._tree.add_dataset(table_name, t) + for r in checkpoint["relationships"]: + self._tree.add_relationship(Relationship(**r)) + + def add_relationship(self, relationship: str | Relationship): + self._tree.add_relationship(relationship) + + def digitize_relationships(self, redigitize=True): + """ + Convert all label-based relationships into position-based. + + Parameters + ---------- + redigitize : bool, default True + Re-compute position-based relationships from labels, even + if the relationship had previously been digitized. + """ + self._keep_digitized = True + self._tree.digitize_relationships(inplace=True, redigitize=redigitize) + + @property + def relationships_are_digitized(self) -> bool: + """bool : Whether all relationships are digital (by position).""" + return self._tree.relationships_are_digitized diff --git a/sharrow/digital_encoding.py b/sharrow/digital_encoding.py index 7dee25f..22a8b27 100644 --- a/sharrow/digital_encoding.py +++ b/sharrow/digital_encoding.py @@ -163,24 +163,35 @@ def find_bins(values, final_width=255): def digitize_by_dictionary(arr, bitwidth=8): result = arr.copy() bins = find_bins(arr, final_width=1 << bitwidth) - bin_edges = (bins[1:] - bins[:-1]) / 2 + bins[:-1] try: - arr_data = arr.data - except AttributeError: - pass + bin_edges = (bins[1:] - bins[:-1]) / 2 + bins[:-1] + except TypeError: + # bins are not numeric + bin_map = {x: n for n, x in enumerate(bins)} + u, inv = np.unique(arr.data, return_inverse=True) + result.data = np.array([bin_map.get(x) for x in u])[inv].reshape(arr.shape) + result.attrs["digital_encoding"] = { + "dictionary": bins, + } + return result else: - if isinstance(arr_data, da.Array): - result.data = da.digitize(arr_data, bin_edges).astype(f"uint{bitwidth}") - result.attrs["digital_encoding"] = { - "dictionary": bins, - } - return result - # fall back to numpy digitize - result.data = np.digitize(arr, bin_edges).astype(f"uint{bitwidth}") - result.attrs["digital_encoding"] = { - "dictionary": bins, - } - return result + try: + arr_data = arr.data + except AttributeError: + pass + else: + if isinstance(arr_data, da.Array): + result.data = da.digitize(arr_data, bin_edges).astype(f"uint{bitwidth}") + result.attrs["digital_encoding"] = { + "dictionary": bins, + } + return result + # fall back to numpy digitize + result.data = np.digitize(arr, bin_edges).astype(f"uint{bitwidth}") + result.attrs["digital_encoding"] = { + "dictionary": bins, + } + return result @xr.register_dataset_accessor("digital_encoding") diff --git a/sharrow/filewrite.py b/sharrow/filewrite.py index 872911c..6d9a55f 100644 --- a/sharrow/filewrite.py +++ b/sharrow/filewrite.py @@ -21,7 +21,7 @@ def blacken(code): except Exception as err: import warnings - warnings.warn(f"error in blacken: {err!r}") + warnings.warn(f"error in blacken: {err!r}", stacklevel=2) return code diff --git a/sharrow/flows.py b/sharrow/flows.py index 5ba3439..20636c1 100644 --- a/sharrow/flows.py +++ b/sharrow/flows.py @@ -46,6 +46,7 @@ class CacheMissWarning(UserWarning): "hard_sigmoid", "transpose_leading", "clip", + "get", } @@ -140,6 +141,61 @@ def filter_name_tokens(expr, matchable_names=None): return name_tokens, arg_tokens +class ExtractOptionalGetTokens(ast.NodeVisitor): + def __init__(self, from_names): + self.optional_get_tokens = set() + self.required_get_tokens = set() + self.from_names = from_names + + def visit_Call(self, node): + if isinstance(node.func, ast.Attribute): + if node.func.attr == "get": + if isinstance(node.func.value, ast.Name): + if node.func.value.id in self.from_names: + if len(node.args) == 1: + if isinstance(node.args[0], ast.Constant): + if len(node.keywords) == 0: + self.required_get_tokens.add( + (node.func.value.id, node.args[0].value) + ) + elif ( + len(node.keywords) == 1 + and node.keywords[0].arg == "default" + ): + self.optional_get_tokens.add( + (node.func.value.id, node.args[0].value) + ) + else: + raise ValueError( + f"{node.func.value.id}.get with unexpected keyword arguments" + ) + if len(node.args) == 2: + if isinstance(node.args[0], ast.Constant): + self.optional_get_tokens.add( + (node.func.value.id, node.args[0].value) + ) + if len(node.args) > 2: + raise ValueError( + f"{node.func.value.id}.get with more than 2 positional arguments" + ) + self.generic_visit(node) + + def check(self, node): + if isinstance(node, str): + node = ast.parse(node) + if isinstance(node, ast.AST): + self.visit(node) + else: + try: + node_iter = iter(node) + except TypeError: + pass + else: + for i in node_iter: + self.check(i) + return self.optional_get_tokens + + def coerce_to_range_index(idx): if isinstance(idx, pd.RangeIndex): return idx @@ -152,7 +208,13 @@ def coerce_to_range_index(idx): FUNCTION_TEMPLATE = """ # {init_expr} -@nb.jit(cache=False, error_model='{error_model}', boundscheck={boundscheck}, nopython={nopython}, fastmath={fastmath}, nogil={nopython}) +@nb.jit( + cache=False, + error_model='{error_model}', + boundscheck={boundscheck}, + nopython={nopython}, + fastmath={fastmath}, + nogil={nopython}) def {fname}( {argtokens} _outputs, @@ -164,7 +226,14 @@ def {fname}( IRUNNER_1D_TEMPLATE = """ -@nb.jit(cache=True, parallel=True, error_model='{error_model}', boundscheck={boundscheck}, nopython={nopython}, fastmath={fastmath}, nogil={nopython}) +@nb.jit( + cache=True, + parallel=True, + error_model='{error_model}', + boundscheck={boundscheck}, + nopython={nopython}, + fastmath={fastmath}, + nogil={nopython}) def irunner( argshape, {joined_namespace_names} @@ -185,7 +254,14 @@ def irunner( """ IRUNNER_2D_TEMPLATE = """ -@nb.jit(cache=True, parallel=True, error_model='{error_model}', boundscheck={boundscheck}, nopython={nopython}, fastmath={fastmath}, nogil={nopython}) +@nb.jit( + cache=True, + parallel=True, + error_model='{error_model}', + boundscheck={boundscheck}, + nopython={nopython}, + fastmath={fastmath}, + nogil={nopython}) def irunner( argshape, {joined_namespace_names} @@ -207,7 +283,14 @@ def irunner( """ IDOTTER_1D_TEMPLATE = """ -@nb.jit(cache=True, parallel=True, error_model='{error_model}', boundscheck={boundscheck}, nopython={nopython}, fastmath={fastmath}, nogil={nopython}) +@nb.jit( + cache=True, + parallel=True, + error_model='{error_model}', + boundscheck={boundscheck}, + nopython={nopython}, + fastmath={fastmath}, + nogil={nopython}) def idotter( argshape, {joined_namespace_names} @@ -232,7 +315,14 @@ def idotter( """ IDOTTER_2D_TEMPLATE = """ -@nb.jit(cache=True, parallel=True, error_model='{error_model}', boundscheck={boundscheck}, nopython={nopython}, fastmath={fastmath}, nogil={nopython}) +@nb.jit( + cache=True, + parallel=True, + error_model='{error_model}', + boundscheck={boundscheck}, + nopython={nopython}, + fastmath={fastmath}, + nogil={nopython}) def idotter( argshape, {joined_namespace_names} @@ -259,7 +349,13 @@ def idotter( """ ILINER_1D_TEMPLATE = """ -@nb.jit(cache=False, error_model='{error_model}', boundscheck={boundscheck}, nopython={nopython}, fastmath={fastmath}, nogil={nopython}) +@nb.jit( + cache=False, + error_model='{error_model}', + boundscheck={boundscheck}, + nopython={nopython}, + fastmath={fastmath}, + nogil={nopython}) def linemaker( intermediate, j0, {joined_namespace_names} @@ -269,7 +365,13 @@ def linemaker( """ ILINER_2D_TEMPLATE = """ -@nb.jit(cache=False, error_model='{error_model}', boundscheck={boundscheck}, nopython={nopython}, fastmath={fastmath}, nogil={nopython}) +@nb.jit( + cache=False, + error_model='{error_model}', + boundscheck={boundscheck}, + nopython={nopython}, + fastmath={fastmath}, + nogil={nopython}) def linemaker( intermediate, j0, j1, {joined_namespace_names} @@ -280,7 +382,13 @@ def linemaker( MNL_GENERIC_TEMPLATE = """ -@nb.jit(cache=True, error_model='{error_model}', boundscheck={boundscheck}, nopython={nopython}, fastmath={fastmath}, nogil={nopython}) +@nb.jit( + cache=True, + error_model='{error_model}', + boundscheck={boundscheck}, + nopython={nopython}, + fastmath={fastmath}, + nogil={nopython}) def _sample_choices_maker( prob_array, random_array, @@ -330,7 +438,13 @@ def _sample_choices_maker( -@nb.jit(cache=True, error_model='{error_model}', boundscheck={boundscheck}, nopython={nopython}, fastmath={fastmath}, nogil={nopython}) +@nb.jit( + cache=True, + error_model='{error_model}', + boundscheck={boundscheck}, + nopython={nopython}, + fastmath={fastmath}, + nogil={nopython}) def _sample_choices_maker_counted( prob_array, random_array, @@ -395,7 +509,14 @@ def _sample_choices_maker_counted( logit_ndims = 1 -@nb.jit(cache=True, parallel=True, error_model='{error_model}', boundscheck={boundscheck}, nopython={nopython}, fastmath={fastmath}, nogil={nopython}) +@nb.jit( + cache=True, + parallel=True, + error_model='{error_model}', + boundscheck={boundscheck}, + nopython={nopython}, + fastmath={fastmath}, + nogil={nopython}) def mnl_transform_plus1d( argshape, {joined_namespace_names} @@ -445,7 +566,13 @@ def mnl_transform_plus1d( """ ) -# @nb.jit(cache=True, parallel=True, error_model='{error_model}', boundscheck={boundscheck}, nopython={nopython}, fastmath={fastmath}) +# @nb.jit( +# cache=True, +# parallel=True, +# error_model='{error_model}', +# boundscheck={boundscheck}, +# nopython={nopython}, +# fastmath={fastmath}) # def mnl_transform_plus1d( # argshape, # {joined_namespace_names} @@ -482,7 +609,9 @@ def mnl_transform_plus1d( # if logsums: # _logsums[j0,k0] = np.log(local_sum) + shifter # if pick_counted: -# _sample_choices_maker_counted(partial, random_draws[j0,k0], result[j0,k0], result_p[j0,k0], pick_count[j0,k0]) +# _sample_choices_maker_counted( +# partial, random_draws[j0,k0], result[j0,k0], result_p[j0,k0], pick_count[j0,k0] +# ) # else: # _sample_choices_maker(partial, random_draws[j0,k0], result[j0,k0], result_p[j0,k0]) # return result, result_p, pick_count, _logsums @@ -494,7 +623,14 @@ def mnl_transform_plus1d( logit_ndims = 2 -@nb.jit(cache=True, parallel=True, error_model='{error_model}', boundscheck={boundscheck}, nopython={nopython}, fastmath={fastmath}, nogil={nopython}) +@nb.jit( + cache=True, + parallel=True, + error_model='{error_model}', + boundscheck={boundscheck}, + nopython={nopython}, + fastmath={fastmath}, + nogil={nopython}) def mnl_transform( argshape, {joined_namespace_names} @@ -558,7 +694,14 @@ def mnl_transform( return result, result_p, pick_count, _logsums -@nb.jit(cache=True, parallel=True, error_model='{error_model}', boundscheck={boundscheck}, nopython={nopython}, fastmath={fastmath}, nogil={nopython}) +@nb.jit( + cache=True, + parallel=True, + error_model='{error_model}', + boundscheck={boundscheck}, + nopython={nopython}, + fastmath={fastmath}, + nogil={nopython}) def mnl_transform_plus1d( argshape, {joined_namespace_names} @@ -613,7 +756,9 @@ def mnl_transform_plus1d( continue partial /= local_sum if pick_counted: - _sample_choices_maker_counted(partial, random_draws[j0,j1], result[j0,j1], result_p[j0,j1], pick_count[j0,j1]) + _sample_choices_maker_counted( + partial, random_draws[j0,j1], result[j0,j1], result_p[j0,j1], pick_count[j0,j1] + ) else: _sample_choices_maker(partial, random_draws[j0,j1], result[j0,j1], result_p[j0,j1]) return result, result_p, pick_count, _logsums @@ -625,7 +770,14 @@ def mnl_transform_plus1d( from sharrow.nested_logit import _utility_to_probability -@nb.jit(cache=True, parallel=True, error_model='{error_model}', boundscheck={boundscheck}, nopython={nopython}, fastmath={fastmath}, nogil={nopython}) +@nb.jit( + cache=True, + parallel=True, + error_model='{error_model}', + boundscheck={boundscheck}, + nopython={nopython}, + fastmath={fastmath}, + nogil={nopython}) def nl_transform( argshape, {joined_namespace_names} @@ -695,7 +847,9 @@ def nl_transform( _logsums[j0] = utility[-1] if logsums != 1: if pick_counted: - _sample_choices_maker_counted(probability[:n_alts], random_draws[j0], result[j0], result_p[j0], pick_count[j0]) + _sample_choices_maker_counted( + probability[:n_alts], random_draws[j0], result[j0], result_p[j0], pick_count[j0] + ) else: _sample_choices_maker(probability[:n_alts], random_draws[j0], result[j0], result_p[j0]) return result, result_p, pick_count, _logsums @@ -802,6 +956,7 @@ def __new__( dim_order=None, dim_exclude=None, bool_wrapping=False, + with_root_node_name=None, ): assert isinstance(tree, DataTree) tree.digitize_relationships(inplace=True) @@ -844,9 +999,11 @@ def __new__( parallel=parallel, extra_hash_data=extra_hash_data, write_hash_audit=write_hash_audit, + with_root_node_name=with_root_node_name, ) if flow_library is not None: flow_library[self.flow_hash] = self + self.with_root_node_name = with_root_node_name return self def __initialize_1( @@ -886,7 +1043,7 @@ def __initialize_1( all_raw_names = set() all_name_tokens = set() - for k, expr in defs.items(): + for _k, expr in defs.items(): plain_names, attribute_pairs, subscript_pairs = extract_names_2(expr) all_raw_names |= plain_names if self.tree.root_node_name: @@ -916,6 +1073,37 @@ def __initialize_1( if aux_var in all_raw_names: self._used_aux_vars.append(aux_var) + subspace_names = set() + for k, _ in self.tree.subspaces_iter(): + subspace_names.add(k) + for k in self.tree.subspace_fallbacks: + subspace_names.add(k) + optional_get_tokens = ExtractOptionalGetTokens(from_names=subspace_names).check( + defs.values() + ) + self._optional_get_tokens = [] + if optional_get_tokens: + for _spacename, _varname in optional_get_tokens: + found = False + if ( + _spacename in self.tree.subspaces + and _varname in self.tree.subspaces[_spacename] + ): + self._optional_get_tokens.append(f"__{_spacename}__{_varname}:True") + found = True + elif _spacename in self.tree.subspace_fallbacks: + for _subspacename in self.tree.subspace_fallbacks[_spacename]: + if _varname in self.tree.subspaces[_subspacename]: + self._optional_get_tokens.append( + f"__{_subspacename}__{_varname}:__{_spacename}__{_varname}" + ) + found = True + break + if not found: + self._optional_get_tokens.append( + f"__{_spacename}__{_varname}:False" + ) + self._hashing_level = hashing_level if self._hashing_level > 1: func_code, all_name_tokens = self.init_sub_funcs( @@ -956,6 +1144,8 @@ def _flow_hash_push(x): _flow_hash_push(f"aux_var:{k}") for k in sorted(self._used_extra_funcs): _flow_hash_push(f"func:{k}") + for k in sorted(self._optional_get_tokens): + _flow_hash_push(f"OPTIONAL:{k}") _flow_hash_push("---DataTree---") for k in self.arg_names: _flow_hash_push(f"arg:{k}") @@ -984,15 +1174,19 @@ def _flow_hash_push(x): digital_encoding = self.tree.subspaces[parts[1]][ "__".join(parts[2:]) ].attrs["digital_encoding"] - except (AttributeError, KeyError): + except (AttributeError, KeyError) as err: pass + print(f"$$$$/ndigital_encoding=ERR\n{err}\n\n\n$$$") + else: + print(f"$$$$/n{digital_encoding=}\n\n\n$$$") if digital_encoding: for de_k in sorted(digital_encoding.keys()): de_v = digital_encoding[de_k] if de_k == "dictionary": self.encoding_dictionaries[k] = de_v _flow_hash_push((k, "digital_encoding", de_k, de_v)) + for k in extra_hash_data: _flow_hash_push(k) @@ -1086,7 +1280,7 @@ def init_sub_funcs( other_way = False # if other_way is triggered, there may be residual other terms # that were not addressed, so this loop should be applied again. - for spacename, spacearrays in self.tree.subspaces.items(): + for spacename in self.tree.subspaces.keys(): dim_slots, digital_encodings, blenders = meta_data[spacename] try: expr = expression_for_numba( @@ -1164,6 +1358,52 @@ def init_sub_funcs( other_way = True # at least one variable was found in a get break + if not other_way and "get" in expr: + # any remaining "get" expressions with defaults should now use them + try: + expr = expression_for_numba( + expr, + spacename, + dim_slots, + dim_slots, + digital_encodings=digital_encodings, + extra_vars=self.tree.extra_vars, + blenders=blenders, + bool_wrapping=self.bool_wrapping, + get_default=True, + ) + except KeyError as err: # noqa: F841 + pass + else: + other_way = True + # at least one variable was found in a get + break + # check if we can resolve this "get" on any other subspace + for other_spacename in self.tree.subspace_fallbacks.get( + topkey, [] + ): + dim_slots, digital_encodings, blenders = meta_data[ + other_spacename + ] + try: + expr = expression_for_numba( + expr, + spacename, + dim_slots, + dim_slots, + digital_encodings=digital_encodings, + prefer_name=other_spacename, + extra_vars=self.tree.extra_vars, + blenders=blenders, + bool_wrapping=self.bool_wrapping, + get_default=True, + ) + except KeyError as err: # noqa: F841 + pass + else: + other_way = True + # at least one variable was found in a fallback + break if not other_way: raise if prior_expr == expr: @@ -1174,6 +1414,49 @@ def init_sub_funcs( # nothing else needs to change prior_expr = expr + # now process for subspace fallbacks + for gd in [False, True]: + # first run all these with get_default off, nothing drops to defaults + # if we might find it later. Then do a second pass with get_default on. + for ( + alias_spacename, + actual_spacenames, + ) in self.tree.subspace_fallbacks.items(): + for actual_spacename in actual_spacenames: + dim_slots, digital_encodings, blenders = meta_data[ + actual_spacename + ] + try: + expr = expression_for_numba( + expr, + alias_spacename, + dim_slots, + dim_slots, + digital_encodings=digital_encodings, + prefer_name=actual_spacename, + extra_vars=self.tree.extra_vars, + blenders=blenders, + bool_wrapping=self.bool_wrapping, + get_default=gd, + ) + except KeyError: + # there was an error, but lets make sure we process the + # whole expression to rewrite all the things we can before + # moving on to the fallback processing. + expr = expression_for_numba( + expr, + alias_spacename, + dim_slots, + dim_slots, + digital_encodings=digital_encodings, + prefer_name=actual_spacename, + extra_vars=self.tree.extra_vars, + blenders=blenders, + bool_wrapping=self.bool_wrapping, + swallow_errors=True, + get_default=gd, + ) + # now find instances where an identifier is previously created in this flow. expr = expression_for_numba( expr, @@ -1248,6 +1531,7 @@ def __initialize_2( parallel=True, extra_hash_data=(), write_hash_audit=True, + with_root_node_name=None, ): """ @@ -1311,7 +1595,7 @@ def __initialize_2( # if an existing __init__ file matches the hash, just use it init_file = os.path.join(self.cache_dir, self.name, "__init__.py") if os.path.isfile(init_file): - with open(init_file, "rt") as f: + with open(init_file) as f: content = f.read() s = re.search("""flow_hash = ['"](.*)['"]""", content) else: @@ -1401,7 +1685,6 @@ def __initialize_2( with rewrite( os.path.join(self.cache_dir, self.name, "__init__.py"), "wt" ) as f_code: - f_code.write( textwrap.dedent( f""" @@ -1455,10 +1738,12 @@ def __initialize_2( f_code.write("\n\n# machinery code\n\n") if self.tree.relationships_are_digitized: + if with_root_node_name is None: + with_root_node_name = self.tree.root_node_name root_dims = list( presorted( - self.tree.root_dataset.dims, + self.tree._graph.nodes[with_root_node_name]["dataset"].dims, self.dim_order, self.dim_exclude, ) @@ -1518,6 +1803,9 @@ def __initialize_2( nl_template = NL_1D_TEMPLATE.format(**locals()).format( **locals() ) + nl_template = NL_1D_TEMPLATE.format(**locals()).format( + **locals() + ) elif n_root_dims == 2: meta_template = IRUNNER_2D_TEMPLATE.format(**locals()).format( **locals() @@ -1536,7 +1824,6 @@ def __initialize_2( raise ValueError(f"invalid n_root_dims {n_root_dims}") else: - raise RuntimeError("digitization is now required") f_code.write(blacken(textwrap.dedent(line_template))) @@ -1610,7 +1897,9 @@ def load_raw(self, rg, args, runner=None, dtype=None, dot=None): assembled_args = [args.get(k) for k in self.arg_name_positions.keys()] for aa in assembled_args: if aa.dtype.kind != "i": - warnings.warn("position arguments are not all integers") + warnings.warn( + "position arguments are not all integers", stacklevel=2 + ) try: if runner is None: if dot is None: @@ -1777,7 +2066,13 @@ def _iload_raw( kwargs.update(nesting) if mask is not None: kwargs["mask"] = mask - tree_root_dims = rg.root_dataset.dims + + if self.with_root_node_name is None: + tree_root_dims = rg.root_dataset.dims + else: + tree_root_dims = rg._graph.nodes[self.with_root_node_name][ + "dataset" + ].dims argshape = [ tree_root_dims[i] for i in presorted(tree_root_dims, self.dim_order, self.dim_exclude) @@ -1903,7 +2198,9 @@ def check_cache_misses(self, *funcs, fresh=True, log_details=True): f"cache miss in {self.flow_hash}{warning_text}\n" f"Compile Time: {timers}" ) - warnings.warn(f"{self.flow_hash}", CacheMissWarning) + warnings.warn( + f"{self.flow_hash}", CacheMissWarning, stacklevel=1 + ) self.compiled_recently = True self._known_cache_misses[runner_name][k] = v return self.compiled_recently @@ -1987,9 +2284,18 @@ def _load( if logit_draws is None and logsums == 1: logit_draws = np.zeros(source.shape + (0,), dtype=dtype) - use_dims = list( - presorted(source.root_dataset.dims, self.dim_order, self.dim_exclude) - ) + if self.with_root_node_name is None: + use_dims = list( + presorted(source.root_dataset.dims, self.dim_order, self.dim_exclude) + ) + else: + use_dims = list( + presorted( + source._graph.nodes[self.with_root_node_name]["dataset"].dims, + self.dim_order, + self.dim_exclude, + ) + ) if logit_draws is not None: if dot is None: @@ -2155,14 +2461,24 @@ def _load( {k: result[:, n] for n, k in enumerate(self._raw_functions.keys())} ) elif as_dataarray: - if result_squeeze: result = squeeze(result, result_squeeze) result_p = squeeze(result_p, result_squeeze) pick_count = squeeze(pick_count, result_squeeze) - result_coords = { - k: v for k, v in source.root_dataset.coords.items() if k in result_dims - } + if self.with_root_node_name is None: + result_coords = { + k: v + for k, v in source.root_dataset.coords.items() + if k in result_dims + } + else: + result_coords = { + k: v + for k, v in source._graph.nodes[self.with_root_node_name][ + "dataset" + ].coords.items() + if k in result_dims + } if result is not None: result = xr.DataArray( result, @@ -2457,6 +2773,10 @@ def logit_draws( mask=mask, ) + @property + def defs(self): + return {k: v[0] for (k, v) in self._raw_functions.items()} + @property def function_names(self): return list(self._raw_functions.keys()) @@ -2501,7 +2821,7 @@ def show_code(self, linenos="inline"): from pygments.lexers.python import PythonLexer codefile = os.path.join(self.cache_dir, self.name, "__init__.py") - with open(codefile, "rt") as f_code: + with open(codefile) as f_code: code = f_code.read() pretty = highlight(code, PythonLexer(), HtmlFormatter(linenos=linenos)) css = HtmlFormatter().get_style_defs(".highlight") diff --git a/sharrow/nested_logit.py b/sharrow/nested_logit.py index 452df23..13e21fd 100644 --- a/sharrow/nested_logit.py +++ b/sharrow/nested_logit.py @@ -1,4 +1,4 @@ -from typing import Mapping +from collections.abc import Mapping import numba as nb import numpy as np @@ -17,7 +17,6 @@ def _utility_to_probability( logprob, # float output shape=[nodes] probability, # float output shape=[nodes] ): - for up in range(n_alts, utility.size): up_nest = up - n_alts n_children_for_parent = len_slots[up_nest] diff --git a/sharrow/relationships.py b/sharrow/relationships.py index 50d49aa..f3ffdc6 100644 --- a/sharrow/relationships.py +++ b/sharrow/relationships.py @@ -24,7 +24,8 @@ except ImportError: from astunparse import unparse as _unparse - unparse = lambda *args: _unparse(*args).strip("\n") + def unparse(*args): + return _unparse(*args).strip("\n") logger = logging.getLogger("sharrow") @@ -47,6 +48,8 @@ "clip", } +NOTSET = "<--NOTSET-->" + def _require_string(x): if not isinstance(x, str): @@ -197,10 +200,30 @@ def __init__( def __eq__(self, other): if isinstance(other, self.__class__): - return repr(self) == repr(other) + if self.analog: + left = ( + f" " + f"{self.child_data}[{self.child_name!r}]>" + ) + else: + left = repr(self) + if other.analog: + right = ( + f" " + f"{other.child_data}[{other.child_name!r}]>" + ) + else: + right = repr(other) + return left == right def __repr__(self): - return f" {self.child_data}[{self.child_name!r}]>" + return ( + f" " + f"{self.child_data}[{self.child_name!r}]>" + ) def attrs(self): return dict( @@ -209,6 +232,16 @@ def attrs(self): indexing=self.indexing, ) + def to_dict(self): + return dict( + parent_data=self.parent_data, + parent_name=self.parent_name, + child_data=self.child_data, + child_name=self.child_name, + indexing=self.indexing, + analog=self.analog, + ) + @classmethod def from_string(cls, s): """ @@ -257,7 +290,7 @@ class DataTree: Parameters ---------- graph : networkx.MultiDiGraph - root_node_name : str + root_node_name : str or False The name of the node at the root of the tree. extra_funcs : Tuple[Callable] Additional functions that can be called by Flow objects created @@ -366,7 +399,9 @@ def shape(self): def root_dims(self): from .flows import presorted - return tuple(presorted(self.root_dataset, self.dim_order, self.dim_exclude)) + return tuple( + presorted(self.root_dataset.dims, self.dim_order, self.dim_exclude) + ) def __shallow_copy_extras(self): return dict( @@ -439,15 +474,34 @@ def root_node_name(self): @root_node_name.setter def root_node_name(self, name): - if name is None: - self._root_node_name = None + if name is None or name is False: + self._root_node_name = name return if not isinstance(name, str): - raise TypeError(f"root_node_name must be str not {type(name)}") + raise TypeError( + f"root_node_name must be one of [str, None, False] not {type(name)}" + ) if name not in self._graph.nodes: raise KeyError(name) self._root_node_name = name + @property + def root_node_name_str(self): + """str: The root node for this data tree, which is only ever a parent. + + This method raises a ValueError if root node cannot be determined. + """ + if self._root_node_name is None: + for nodename in self._graph.nodes: + if self._graph.in_degree(nodename) == 0: + self._root_node_name = nodename + break + if self._root_node_name is None: + raise ValueError("root node cannot be determined") + if self._root_node_name is False: + raise ValueError("root node is False") + return self._root_node_name + def add_relationship(self, *args, **kwargs): """ Add a relationship to this DataTree. @@ -467,28 +521,7 @@ def add_relationship(self, *args, **kwargs): if len(args) == 1 and isinstance(args[0], Relationship): r = args[0] elif len(args) == 1 and isinstance(args[0], str): - s = args[0] - if "->" in s: - parent, child = s.split("->", 1) - i = "position" - elif "@" in s: - parent, child = s.split("@", 1) - i = "label" - else: - raise ValueError(f"cannot interpret relationship {s!r}") - p1, p2 = parent.split(".", 1) - c1, c2 = child.split(".", 1) - p1 = p1.strip() - p2 = p2.strip() - c1 = c1.strip() - c2 = c2.strip() - r = Relationship( - parent_data=p1, - parent_name=p2, - child_data=c1, - child_name=c2, - indexing=i, - ) + r = Relationship.from_string(args[0]) else: r = Relationship(*args, **kwargs) @@ -509,6 +542,13 @@ def get_relationship(self, parent, child): attrs = self._graph.edges[parent, child] return Relationship(parent_data=parent, child_data=child, **attrs) + def list_relationships(self) -> list[Relationship]: + """list : List all relationships defined in this tree.""" + result = [] + for e in self._graph.edges: + result.append(self._get_relationship(e)) + return result + def add_dataset(self, name, dataset, relationships=(), as_root=False): """ Add a new Dataset node to this DataTree. @@ -537,7 +577,6 @@ def add_dataset(self, name, dataset, relationships=(), as_root=False): self.digitize_relationships(inplace=True) def add_items(self, items): - from collections.abc import Mapping, Sequence if isinstance(items, Sequence): @@ -562,11 +601,11 @@ def add_items(self, items): @property def root_node(self): - return self._graph.nodes[self.root_node_name] + return self._graph.nodes[self.root_node_name_str] @property def root_dataset(self): - return self._graph.nodes[self.root_node_name]["dataset"] + return self._graph.nodes[self.root_node_name_str]["dataset"] @root_dataset.setter def root_dataset(self, x): @@ -574,7 +613,7 @@ def root_dataset(self, x): if not isinstance(x, Dataset): x = construct(x) - if self.root_node_name in self.replacement_filters: + if self.root_node_name_str in self.replacement_filters: x = self.replacement_filters[self.root_node_name](x) self._graph.nodes[self.root_node_name]["dataset"] = x @@ -628,16 +667,28 @@ def get(self, item, default=None, broadcast=True, coords=True): raise else: result = xr.DataArray(default) - root_dataset = self.root_dataset - if result.dims != self.root_dims and broadcast: - result, _ = xr.broadcast(result, root_dataset) - if coords: - add_coords = {} - for i in result.dims: - if i not in result.coords and i in root_dataset.coords: - add_coords[i] = root_dataset.coords[i] - if add_coords: - result = result.assign_coords(add_coords) + if self.root_node_name: + root_dataset = self.root_dataset + if result.dims != self.root_dims and broadcast: + result, _ = xr.broadcast(result, root_dataset) + if coords: + add_coords = {} + for i in result.dims: + if i not in result.coords and i in root_dataset.coords: + add_coords[i] = root_dataset.coords[i] + if add_coords: + result = result.assign_coords(add_coords) + elif self.root_node_name is False: + if "." in item: + item_in, item = item.split(".", 1) + base_dataset = self._graph.nodes[item_in]["dataset"] + if coords: + add_coords = {} + for i in result.dims: + if i not in result.coords and i in base_dataset.coords: + add_coords[i] = base_dataset.coords[i] + if add_coords: + result = result.assign_coords(add_coords) return result def finditem(self, item, maybe_in=None): @@ -655,7 +706,6 @@ def _getitem( just_node_name=False, dim_names_from_top=False, ): - if isinstance(item, (list, tuple)): from .dataset import Dataset @@ -663,11 +713,17 @@ def _getitem( if "." in item: item_in, item = item.split(".", 1) + queue = [self.root_node_name] + if self.root_node_name is False: + # when root_node_name is False, we don't want to broadcast + # back to the root, but instead only to the given `item_in` + queue = [item_in] + item_in = None else: item_in = None - - queue = [self.root_node_name] + queue = [self.root_node_name_str] examined = set() + start_from = queue[0] while len(queue): current_node = queue.pop(0) if current_node in examined: @@ -684,7 +740,7 @@ def _getitem( if (by_name or by_dims) and (item_in is None or item_in == current_node): if just_node_name: return current_node - if current_node == self.root_node_name: + if current_node == start_from: if by_dims: return xr.DataArray( pd.RangeIndex(dataset.dims[item]), dims=item @@ -709,19 +765,19 @@ def _getitem( dims_in_result = set(result.dims) top_dim_names = {} for path in nx.algorithms.simple_paths.all_simple_edge_paths( - self._graph, self.root_node_name, current_node + self._graph, start_from, current_node ): if dim_names_from_top: e = path[0] top_dim_name = self._graph.edges[e].get("parent_name") - root_dataset = self.root_dataset + start_dataset = self._graph.nodes[start_from]["dataset"] # deconvert digitized dim names back to native dims if ( - top_dim_name not in root_dataset.dims - and top_dim_name in root_dataset.variables + top_dim_name not in start_dataset.dims + and top_dim_name in start_dataset.variables ): - if root_dataset.variables[top_dim_name].ndim == 1: - top_dim_name = root_dataset.variables[ + if start_dataset.variables[top_dim_name].ndim == 1: + top_dim_name = start_dataset.variables[ top_dim_name ].dims[0] else: @@ -732,7 +788,7 @@ def _getitem( # path_indexing = self._graph.edges[path[-1]].get('indexing') t1 = None # intermediate nodes on path - for (e, e_next) in zip(path[:-1], path[1:]): + for e, e_next in zip(path[:-1], path[1:]): r = self._get_relationship(e) r_next = self._get_relationship(e_next) if t1 is None: @@ -826,6 +882,8 @@ def get_expr(self, expression, engine="sharrow", allow_native=True): result = DataArray( pd.eval(expression, resolvers=[self], engine="numexpr"), ) + else: + raise ValueError(f"unknown engine {engine}") from None return result @property @@ -844,12 +902,48 @@ def subspaces_iter(self): if s is not None: yield (k, s) + def contains_subspace(self, key) -> bool: + """ + Is this named Dataset in this tree's subspaces + + Parameters + ---------- + key : str + + Returns + ------- + bool + """ + return key in self._graph.nodes + + def get_subspace(self, key, default_empty=False) -> xr.Dataset: + """ + Access named Dataset from this tree's subspaces + + Parameters + ---------- + key : str + default_empty : bool, default False + Return an empty Dataset if the key is not found. + + Returns + ------- + xr.Dataset + """ + result = self._graph.nodes[key].get("dataset", None) + if result is None: + if default_empty: + result = xr.Dataset() + else: + raise KeyError(key) + return result + def namespace_names(self): namespace = set() for spacename, spacearrays in self.subspaces_iter(): - for k, arr in spacearrays.coords.items(): + for k, _arr in spacearrays.coords.items(): namespace.add(f"__{spacename or 'base'}__{k}") - for k, arr in spacearrays.items(): + for k, _arr in spacearrays.items(): if k.startswith("_s_"): namespace.add(f"__{spacename or 'base'}__{k}__indptr") namespace.add(f"__{spacename or 'base'}__{k}__indices") @@ -864,7 +958,7 @@ def dims(self): Mapping from dimension names to lengths across all dataset nodes. """ dims = {} - for k, v in self.subspaces_iter(): + for _k, v in self.subspaces_iter(): for name, length in v.dims.items(): if name in dims: if dims[name] != length: @@ -928,7 +1022,7 @@ def drop_dims(self, dims, inplace=False, ignore_missing_dims=True): # remove subspaces that rely on dropped dim boot_queue = set() booted = set() - for (up, dn, n), e in obj._graph.edges.items(): + for (up, dn, _n), e in obj._graph.edges.items(): if up == obj.root_node_name: _analog = e.get("analog", "") if _analog in dims: @@ -943,7 +1037,7 @@ def drop_dims(self, dims, inplace=False, ignore_missing_dims=True): while boot_queue: b = boot_queue.pop() booted.add(b) - for (up, dn, n), e in obj._graph.edges.items(): + for up, dn, _n in obj._graph.edges.keys(): if up == b: boot_queue.add(dn) @@ -985,7 +1079,8 @@ def get_indexes( if result_shape != result_k.shape: if check_shapes: raise ValueError( - f"inconsistent index shapes {result_k.shape} v {result_shape} (probably an error on {k} or {sorted(dims)[0]})" + f"inconsistent index shapes {result_k.shape} v {result_shape} " + f"(probably an error on {k} or {sorted(dims)[0]})" ) result[k] = result_k @@ -1072,6 +1167,7 @@ def setup_flow( write_hash_audit=True, hashing_level=1, dim_exclude=None, + with_root_node_name=None, ): """ Set up a new Flow for analysis using the structure of this DataTree. @@ -1156,6 +1252,7 @@ def setup_flow( write_hash_audit=write_hash_audit, dim_order=self.dim_order, dim_exclude=dim_exclude, + with_root_node_name=with_root_node_name, ) def _spill(self, all_name_tokens=()): @@ -1235,13 +1332,18 @@ def digitize_relationships(self, inplace=False, redigitize=True): # vectorize version mapper = {i: j for (j, i) in enumerate(_dataarray_to_numpy(downstream))} + + def mapper_get(x, mapper=mapper): + return mapper.get(x, 0) + if upstream.size: - offsets = xr.apply_ufunc(np.vectorize(mapper.get), upstream) + offsets = xr.apply_ufunc(np.vectorize(mapper_get), upstream) else: offsets = xr.DataArray([], dims=["index"]) if offsets.dtype.kind != "i": warnings.warn( f"detected missing values in digitizing {r.parent_data}.{r.parent_name}", + stacklevel=2, ) # candidate name for write back @@ -1291,7 +1393,6 @@ def relationships_are_digitized(self): def _arg_tokenizer( self, spacename, spacearray, spacearrayname, exclude_dims=None, blends=None ): - if blends is None: blends = {} @@ -1351,29 +1452,40 @@ def _arg_tokenizer( try: upside = ", ".join(unparse(t) for t in upside_ast) except: # noqa: E722 - for t in upside_ast: - str_t = str(t) - if len(str_t) < 2000: - print(f"t:{str_t}") - else: - print(f"t:{str_t[:200]}...") - raise - - # check for redirection target - if retarget is not None: - tokens.append( - f"__{spacename}___digitized_{retarget}_of_{this_dim_name}[__{parent_data}__{parent_name}[{upside}]]" - ) - else: - tokens.append(f"__{parent_data}__{parent_name}[{upside}]") + if self.root_node_name is False: + upside = None + else: + print(f"{parent_data=}") + print(f"{parent_name=}") + print(f"{spacearrayname=}") + print(f"{exclude_dims=}") + print(f"{blends=}") + for t in upside_ast: + str_t = str(t) + if len(str_t) < 2000: + print(f"t:{str_t}") + else: + print(f"t:{str_t[:200]}...") + raise + if upside is not None: + # check for redirection target + if retarget is not None: + tokens.append( + f"__{spacename}___digitized_{retarget}_of_{this_dim_name}[__{parent_data}__{parent_name}[{upside}]]" + ) + else: + tokens.append(f"__{parent_data}__{parent_name}[{upside}]") found_token = True break if not found_token: if dimname in self.subspaces[spacename].indexes: - ix = self.subspaces[spacename].indexes[dimname] - ix = {i: n for n, i in enumerate(ix)} - tokens.append(ix) - n_missing_tokens += 1 + if self.root_node_name is False: + tokens.append(False) + else: + ix = self.subspaces[spacename].indexes[dimname] + ix = {i: n for n, i in enumerate(ix)} + tokens.append(ix) + n_missing_tokens += 1 elif dimname.endswith("_indices") or dimname.endswith("_indptr"): tokens.append(None) # this dimension corresponds to a blender @@ -1394,7 +1506,7 @@ def coords(self): return self.root_dataset.coords def get_index(self, dim): - for spacename, subspace in self.subspaces.items(): + for _spacename, subspace in self.subspaces.items(): if dim in subspace.coords: return subspace.indexes[dim] diff --git a/sharrow/selectors.py b/sharrow/selectors.py index 5fb6bef..b057f89 100644 --- a/sharrow/selectors.py +++ b/sharrow/selectors.py @@ -127,9 +127,15 @@ def _filter( ds = self._obj if _load: ds = ds.load() + if _func == "isel": + # remove coordinates, we don't need them for isel + ds_ = ds.drop_vars(ds.coords) + else: + ds_ = ds + if _names: result = ( - getattr(ds, _func)(**loaders) + getattr(ds_, _func)(**loaders) .digital_encoding.strip(_names) .drop_vars(_baggage) ) @@ -149,7 +155,7 @@ def _filter( result = result[_name] return result else: - result = getattr(ds, _func)(**loaders) + result = getattr(ds_, _func)(**loaders) names = list(result.keys()) for n in names: if self._obj.redirection.is_blended(n): @@ -202,17 +208,34 @@ def __call__( ): modified_idxs = {} raw_idxs = {} + + keep_raw_idxs = False + is_it_blended = [] + if isinstance(_name, str): + is_it_blended = [_name] + elif _names is not None: + is_it_blended = _names + for n in is_it_blended: + try: + if self._obj.redirection.is_blended(n): + keep_raw_idxs = True + except AttributeError: + keep_raw_idxs = True + for k, v in idxs.items(): target = self._obj.redirection.target(k) if target is None: - raw_idxs[k] = modified_idxs[k] = v + modified_idxs[k] = v + if keep_raw_idxs: + raw_idxs[k] = v else: v_ = np.asarray(v) modified_idxs[target] = self._obj[f"_digitized_{target}_of_{k}"][ v_ ].to_numpy() - raw_idxs[target] = v_ # self._obj[k][v_].to_numpy() - return self._filter( + if keep_raw_idxs: + raw_idxs[target] = v_ # self._obj[k][v_].to_numpy() + out = self._filter( _name=_name, _names=_names, _load=_load, @@ -221,6 +244,7 @@ def __call__( _raw_idxs=raw_idxs, **modified_idxs, ) + return out @xr.register_dataset_accessor("at") diff --git a/sharrow/shared_memory.py b/sharrow/shared_memory.py index 8b493eb..3bd273b 100644 --- a/sharrow/shared_memory.py +++ b/sharrow/shared_memory.py @@ -23,7 +23,6 @@ def si_units(x, kind="B", digits=3, shift=1000): - # nano micro milli kilo mega giga tera peta exa zeta yotta tiers = ["n", "ยต", "m", "", "K", "M", "G", "T", "P", "E", "Z", "Y"] @@ -99,7 +98,7 @@ def create_shared_memory_array(key, size): size=size, ) except FileExistsError: - raise FileExistsError(f"sharrow_shared_memory_array:{key}") + raise FileExistsError(f"sharrow_shared_memory_array:{key}") from None __GLOBAL_MEMORY_ARRAYS[key] = result return result @@ -138,7 +137,7 @@ def open_shared_memory_array(key, mode="r+"): create=False, ) except FileNotFoundError: - raise FileNotFoundError(f"sharrow_shared_memory_array:{key}") + raise FileNotFoundError(f"sharrow_shared_memory_array:{key}") from None else: logger.info( f"shared memory array from ephemeral memory, {si_units(result.size)}" @@ -175,7 +174,7 @@ def create_shared_list(content, key): name=h, ) except FileExistsError: - raise FileExistsError(f"sharrow_shared_memory_list:{key}") + raise FileExistsError(f"sharrow_shared_memory_list:{key}") from None __GLOBAL_MEMORY_LISTS[key] = result return result @@ -190,7 +189,7 @@ def read_shared_list(key): try: sl = ShareableList(name=_hexhash(f"sharrow__list__{key}")) except FileNotFoundError: - raise FileNotFoundError(f"sharrow_shared_memory_list:{key}") + raise FileNotFoundError(f"sharrow_shared_memory_list:{key}") from None else: return sl @@ -202,7 +201,7 @@ def get_shared_list_nbytes(key): try: shm = SharedMemory(name=h, create=False) except FileNotFoundError: - raise FileNotFoundError(f"sharrow_shared_memory_list:{key}") + raise FileNotFoundError(f"sharrow_shared_memory_list:{key}") from None else: return shm.size @@ -219,7 +218,6 @@ def delete_shared_memory_files(key): @xr.register_dataset_accessor("shm") class SharedMemDatasetAccessor: - _parent_class = xr.Dataset def __init__(self, xarray_obj): @@ -397,7 +395,7 @@ def shared_memory_key(self): try: return self._shared_memory_key_ except AttributeError: - raise ValueError("this dataset is not in shared memory") + raise ValueError("this dataset is not in shared memory") from None @classmethod def from_shared_memory(cls, key, own_data=False, mode="r+"): @@ -431,7 +429,7 @@ def from_shared_memory(cls, key, own_data=False, mode="r+"): # for memmap, list is loaded from pickle, not shared ram pass - if own_data and not (own_data is True): + if own_data and own_data is not True: mem = own_data own_data = True else: @@ -513,7 +511,7 @@ def shared_memory_size(self): try: return sum(i.size for i in self._shared_memory_objs_) except AttributeError: - raise ValueError("this dataset is not in shared memory") + raise ValueError("this dataset is not in shared memory") from None @property def is_shared_memory(self): diff --git a/sharrow/sparse.py b/sharrow/sparse.py index 2adc8f0..f424e39 100644 --- a/sharrow/sparse.py +++ b/sharrow/sparse.py @@ -1,3 +1,5 @@ +import math + import numba as nb import numpy as np import pandas as pd @@ -197,15 +199,17 @@ def blenders(self): return b -@nb.generated_jit(nopython=True) +# fastmath must be false to ensure NaNs are detected here. +# wrapping this as such allows fastmath to be turned on in outer functions +# but not lose the ability to check for NaNs. Older versions of this function +# checked whether the float cast to an integer was -9223372036854775808, but +# that turns out to be not compatible with all hardware (i.e. Apple Silicon). +@nb.generated_jit(nopython=True, fastmath=False) def isnan_fast_safe(x): if isinstance(x, nb.types.Float): def func(x): - if int(x) == -9223372036854775808: - return True - else: - return False + return math.isnan(x) return func elif isinstance(x, (nb.types.UnicodeType, nb.types.UnicodeCharSeq)): diff --git a/sharrow/table.py b/sharrow/table.py index b964080..adde6ba 100644 --- a/sharrow/table.py +++ b/sharrow/table.py @@ -257,9 +257,9 @@ def from_quilt(cls, path, blockname=None): stopper = blockname else: qlog = os.path.join(path, "quilt.log") - with open(qlog, "rt") as logreader: + with open(qlog) as logreader: existing_info = yaml.safe_load(logreader) - for stopper, block in enumerate(existing_info): + for _stopper, block in enumerate(existing_info): if block.get("name", None) == blockname: break else: @@ -267,8 +267,13 @@ def from_quilt(cls, path, blockname=None): else: stopper = 1e99 n = 0 - rowfile = lambda n: os.path.join(path, f"block.{n:03d}.rows") - colfile = lambda n: os.path.join(path, f"block.{n:03d}.cols") + + def rowfile(n): + return os.path.join(path, f"block.{n:03d}.rows") + + def colfile(n): + return os.path.join(path, f"block.{n:03d}.cols") + builder = None look = True while look and n <= stopper: @@ -291,7 +296,7 @@ def from_quilt(cls, path, blockname=None): n += 1 if builder is not None: metadata = builder.schema.metadata - metadata[b"quilt_number"] = f"{n}".encode("utf8") + metadata[b"quilt_number"] = f"{n}".encode() return builder.replace_schema_metadata(metadata) return None @@ -305,7 +310,7 @@ def to_quilt(self, path, blockname=None): ex_cols = [] max_block = -1 else: - with open(qlog, "rt") as logreader: + with open(qlog) as logreader: existing_info = yaml.safe_load(logreader) ex_rows = sum(block.get("rows", 0) for block in existing_info) ex_cols = sum((block.get("cols", []) for block in existing_info), []) diff --git a/sharrow/tests/conftest.py b/sharrow/tests/conftest.py new file mode 100644 index 0000000..e532de1 --- /dev/null +++ b/sharrow/tests/conftest.py @@ -0,0 +1,54 @@ +import pandas as pd +import pytest +import xarray as xr + +from sharrow.dataset import construct + + +@pytest.fixture +def person_dataset() -> xr.Dataset: + """ + Sample persons dataset with dummy data. + """ + df = pd.DataFrame( + { + "Income": [45, 88, 56, 15, 71], + "Name": ["Andre", "Bruce", "Carol", "David", "Eugene"], + "Age": [14, 25, 55, 8, 21], + "WorkMode": ["Car", "Bus", "Car", "Car", "Walk"], + "household_id": [11, 11, 22, 22, 33], + }, + index=pd.Index([441, 445, 552, 556, 934], name="person_id"), + ) + df["WorkMode"] = df["WorkMode"].astype("category") + return construct(df) + + +@pytest.fixture +def household_dataset() -> xr.Dataset: + """ + Sample household dataset with dummy data. + """ + df = pd.DataFrame( + { + "n_cars": [1, 2, 1], + }, + index=pd.Index([11, 22, 33], name="household_id"), + ) + return construct(df) + + +@pytest.fixture +def tours_dataset() -> xr.Dataset: + """ + Sample tours dataset with dummy data. + """ + df = pd.DataFrame( + { + "TourMode": ["Car", "Bus", "Car", "Car", "Walk"], + "person_id": [441, 445, 552, 556, 934], + }, + index=pd.Index([4411, 4451, 5521, 5561, 9341], name="tour_id"), + ) + df["TourMode"] = df["TourMode"].astype("category") + return construct(df) diff --git a/sharrow/tests/test_categorical.py b/sharrow/tests/test_categorical.py new file mode 100644 index 0000000..b500112 --- /dev/null +++ b/sharrow/tests/test_categorical.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +from enum import IntEnum + +import numpy as np +import pandas as pd +import xarray as xr + +import sharrow + + +def test_simple_cat(tours_dataset: xr.Dataset): + tree = sharrow.DataTree(tours=tours_dataset) + + assert all(tours_dataset.TourMode.cat.categories == ["Bus", "Car", "Walk"]) + + expr = "tours.TourMode == 'Bus'" + f = tree.setup_flow({expr: expr}) + a = f.load_dataarray(dtype=np.int8) + a = a.isel(expressions=0) + assert all(a == np.asarray([0, 1, 0, 0, 0])) + + tour_mode_bus = tree.get_expr(expr) + assert all(tour_mode_bus == np.asarray([0, 1, 0, 0, 0])) + + +def test_2_level_tree_cat( + tours_dataset: xr.Dataset, + person_dataset: xr.Dataset, +): + tree = sharrow.DataTree(tours=tours_dataset) + tree.add_dataset("persons", person_dataset, "tours.person_id @ persons.person_id") + + assert all(tours_dataset.TourMode.cat.categories == ["Bus", "Car", "Walk"]) + + expr = "tours.TourMode == 'Bus'" + f = tree.setup_flow({expr: expr}) + a = f.load_dataarray(dtype=np.int8) + a = a.isel(expressions=0) + assert all(a == np.asarray([0, 1, 0, 0, 0])) + + tour_mode_bus = tree.get_expr(expr) + assert all(tour_mode_bus == np.asarray([0, 1, 0, 0, 0])) + + work_mode_bus = tree.get_expr("WorkMode == 'Walk'") + assert all(work_mode_bus == np.asarray([0, 0, 0, 0, 1])) + + work_mode_bus1 = tree.get_expr("persons.WorkMode == 'Walk'") + assert all(work_mode_bus1 == np.asarray([0, 0, 0, 0, 1])) + + +def test_3_level_tree_cat( + tours_dataset: xr.Dataset, + person_dataset: xr.Dataset, + household_dataset: xr.Dataset, +): + tree = sharrow.DataTree(tours=tours_dataset) + tree.add_dataset("persons", person_dataset, "tours.person_id @ persons.person_id") + tree.add_dataset( + "households", person_dataset, "persons.household_id @ households.household_id" + ) + + assert all(tours_dataset.TourMode.cat.categories == ["Bus", "Car", "Walk"]) + + expr = "tours.TourMode == 'Bus'" + f = tree.setup_flow({expr: expr}) + a = f.load_dataarray(dtype=np.int8) + a = a.isel(expressions=0) + assert all(a == np.asarray([0, 1, 0, 0, 0])) + + tour_mode_bus = tree.get_expr(expr) + assert all(tour_mode_bus == np.asarray([0, 1, 0, 0, 0])) + + work_mode_bus = tree.get_expr("WorkMode == 'Walk'") + assert all(work_mode_bus == np.asarray([0, 0, 0, 0, 1])) + + work_mode_bus1 = tree.get_expr("persons.WorkMode == 'Walk'") + assert all(work_mode_bus1 == np.asarray([0, 0, 0, 0, 1])) + + +def test_rootless_tree_cat( + tours_dataset: xr.Dataset, + person_dataset: xr.Dataset, + household_dataset: xr.Dataset, +): + tree = sharrow.DataTree(tours=tours_dataset, root_node_name=False) + tree.add_dataset("persons", person_dataset, "tours.person_id @ persons.person_id") + tree.add_dataset( + "households", person_dataset, "persons.household_id @ households.household_id" + ) + + assert all(tours_dataset.TourMode.cat.categories == ["Bus", "Car", "Walk"]) + + expr = "tours.TourMode == 'Bus'" + f = tree.setup_flow({expr: expr}, with_root_node_name="tours") + a = f.load_dataarray(dtype=np.int8) + a = a.isel(expressions=0) + assert all(a == np.asarray([0, 1, 0, 0, 0])) + + +def test_int_enum_categorical(): + class TourMode(IntEnum): + Car = 1 + Bus = 2 + Walk = 3 + + df = pd.DataFrame( + { + "TourMode": ["Car", "Bus", "Car", "Car", "Walk"], + "person_id": [441, 445, 552, 556, 934], + }, + index=pd.Index([4411, 4451, 5521, 5561, 9341], name="tour_id"), + ) + df["TourMode2"] = df["TourMode"].as_int_enum(TourMode) + assert df["TourMode2"].dtype == "category" + assert all(df["TourMode2"].cat.categories == ["_0", "Car", "Bus", "Walk"]) + assert all(df["TourMode2"].cat.codes == [1, 2, 1, 1, 3]) diff --git a/sharrow/tests/test_datasets.py b/sharrow/tests/test_datasets.py index f2b8b09..d2da374 100644 --- a/sharrow/tests/test_datasets.py +++ b/sharrow/tests/test_datasets.py @@ -3,6 +3,7 @@ import numpy as np import openmatrix +import pandas as pd from pytest import approx import sharrow as sh @@ -35,3 +36,32 @@ def test_dataset_construct_with_zoneids(): with openmatrix.open_file(t.joinpath("dummy5.omx"), mode="r") as back: ds1 = sh.dataset.from_omx(back, indexes="one-based") assert ds1.coords["otaz"].values == approx(np.asarray([1, 2, 3, 4, 5])) + + +def test_dataset_categoricals(): + hhs = sh.example_data.get_households() + + def income_cat(i): + if i < 12500: + return "LOW" + elif i < 45000: + return "MID" + else: + return "HIGH" + + hhs["income_grp"] = hhs.income.apply(income_cat).astype( + pd.CategoricalDtype(["LOW", "MID", "HIGH"], ordered=True) + ) + assert hhs["income_grp"].dtype == "category" + + hd = sh.dataset.construct(hhs) + assert hd["income_grp"].dtype == np.int8 + + # affirm we can recover categorical and non-categorical data from datarrays + pd.testing.assert_series_equal( + hhs["income_grp"], hd.income_grp.single_dim.to_pandas() + ) + pd.testing.assert_series_equal(hhs["income"], hd.income.single_dim.to_pandas()) + + recovered_df = hd.single_dim.to_pandas() + pd.testing.assert_frame_equal(hhs, recovered_df) diff --git a/sharrow/tests/test_datastore.py b/sharrow/tests/test_datastore.py new file mode 100644 index 0000000..6d0c974 --- /dev/null +++ b/sharrow/tests/test_datastore.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +import shutil +from pathlib import Path + +import pytest +import xarray as xr + +from sharrow.datastore import DataStore, ReadOnlyError + + +def test_datasstore_checkpointing(tmp_path: Path, person_dataset): + tm = DataStore(directory=tmp_path) + tm["persons"] = person_dataset + tm.make_checkpoint("init_persons") + + person_dataset["DoubleAge"] = person_dataset["Age"] * 2 + tm.update("persons", person_dataset["DoubleAge"]) + tm.make_checkpoint("annot_persons") + + tm2 = DataStore(directory=tmp_path) + tm2.restore_checkpoint("annot_persons") + xr.testing.assert_equal(tm2.get_dataset("persons"), person_dataset) + + tm2.restore_checkpoint("init_persons") + assert "DoubleAge" not in tm2.get_dataset("persons") + + tm_ro = DataStore(directory=tmp_path, mode="r") + with pytest.raises(ReadOnlyError): + tm_ro.make_checkpoint("will-fail") + + +def test_datasstore_checkpointing_parquet(tmp_path: Path, person_dataset): + tm = DataStore(directory=tmp_path, storage_format="parquet") + tm["persons"] = person_dataset + tm.make_checkpoint("init_persons") + + person_dataset["DoubleAge"] = person_dataset["Age"] * 2 + tm.update("persons", person_dataset["DoubleAge"]) + tm.make_checkpoint("annot_persons") + + tm2 = DataStore(directory=tmp_path) + tm2.restore_checkpoint("annot_persons") + xr.testing.assert_equal(tm2.get_dataset("persons"), person_dataset) + + tm2.restore_checkpoint("init_persons") + assert "DoubleAge" not in tm2.get_dataset("persons") + + tm_ro = DataStore(directory=tmp_path, mode="r") + with pytest.raises(ReadOnlyError): + tm_ro.make_checkpoint("will-fail") + + +def test_datasstore_relationships( + tmp_path: Path, person_dataset, household_dataset, tours_dataset +): + pth = tmp_path.joinpath("relations") + + if pth.exists(): + shutil.rmtree(pth) + + pth.mkdir(parents=True, exist_ok=True) + tm = DataStore(directory=pth) + + tm["persons"] = person_dataset + tm.make_checkpoint("init_persons") + + tm["households"] = household_dataset + tm.add_relationship("persons.household_id @ households.household_id") + tm.make_checkpoint("init_households") + + tm["tours"] = tours_dataset + tm.add_relationship("tours.person_id @ persons.person_id") + tm.make_checkpoint("init_tours") + + tm.digitize_relationships() + assert tm.relationships_are_digitized + + tm.make_checkpoint("digitized") + + tm2 = DataStore(directory=pth, mode="r") + tm2.read_metadata("*") + tm2.restore_checkpoint("init_households") + + assert sorted(tm2.get_dataset("persons")) == [ + "Age", + "Income", + "Name", + "WorkMode", + "household_id", + ] + + assert sorted(tm2.get_dataset("households")) == [ + "n_cars", + ] + + tm2.restore_checkpoint("digitized") + assert sorted(tm2.get_dataset("persons")) == [ + "Age", + "Income", + "Name", + "WorkMode", + "digitizedOffsethousehold_id_households_household_id", + "household_id", + ] + + double_age = tm2.get_dataset("persons")["Age"] * 2 + with pytest.raises(ReadOnlyError): + tm2.update("persons", double_age.rename("doubleAge")) + + with pytest.raises(ReadOnlyError): + tm2.make_checkpoint("age-x2") + + tm.update("persons", double_age.rename("doubleAge")) + assert sorted(tm.get_dataset("persons")) == [ + "Age", + "Income", + "Name", + "WorkMode", + "digitizedOffsethousehold_id_households_household_id", + "doubleAge", + "household_id", + ] + + tm.make_checkpoint("age-x2") + tm2.read_metadata() + tm2.restore_checkpoint("age-x2") + + person_restored = tm2.get_dataframe("persons") + print(person_restored.WorkMode.dtype) + assert person_restored.WorkMode.dtype == "category" diff --git a/sharrow/tests/test_relationships.py b/sharrow/tests/test_relationships.py index f2add2b..7a76dfb 100644 --- a/sharrow/tests/test_relationships.py +++ b/sharrow/tests/test_relationships.py @@ -30,7 +30,6 @@ def skims(): def test_shared_data(dataframe_regression, households, skims): - tree = DataTree( base=households, skims=skims, @@ -74,8 +73,52 @@ def test_shared_data(dataframe_regression, households, skims): dataframe_regression.check(result2, basename="test_shared_data_2") -def test_shared_data_reversible(dataframe_regression, households, skims): +def test_subspace_fallbacks(dataframe_regression, households, skims): + tree = DataTree( + base=households, + skims=skims, + relationships=( + "base.otaz_idx->skims.otaz", + "base.dtaz_idx->skims.dtaz", + "base.timeperiod5->skims.time_period", + ), + ) + tree.subspace_fallbacks["df"] = ["base", "skims"] + + flow1 = tree.setup_flow( + { + "income": "df['income']", + "sov_time_by_income": "df['SOV_TIME']/df['income']", + "sov_cost_by_income": "df['HOV3_TIME']", + } + ) + result1 = flow1._load(tree, as_dataframe=True) + dataframe_regression.check(result1, basename="test_shared_data") + + flow2 = tree.setup_flow( + { + "income": "income", + "sov_time_by_income": "SOV_TIME/income", + "sov_cost_by_income": "HOV3_TIME", + } + ) + result2 = flow2._load(tree, as_dataframe=True) + dataframe_regression.check(result2, basename="test_shared_data") + + # names that are not valid Python identifiers + flow3 = tree.setup_flow( + { + "income > 10k": "df.income > 10_000", + "income [up to 10k]": "df.income <= 10_000", + "sov_time / income": "df.SOV_TIME/df.income", + "log1p(sov_cost_by_income)": "log1p(df.HOV3_TIME)", + } + ) + result3 = flow3._load(tree, as_dataframe=True) + dataframe_regression.check(result3, basename="test_shared_data_2") + +def test_shared_data_reversible(dataframe_regression, households, skims): tree = DataTree( base=households, odt_skims=skims, @@ -241,7 +284,6 @@ def test_with_2d_base(dataframe_regression): def test_mixed_dtypes(dataframe_regression, households, skims): - tree = DataTree( base=households, skims=skims, @@ -326,7 +368,6 @@ def _get_target(q, token): sys.version_info < (3, 8), reason="shared memory requires python3.8 or higher" ) def test_shared_memory(skims): - token = "skims" + secrets.token_hex(5) skims_2 = skims.shm.to_shared_memory(token) @@ -364,7 +405,6 @@ def test_relationship_init(): def test_replacement_filters(dataframe_regression, households, skims): - tree = DataTree( base=households, skims=skims, @@ -402,7 +442,6 @@ def rename_jncome(x): def test_name_in_wrong_subspace(dataframe_regression, households, skims): - tree = DataTree( base=households, skims=skims, @@ -471,7 +510,6 @@ def test_name_in_wrong_subspace(dataframe_regression, households, skims): def test_shared_data_encoded(dataframe_regression, households, skims): - households = sharrow.dataset.construct(households).digital_encoding.set( "income", bitwidth=32, @@ -562,7 +600,6 @@ def test_joint_dict_encoded(dataframe_regression, skims): def test_isin_and_between(dataframe_regression): - data = example_data.get_data() persons = data["persons"] @@ -645,7 +682,6 @@ def test_isin_and_between(dataframe_regression): def test_nested_where(dataframe_regression): - data = example_data.get_data() base = persons = data["persons"] @@ -767,7 +803,6 @@ def test_isna(): def test_get(dataframe_regression, households, skims): - tree = DataTree( base=households, skims=skims, @@ -778,9 +813,9 @@ def test_get(dataframe_regression, households, skims): ), ) - ss = tree.setup_flow( + flow1 = tree.setup_flow( { - "income": "base.get('income', 0)", + "income": "base.get('income', 0) + base.get('missing_one', 0)", "sov_time_by_income": "skims.SOV_TIME/base.get('income', 0)", "missing_data": "base.get('missing_data', -1)", "missing_skim": "skims.get('missing_core', -2)", @@ -788,12 +823,40 @@ def test_get(dataframe_regression, households, skims): "sov_cost_by_income_2": "skims.get('HOV3_TIME', 999)", }, ) - result = ss._load(tree, as_dataframe=True) + result = flow1._load(tree, as_dataframe=True) dataframe_regression.check(result) - s2 = tree.setup_flow( + tree_plus = DataTree( + base=households.assign(missing_one=1.0), + skims=skims, + relationships=( + "base.otaz_idx->skims.otaz", + "base.dtaz_idx->skims.dtaz", + "base.timeperiod5->skims.time_period", + ), + ) + flow2 = tree_plus.setup_flow(flow1.defs) + result = flow2._load(tree_plus, as_dataframe=True) + dataframe_regression.check(result.eval("income = income-1")) + assert flow2.flow_hash != flow1.flow_hash + + tree.subspace_fallbacks["df"] = ["base"] + flow3 = tree.setup_flow( + { + "income": "base.get('income', 0)", + "sov_time_by_income": "skims.SOV_TIME/df.get('income', 0)", + "missing_data": "df.get('missing_data', -1)", + "missing_skim": "skims.get('missing_core', -2)", + "sov_time_by_income_2": "skims.get('SOV_TIME')/df.income", + "sov_cost_by_income_2": "skims.get('HOV3_TIME', 999)", + }, + ) + result = flow3._load(tree, as_dataframe=True) + dataframe_regression.check(result) + + flow4 = tree.setup_flow( { - "income": "base.get('income', default=0)", + "income": "base.get('income', default=0) + df.get('missing_one', 0)", "sov_time_by_income": "skims.SOV_TIME/base.get('income', default=0)", "missing_data": "base.get('missing_data', default=-1)", "missing_skim": "skims.get('missing_core', default=-2)", @@ -801,10 +864,25 @@ def test_get(dataframe_regression, households, skims): "sov_cost_by_income_2": "skims.get('HOV3_TIME', default=999)", }, ) - result = s2._load(tree, as_dataframe=True) - - assert s2.flow_hash != ss.flow_hash + result = flow4._load(tree, as_dataframe=True) + assert flow4.flow_hash != flow1.flow_hash + dataframe_regression.check(result) + # test get when inside another function + flow5 = tree.setup_flow( + { + "income": "np.power(base.get('income', default=0) + df.get('missing_one', 0), 1)", + "sov_time_by_income": "skims.SOV_TIME/np.power(base.get('income', default=0), 1)", + "missing_data": "np.where(np.isnan(df.get('missing_data', default=1)), 0, df.get('missing_data', default=-1))", # noqa: E501 + "missing_skim": "(np.where(np.isnan(df.get('num_escortees', np.nan)), -2 , df.get('num_escortees', np.nan)))", # noqa: E501 + "sov_time_by_income_2": "skims.get('SOV_TIME', default=0)/base.income", + "sov_cost_by_income_2": "skims.get('HOV3_TIME', default=999)", + }, + ) + result = flow5._load(tree, as_dataframe=True) + assert "__skims__HOV3_TIME:True" in flow5._optional_get_tokens + assert "__df__missing_data:False" in flow5._optional_get_tokens + assert "__df__num_escortees:False" in flow5._optional_get_tokens dataframe_regression.check(result) diff --git a/sharrow/translate.py b/sharrow/translate.py index c0e5932..2e90611 100644 --- a/sharrow/translate.py +++ b/sharrow/translate.py @@ -21,7 +21,6 @@ def omx_to_zarr( time_periods=None, time_period_sep="__", ): - bucket = {} r1 = r2 = None diff --git a/sharrow/utils/__init__.py b/sharrow/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sharrow/utils/tar_zst.py b/sharrow/utils/tar_zst.py new file mode 100644 index 0000000..a5e78ad --- /dev/null +++ b/sharrow/utils/tar_zst.py @@ -0,0 +1,104 @@ +# based on https://gist.github.com/scivision/ad241e9cf0474e267240e196d7545eca + +import os +import sys +import tarfile +import tempfile +from pathlib import Path + +try: + import zstandard # pip install zstandard +except ModuleNotFoundError: + zstandard = None + + +def extract_zst(archive: Path, out_path: Path): + """ + extract .zst file + works on Windows, Linux, MacOS, etc. + + Parameters + ---------- + archive: pathlib.Path or str + .zst file to extract + out_path: pathlib.Path or str + directory to extract files and directories to + """ + + if zstandard is None: + raise ImportError("pip install zstandard") + + archive = Path(archive).expanduser() + out_path = Path(out_path).expanduser().resolve() + # need .resolve() in case intermediate relative dir doesn't exist + + dctx = zstandard.ZstdDecompressor() + + with tempfile.TemporaryFile(suffix=".tar") as ofh: + with archive.open("rb") as ifh: + dctx.copy_stream(ifh, ofh) + ofh.seek(0) + with tarfile.open(fileobj=ofh) as z: + z.extractall(out_path) + + +def compress_zst(in_path: Path, archive: Path): + """ + Compress a directory into a .tar.zst file. + + Certain hidden files are excluded, including .git directories and + macOS's .DS_Store files. + + Parameters + ---------- + in_path: pathlib.Path or str + directory to compress + archive: pathlib.Path or str + .tar.zst file to compress into + """ + if zstandard is None: + raise ImportError("pip install zstandard") + dctx = zstandard.ZstdCompressor(level=9, threads=-1, write_checksum=True) + with tempfile.TemporaryFile(suffix=".tar") as ofh: + with tarfile.open(fileobj=ofh, mode="w") as z: + for dirpath, dirnames, filenames in os.walk(in_path): + if os.path.basename(dirpath) == ".git": + continue + for n in range(len(dirnames) - 1, -1, -1): + if dirnames[n] == ".git" or dirnames[n].startswith("---"): + dirnames.pop(n) + for f in filenames: + if f.startswith(".git") or f == ".DS_Store" or f.startswith("---"): + continue + finame = Path(os.path.join(dirpath, f)) + arcname = finame.relative_to(in_path) + print(f"> {arcname}") + z.add(finame, arcname=arcname) + ofh.seek(0) + with archive.open("wb") as ifh: + dctx.copy_stream(ofh, ifh) + + +if __name__ == "__main__": + x = Path(sys.argv[1]) + name = x.name + if name.endswith(".tar.zst"): + y = x.with_name(name[:-8]) + if x.exists(): + if not y.exists(): + print(f"extracting from: {x}") + extract_zst(x, y) + else: + print(f"not extracting, existing target: {y}") + else: + print(f"not extracting, does not exist: {x}") + else: + y = x.with_name(name + ".tar.zst") + if x.exists(): + if not y.exists(): + print(f"compressing to tar.zst: {x}") + compress_zst(x, y) + else: + print(f"not compressing, existing tar.zst: {x}") + else: + print(f"not compressing, does not exist: {x}")