From f0c1e0a77bf8e3660fdabf484229a47a021d1f8c Mon Sep 17 00:00:00 2001 From: Dan Redding <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 9 Aug 2024 14:25:14 +0100 Subject: [PATCH] feat: Support a wider range of iterables in `SchemaBase.to_dict` (#3501) * feat: Support a wider range of iterables in `SchemaBase.to_dict` * test: Update validation errors to use non-iterable element type Previously `list[set[str]]` * fix: Prevent `to_dict` method being called on `pd.Series` AFAIK, this is the intended case for `Parameter | Expression` not for converting arbitrary objects * test: Add tests for iterables and ranges * fix(typing): Ignore type errors for tests * refactor: Use `narwhals.stable.v1` https://github.com/vega/altair/pull/3501#discussion_r1690123512 * test: Increase coverage in `test_to_dict_iterables` The original test obscured the fact that this change applies anywhere a `Sequence` is annotated. * docs: Add a doc for `test_to_dict_iterables` * revert: Change `test_chart_validation_errors` back to demonstrate failures to @joelostblom Reverting 1f3dfd33ee273e71c74b67e64ae4410ae0b7c058 https://github.com/vega/altair/pull/3501#discussion_r1703392848 * docs: Update User Guide to use `Sequence` https://github.com/vega/altair/pull/3501#issuecomment-2267945794 * test: Fix `test_chart_validation_errors` failure verbosity https://github.com/vega/altair/pull/3501#discussion_r1704329394 * refactor: Move `inspect.cleandoc` inside of `test_chart_validation_errors` All 18 cases use this, saves 30 lines * revert: Restore original fix to `test_chart_validation_errors` https://github.com/vega/altair/pull/3501#discussion_r1703392848 https://github.com/vega/altair/pull/3501#discussion_r1704329394 https://github.com/vega/altair/pull/3501#discussion_r1709294033 * docs: Remove "ordered" descriptor from `Sequence` Co-authored-by: Joel Ostblom * test: Remove missed `inspect.cleandoc` https://github.com/vega/altair/pull/3501#discussion_r1711280204 * test: Only modify message, not input to `test_chart_validation_errors` https://github.com/vega/altair/pull/3501#discussion_r1711262137 * style: fix oddly formatted `test_multiple_field_strings_in_condition` --------- Co-authored-by: Joel Ostblom --- altair/utils/schemapi.py | 23 +- doc/user_guide/encodings/index.rst | 2 +- tests/utils/test_schemapi.py | 327 +++++++++++++++++------------ tools/schemapi/schemapi.py | 23 +- 4 files changed, 236 insertions(+), 139 deletions(-) diff --git a/altair/utils/schemapi.py b/altair/utils/schemapi.py index e8884ccd2..567fd8a8a 100644 --- a/altair/utils/schemapi.py +++ b/altair/utils/schemapi.py @@ -32,6 +32,7 @@ import jsonschema import jsonschema.exceptions import jsonschema.validators +import narwhals.stable.v1 as nw from packaging.version import Version # This leads to circular imports with the vegalite module. Currently, this works @@ -488,6 +489,14 @@ def _subclasses(cls: type[Any]) -> Iterator[type[Any]]: yield cls +def _from_array_like(obj: Iterable[Any], /) -> list[Any]: + try: + ser = nw.from_native(obj, strict=True, series_only=True) + return ser.to_list() + except TypeError: + return list(obj) + + def _todict(obj: Any, context: dict[str, Any] | None, np_opt: Any, pd_opt: Any) -> Any: """Convert an object to a dict representation.""" if np_opt is not None: @@ -512,10 +521,16 @@ def _todict(obj: Any, context: dict[str, Any] | None, np_opt: Any, pd_opt: Any) for k, v in obj.items() if v is not Undefined } - elif hasattr(obj, "to_dict"): + elif ( + hasattr(obj, "to_dict") + and (module_name := obj.__module__) + and module_name.startswith("altair") + ): return obj.to_dict() elif pd_opt is not None and isinstance(obj, pd_opt.Timestamp): return pd_opt.Timestamp(obj).isoformat() + elif _is_iterable(obj, exclude=(str, bytes)): + return _todict(_from_array_like(obj), context, np_opt, pd_opt) else: return obj @@ -1232,6 +1247,12 @@ def _is_list(obj: Any | list[Any]) -> TypeIs[list[Any]]: return isinstance(obj, list) +def _is_iterable( + obj: Any, *, exclude: type | tuple[type, ...] = (str, bytes) +) -> TypeIs[Iterable[Any]]: + return not isinstance(obj, exclude) and isinstance(obj, Iterable) + + def _passthrough(*args: Any, **kwds: Any) -> Any | dict[str, Any]: return args[0] if args else kwds diff --git a/doc/user_guide/encodings/index.rst b/doc/user_guide/encodings/index.rst index 264bacc02..b693b93f4 100644 --- a/doc/user_guide/encodings/index.rst +++ b/doc/user_guide/encodings/index.rst @@ -420,7 +420,7 @@ options available to change the sort order: - Passing the name of an encoding channel to ``sort``, such as ``"x"`` or ``"y"``, allows for sorting by that channel. An optional minus prefix can be used for a descending sort. For example ``sort='-x'`` would sort by the x channel in descending order. -- Passing a list to ``sort`` allows you to explicitly set the order in which +- Passing a `Sequence `_ to ``sort`` allows you to explicitly set the order in which you would like the encoding to appear - Using the ``field`` and ``op`` parameters to specify a field and aggregation operation to sort by. diff --git a/tests/utils/test_schemapi.py b/tests/utils/test_schemapi.py index b0068724b..acea9dcb6 100644 --- a/tests/utils/test_schemapi.py +++ b/tests/utils/test_schemapi.py @@ -1,15 +1,22 @@ # ruff: noqa: W291 +from __future__ import annotations + import copy import inspect import io import json import pickle +import types import warnings +from collections import deque +from functools import partial +from typing import Any, Callable, Iterable, Sequence import jsonschema import jsonschema.exceptions import numpy as np import pandas as pd +import polars as pl import pytest import altair as alt @@ -22,6 +29,8 @@ UndefinedType, _FromDict, ) +from altair.vegalite.v5.schema.channels import X +from altair.vegalite.v5.schema.core import FieldOneOfPredicate, Legend from vega_datasets import data _JSON_SCHEMA_DRAFT_URL = load_schema()["$schema"] @@ -531,9 +540,7 @@ def chart_error_example__wrong_tooltip_type_in_faceted_chart(): def chart_error_example__wrong_tooltip_type_in_layered_chart(): # Error: Wrong data type to pass to tooltip - return alt.layer( - alt.Chart().mark_point().encode(tooltip=[{"wrong"}]), - ) + return alt.layer(alt.Chart().mark_point().encode(tooltip=[{"wrong"}])) def chart_error_example__two_errors_in_layered_chart(): @@ -629,13 +636,24 @@ def chart_error_example__four_errors(): ) -@pytest.mark.parametrize( - ("chart_func", "expected_error_message"), - [ - ( - chart_error_example__invalid_y_option_value_unknown_x_option, - inspect.cleandoc( - r"""Multiple errors were found. +def id_func(val) -> str: + """ + Ensures the generated test-id name uses only `chart_func` and not `expected_error_message`. + + Without this, the name is ``test_chart_validation_errors[chart_func-expected_error_message]`` + """ + if isinstance(val, types.FunctionType): + return val.__name__ + else: + return "" + + +# NOTE: Avoids all cases appearing in a failure traceback +# At the time of writing, this is over 300 lines +chart_funcs_error_message: list[tuple[Callable[..., Any], str]] = [ + ( + chart_error_example__invalid_y_option_value_unknown_x_option, + r"""Multiple errors were found. Error 1: `X` has no parameter named 'unknown' @@ -650,27 +668,21 @@ def chart_error_example__four_errors(): Error 2: 'asdf' is an invalid value for `stack`. Valid values are: - One of \['zero', 'center', 'normalize'\] - - Of type 'null' or 'boolean'$""" - ), - ), - ( - chart_error_example__wrong_tooltip_type_in_faceted_chart, - inspect.cleandoc( - r"""'{'wrong'}' is an invalid value for `field`. Valid values are of type 'string' or 'object'.$""" - ), - ), - ( - chart_error_example__wrong_tooltip_type_in_layered_chart, - inspect.cleandoc( - r"""'{'wrong'}' is an invalid value for `field`. Valid values are of type 'string' or 'object'.$""" - ), - ), - ( - chart_error_example__two_errors_in_layered_chart, - inspect.cleandoc( - r"""Multiple errors were found. - - Error 1: '{'wrong'}' is an invalid value for `field`. Valid values are of type 'string' or 'object'. + - Of type 'null' or 'boolean'$""", + ), + ( + chart_error_example__wrong_tooltip_type_in_faceted_chart, + r"""'\['wrong'\]' is an invalid value for `field`. Valid values are of type 'string' or 'object'.$""", + ), + ( + chart_error_example__wrong_tooltip_type_in_layered_chart, + r"""'\['wrong'\]' is an invalid value for `field`. Valid values are of type 'string' or 'object'.$""", + ), + ( + chart_error_example__two_errors_in_layered_chart, + r"""Multiple errors were found. + + Error 1: '\['wrong'\]' is an invalid value for `field`. Valid values are of type 'string' or 'object'. Error 2: `Color` has no parameter named 'invalidArgument' @@ -679,25 +691,21 @@ def chart_error_example__four_errors(): aggregate condition scale title bandPosition field sort type - See the help for `Color` to read the full description of these parameters$""" - ), - ), - ( - chart_error_example__two_errors_in_complex_concat_layered_chart, - inspect.cleandoc( - r"""Multiple errors were found. + See the help for `Color` to read the full description of these parameters$""", + ), + ( + chart_error_example__two_errors_in_complex_concat_layered_chart, + r"""Multiple errors were found. - Error 1: '{'wrong'}' is an invalid value for `field`. Valid values are of type 'string' or 'object'. + Error 1: '\['wrong'\]' is an invalid value for `field`. Valid values are of type 'string' or 'object'. - Error 2: '4' is an invalid value for `bandPosition`. Valid values are of type 'number'.$""" - ), - ), - ( - chart_error_example__three_errors_in_complex_concat_layered_chart, - inspect.cleandoc( - r"""Multiple errors were found. + Error 2: '4' is an invalid value for `bandPosition`. Valid values are of type 'number'.$""", + ), + ( + chart_error_example__three_errors_in_complex_concat_layered_chart, + r"""Multiple errors were found. - Error 1: '{'wrong'}' is an invalid value for `field`. Valid values are of type 'string' or 'object'. + Error 1: '\['wrong'\]' is an invalid value for `field`. Valid values are of type 'string' or 'object'. Error 2: `Color` has no parameter named 'invalidArgument' @@ -708,13 +716,11 @@ def chart_error_example__four_errors(): See the help for `Color` to read the full description of these parameters - Error 3: '4' is an invalid value for `bandPosition`. Valid values are of type 'number'.$""" - ), - ), - ( - chart_error_example__two_errors_with_one_in_nested_layered_chart, - inspect.cleandoc( - r"""Multiple errors were found. + Error 3: '4' is an invalid value for `bandPosition`. Valid values are of type 'number'.$""", + ), + ( + chart_error_example__two_errors_with_one_in_nested_layered_chart, + r"""Multiple errors were found. Error 1: `Scale` has no parameter named 'invalidOption' @@ -734,13 +740,11 @@ def chart_error_example__four_errors(): aggregate condition scale title bandPosition field sort type - See the help for `Color` to read the full description of these parameters$""" - ), - ), - ( - chart_error_example__layer, - inspect.cleandoc( - r"""`VConcatChart` has no parameter named 'width' + See the help for `Color` to read the full description of these parameters$""", + ), + ( + chart_error_example__layer, + r"""`VConcatChart` has no parameter named 'width' Existing parameter names are: vconcat center description params title @@ -748,37 +752,29 @@ def chart_error_example__four_errors(): background data padding spacing usermeta bounds datasets - See the help for `VConcatChart` to read the full description of these parameters$""" - ), - ), - ( - chart_error_example__invalid_y_option_value, - inspect.cleandoc( - r"""'asdf' is an invalid value for `stack`. Valid values are: + See the help for `VConcatChart` to read the full description of these parameters$""", + ), + ( + chart_error_example__invalid_y_option_value, + r"""'asdf' is an invalid value for `stack`. Valid values are: - One of \['zero', 'center', 'normalize'\] - - Of type 'null' or 'boolean'$""" - ), - ), - ( - chart_error_example__invalid_y_option_value_with_condition, - inspect.cleandoc( - r"""'asdf' is an invalid value for `stack`. Valid values are: + - Of type 'null' or 'boolean'$""", + ), + ( + chart_error_example__invalid_y_option_value_with_condition, + r"""'asdf' is an invalid value for `stack`. Valid values are: - One of \['zero', 'center', 'normalize'\] - - Of type 'null' or 'boolean'$""" - ), - ), - ( - chart_error_example__hconcat, - inspect.cleandoc( - r"""'{'text': 'Horsepower', 'align': 'right'}' is an invalid value for `title`. Valid values are of type 'string', 'array', or 'null'.$""" - ), - ), - ( - chart_error_example__invalid_timeunit_value, - inspect.cleandoc( - r"""'invalid_value' is an invalid value for `timeUnit`. Valid values are: + - Of type 'null' or 'boolean'$""", + ), + ( + chart_error_example__hconcat, + r"""'{'text': 'Horsepower', 'align': 'right'}' is an invalid value for `title`. Valid values are of type 'string', 'array', or 'null'.$""", + ), + ( + chart_error_example__invalid_timeunit_value, + r"""'invalid_value' is an invalid value for `timeUnit`. Valid values are: - One of \['year', 'quarter', 'month', 'week', 'day', 'dayofyear', 'date', 'hours', 'minutes', 'seconds', 'milliseconds'\] - One of \['utcyear', 'utcquarter', 'utcmonth', 'utcweek', 'utcday', 'utcdayofyear', 'utcdate', 'utchours', 'utcminutes', 'utcseconds', 'utcmilliseconds'\] @@ -786,36 +782,28 @@ def chart_error_example__four_errors(): - One of \['utcyearquarter', 'utcyearquartermonth', 'utcyearmonth', 'utcyearmonthdate', 'utcyearmonthdatehours', 'utcyearmonthdatehoursminutes', 'utcyearmonthdatehoursminutesseconds', 'utcyearweek', 'utcyearweekday', 'utcyearweekdayhours', 'utcyearweekdayhoursminutes', 'utcyearweekdayhoursminutesseconds', 'utcyeardayofyear', 'utcquartermonth', 'utcmonthdate', 'utcmonthdatehours', 'utcmonthdatehoursminutes', 'utcmonthdatehoursminutesseconds', 'utcweekday', 'utcweekdayhours', 'utcweekdayhoursminutes', 'utcweekdayhoursminutesseconds', 'utcdayhours', 'utcdayhoursminutes', 'utcdayhoursminutesseconds', 'utchoursminutes', 'utchoursminutesseconds', 'utcminutesseconds', 'utcsecondsmilliseconds'\] - One of \['binnedyear', 'binnedyearquarter', 'binnedyearquartermonth', 'binnedyearmonth', 'binnedyearmonthdate', 'binnedyearmonthdatehours', 'binnedyearmonthdatehoursminutes', 'binnedyearmonthdatehoursminutesseconds', 'binnedyearweek', 'binnedyearweekday', 'binnedyearweekdayhours', 'binnedyearweekdayhoursminutes', 'binnedyearweekdayhoursminutesseconds', 'binnedyeardayofyear'\] - One of \['binnedutcyear', 'binnedutcyearquarter', 'binnedutcyearquartermonth', 'binnedutcyearmonth', 'binnedutcyearmonthdate', 'binnedutcyearmonthdatehours', 'binnedutcyearmonthdatehoursminutes', 'binnedutcyearmonthdatehoursminutesseconds', 'binnedutcyearweek', 'binnedutcyearweekday', 'binnedutcyearweekdayhours', 'binnedutcyearweekdayhoursminutes', 'binnedutcyearweekdayhoursminutesseconds', 'binnedutcyeardayofyear'\] - - Of type 'object'$""" - ), - ), - ( - chart_error_example__invalid_sort_value, - inspect.cleandoc( - r"""'invalid_value' is an invalid value for `sort`. Valid values are: + - Of type 'object'$""", + ), + ( + chart_error_example__invalid_sort_value, + r"""'invalid_value' is an invalid value for `sort`. Valid values are: - One of \['ascending', 'descending'\] - One of \['x', 'y', 'color', 'fill', 'stroke', 'strokeWidth', 'size', 'shape', 'fillOpacity', 'strokeOpacity', 'opacity', 'text'\] - One of \['-x', '-y', '-color', '-fill', '-stroke', '-strokeWidth', '-size', '-shape', '-fillOpacity', '-strokeOpacity', '-opacity', '-text'\] - - Of type 'array', 'object', or 'null'$""" - ), - ), - ( - chart_error_example__invalid_bandposition_value, - inspect.cleandoc( - r"""'4' is an invalid value for `bandPosition`. Valid values are of type 'number'.$""" - ), - ), - ( - chart_error_example__invalid_type, - inspect.cleandoc( - r"""'unknown' is an invalid value for `type`. Valid values are one of \['quantitative', 'ordinal', 'temporal', 'nominal', 'geojson'\].$""" - ), - ), - ( - chart_error_example__additional_datum_argument, - inspect.cleandoc( - r"""`X` has no parameter named 'wrong_argument' + - Of type 'array', 'object', or 'null'$""", + ), + ( + chart_error_example__invalid_bandposition_value, + r"""'4' is an invalid value for `bandPosition`. Valid values are of type 'number'.$""", + ), + ( + chart_error_example__invalid_type, + r"""'unknown' is an invalid value for `type`. Valid values are one of \['quantitative', 'ordinal', 'temporal', 'nominal', 'geojson'\].$""", + ), + ( + chart_error_example__additional_datum_argument, + r"""`X` has no parameter named 'wrong_argument' Existing parameter names are: shorthand bin scale timeUnit @@ -823,19 +811,15 @@ def chart_error_example__four_errors(): axis impute stack type bandPosition - See the help for `X` to read the full description of these parameters$""" - ), - ), - ( - chart_error_example__invalid_value_type, - inspect.cleandoc( - r"""'1' is an invalid value for `value`. Valid values are of type 'object', 'string', or 'null'.$""" - ), - ), - ( - chart_error_example__four_errors, - inspect.cleandoc( - r"""Multiple errors were found. + See the help for `X` to read the full description of these parameters$""", + ), + ( + chart_error_example__invalid_value_type, + r"""'1' is an invalid value for `value`. Valid values are of type 'object', 'string', or 'null'.$""", + ), + ( + chart_error_example__four_errors, + r"""Multiple errors were found. Error 1: `Color` has no parameter named 'another_unknown' @@ -863,10 +847,13 @@ def chart_error_example__four_errors(): axis impute stack type bandPosition - See the help for `X` to read the full description of these parameters$""" - ), - ), - ], + See the help for `X` to read the full description of these parameters$""", + ), +] + + +@pytest.mark.parametrize( + ("chart_func", "expected_error_message"), chart_funcs_error_message, ids=id_func ) def test_chart_validation_errors(chart_func, expected_error_message): # For some wrong chart specifications such as an unknown encoding channel, @@ -876,6 +863,7 @@ def test_chart_validation_errors(chart_func, expected_error_message): with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) chart = chart_func() + expected_error_message = inspect.cleandoc(expected_error_message) with pytest.raises(SchemaValidationError, match=expected_error_message): chart.to_dict() @@ -884,14 +872,13 @@ def test_multiple_field_strings_in_condition(): selection = alt.selection_point() expected_error_message = "A field cannot be used for both the `if_true` and `if_false` values of a condition. One of them has to specify a `value` or `datum` definition." with pytest.raises(ValueError, match=expected_error_message): - ( + chart = ( # noqa: F841 alt.Chart(data.cars()) .mark_circle() .add_params(selection) - .encode( - color=alt.condition(selection, "Origin", "Origin"), - ) - ).to_dict() + .encode(color=alt.condition(selection, "Origin", "Origin")) + .to_dict() + ) def test_serialize_numpy_types(): @@ -948,3 +935,71 @@ def test_to_dict_expand_mark_spec(): chart = alt.Chart().mark_bar() assert chart.to_dict()["mark"] == {"type": "bar"} assert chart.mark == "bar" + + +@pytest.mark.parametrize( + "expected", + [list("cdfabe"), [0, 3, 4, 5, 8]], +) +@pytest.mark.parametrize( + "tp", + [ + tuple, + list, + deque, + pl.Series, + pd.Series, + pd.Index, + pd.Categorical, + pd.CategoricalIndex, + np.array, + ], +) +@pytest.mark.parametrize( + "schema_param", + [ + (partial(X, "x:N"), "sort"), + (partial(FieldOneOfPredicate, "name"), "oneOf"), + (Legend, "values"), + ], +) +def test_to_dict_iterables( + tp: Callable[..., Iterable[Any]], + expected: Sequence[Any], + schema_param: tuple[Callable[..., SchemaBase], str], +) -> None: + """ + Confirm `SchemaBase` can convert common `(Sequence|Iterable)` types to `list`. + + Parameters + ---------- + tp + Constructor for test `Iterable`. + expected + Values wrapped by `tp`. + schema_param + Constructor for `SchemaBase` subclass, and target parameter name. + + Notes + ----- + `partial` can be used to reshape the `SchemaBase` constructor. + + References + ---------- + - https://github.com/vega/altair/issues/2808 + - https://github.com/vega/altair/issues/2877 + """ + tp_schema, param = schema_param + validated = tp_schema(**{param: tp(expected)}).to_dict() + actual = validated[param] + assert actual == expected + + +@pytest.mark.parametrize( + "tp", [range, np.arange, partial(pl.int_range, eager=True), pd.RangeIndex] +) +def test_to_dict_range(tp) -> None: + expected = [0, 1, 2, 3, 4] + x_dict = alt.X("x:O", sort=tp(0, 5)).to_dict() + actual = x_dict["sort"] # type: ignore + assert actual == expected diff --git a/tools/schemapi/schemapi.py b/tools/schemapi/schemapi.py index bd1827a89..be6725a03 100644 --- a/tools/schemapi/schemapi.py +++ b/tools/schemapi/schemapi.py @@ -30,6 +30,7 @@ import jsonschema import jsonschema.exceptions import jsonschema.validators +import narwhals.stable.v1 as nw from packaging.version import Version # This leads to circular imports with the vegalite module. Currently, this works @@ -486,6 +487,14 @@ def _subclasses(cls: type[Any]) -> Iterator[type[Any]]: yield cls +def _from_array_like(obj: Iterable[Any], /) -> list[Any]: + try: + ser = nw.from_native(obj, strict=True, series_only=True) + return ser.to_list() + except TypeError: + return list(obj) + + def _todict(obj: Any, context: dict[str, Any] | None, np_opt: Any, pd_opt: Any) -> Any: """Convert an object to a dict representation.""" if np_opt is not None: @@ -510,10 +519,16 @@ def _todict(obj: Any, context: dict[str, Any] | None, np_opt: Any, pd_opt: Any) for k, v in obj.items() if v is not Undefined } - elif hasattr(obj, "to_dict"): + elif ( + hasattr(obj, "to_dict") + and (module_name := obj.__module__) + and module_name.startswith("altair") + ): return obj.to_dict() elif pd_opt is not None and isinstance(obj, pd_opt.Timestamp): return pd_opt.Timestamp(obj).isoformat() + elif _is_iterable(obj, exclude=(str, bytes)): + return _todict(_from_array_like(obj), context, np_opt, pd_opt) else: return obj @@ -1230,6 +1245,12 @@ def _is_list(obj: Any | list[Any]) -> TypeIs[list[Any]]: return isinstance(obj, list) +def _is_iterable( + obj: Any, *, exclude: type | tuple[type, ...] = (str, bytes) +) -> TypeIs[Iterable[Any]]: + return not isinstance(obj, exclude) and isinstance(obj, Iterable) + + def _passthrough(*args: Any, **kwds: Any) -> Any | dict[str, Any]: return args[0] if args else kwds