From 9a0e96b27cccfd09ccbfbac5b3069e4ce47cd851 Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Sun, 29 Oct 2023 21:07:44 +0100 Subject: [PATCH] Add shorthand str and list[str] type hints to signature of encode method --- altair/vegalite/v5/api.py | 2 +- altair/vegalite/v5/schema/core.py | 395 ++++++++++++++++++++++++++++++ tools/generate_schema_wrapper.py | 21 +- tools/schemapi/codegen.py | 40 ++- 4 files changed, 443 insertions(+), 15 deletions(-) diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index 9614ffb75..5f2313d03 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -2610,7 +2610,7 @@ def resolve_scale(self, *args, **kwargs) -> Self: class _EncodingMixin: - @utils.use_signature(core.FacetedEncoding) + @utils.use_signature(core._encode_signature) def encode(self, *args, **kwargs) -> Self: # Convert args to kwargs based on their types. kwargs = utils.infer_encoding_types(args, kwargs, channels) diff --git a/altair/vegalite/v5/schema/core.py b/altair/vegalite/v5/schema/core.py index 15ddf6cab..285453b06 100644 --- a/altair/vegalite/v5/schema/core.py +++ b/altair/vegalite/v5/schema/core.py @@ -37,6 +37,401 @@ def _to_expr(self) -> str: ... +def _encode_signature( + self, + angle: Union[ + str, + List[str], + Union[ + "NumericMarkPropDef", + Union["FieldOrDatumDefWithConditionDatumDefnumber", dict], + Union["FieldOrDatumDefWithConditionMarkPropFieldDefnumber", dict], + Union["ValueDefWithConditionMarkPropFieldOrDatumDefnumber", dict], + ], + UndefinedType, + ] = Undefined, + color: Union[ + str, + List[str], + Union[ + "ColorDef", + Union["FieldOrDatumDefWithConditionDatumDefGradientstringnull", dict], + Union[ + "FieldOrDatumDefWithConditionMarkPropFieldDefGradientstringnull", dict + ], + Union[ + "ValueDefWithConditionMarkPropFieldOrDatumDefGradientstringnull", dict + ], + ], + UndefinedType, + ] = Undefined, + column: Union[ + str, List[str], Union["RowColumnEncodingFieldDef", dict], UndefinedType + ] = Undefined, + description: Union[ + str, + List[str], + Union[ + Union["StringFieldDefWithCondition", dict], + Union["StringValueDefWithCondition", dict], + ], + UndefinedType, + ] = Undefined, + detail: Union[ + str, + List[str], + Union[ + Sequence[Union["FieldDefWithoutScale", dict]], + Union["FieldDefWithoutScale", dict], + ], + UndefinedType, + ] = Undefined, + facet: Union[ + str, List[str], Union["FacetEncodingFieldDef", dict], UndefinedType + ] = Undefined, + fill: Union[ + str, + List[str], + Union[ + "ColorDef", + Union["FieldOrDatumDefWithConditionDatumDefGradientstringnull", dict], + Union[ + "FieldOrDatumDefWithConditionMarkPropFieldDefGradientstringnull", dict + ], + Union[ + "ValueDefWithConditionMarkPropFieldOrDatumDefGradientstringnull", dict + ], + ], + UndefinedType, + ] = Undefined, + fillOpacity: Union[ + str, + List[str], + Union[ + "NumericMarkPropDef", + Union["FieldOrDatumDefWithConditionDatumDefnumber", dict], + Union["FieldOrDatumDefWithConditionMarkPropFieldDefnumber", dict], + Union["ValueDefWithConditionMarkPropFieldOrDatumDefnumber", dict], + ], + UndefinedType, + ] = Undefined, + href: Union[ + str, + List[str], + Union[ + Union["StringFieldDefWithCondition", dict], + Union["StringValueDefWithCondition", dict], + ], + UndefinedType, + ] = Undefined, + key: Union[ + str, List[str], Union["FieldDefWithoutScale", dict], UndefinedType + ] = Undefined, + latitude: Union[ + str, + List[str], + Union["LatLongDef", Union["DatumDef", dict], Union["LatLongFieldDef", dict]], + UndefinedType, + ] = Undefined, + latitude2: Union[ + str, + List[str], + Union[ + "Position2Def", + Union["DatumDef", dict], + Union["PositionValueDef", dict], + Union["SecondaryFieldDef", dict], + ], + UndefinedType, + ] = Undefined, + longitude: Union[ + str, + List[str], + Union["LatLongDef", Union["DatumDef", dict], Union["LatLongFieldDef", dict]], + UndefinedType, + ] = Undefined, + longitude2: Union[ + str, + List[str], + Union[ + "Position2Def", + Union["DatumDef", dict], + Union["PositionValueDef", dict], + Union["SecondaryFieldDef", dict], + ], + UndefinedType, + ] = Undefined, + opacity: Union[ + str, + List[str], + Union[ + "NumericMarkPropDef", + Union["FieldOrDatumDefWithConditionDatumDefnumber", dict], + Union["FieldOrDatumDefWithConditionMarkPropFieldDefnumber", dict], + Union["ValueDefWithConditionMarkPropFieldOrDatumDefnumber", dict], + ], + UndefinedType, + ] = Undefined, + order: Union[ + str, + List[str], + Union[ + Sequence[Union["OrderFieldDef", dict]], + Union["OrderFieldDef", dict], + Union["OrderOnlyDef", dict], + Union["OrderValueDef", dict], + ], + UndefinedType, + ] = Undefined, + radius: Union[ + str, + List[str], + Union[ + "PolarDef", + Union["PositionDatumDefBase", dict], + Union["PositionFieldDefBase", dict], + Union["PositionValueDef", dict], + ], + UndefinedType, + ] = Undefined, + radius2: Union[ + str, + List[str], + Union[ + "Position2Def", + Union["DatumDef", dict], + Union["PositionValueDef", dict], + Union["SecondaryFieldDef", dict], + ], + UndefinedType, + ] = Undefined, + row: Union[ + str, List[str], Union["RowColumnEncodingFieldDef", dict], UndefinedType + ] = Undefined, + shape: Union[ + str, + List[str], + Union[ + "ShapeDef", + Union["FieldOrDatumDefWithConditionDatumDefstringnull", dict], + Union[ + "FieldOrDatumDefWithConditionMarkPropFieldDefTypeForShapestringnull", + dict, + ], + Union[ + "ValueDefWithConditionMarkPropFieldOrDatumDefTypeForShapestringnull", + dict, + ], + ], + UndefinedType, + ] = Undefined, + size: Union[ + str, + List[str], + Union[ + "NumericMarkPropDef", + Union["FieldOrDatumDefWithConditionDatumDefnumber", dict], + Union["FieldOrDatumDefWithConditionMarkPropFieldDefnumber", dict], + Union["ValueDefWithConditionMarkPropFieldOrDatumDefnumber", dict], + ], + UndefinedType, + ] = Undefined, + stroke: Union[ + str, + List[str], + Union[ + "ColorDef", + Union["FieldOrDatumDefWithConditionDatumDefGradientstringnull", dict], + Union[ + "FieldOrDatumDefWithConditionMarkPropFieldDefGradientstringnull", dict + ], + Union[ + "ValueDefWithConditionMarkPropFieldOrDatumDefGradientstringnull", dict + ], + ], + UndefinedType, + ] = Undefined, + strokeDash: Union[ + str, + List[str], + Union[ + "NumericArrayMarkPropDef", + Union["FieldOrDatumDefWithConditionDatumDefnumberArray", dict], + Union["FieldOrDatumDefWithConditionMarkPropFieldDefnumberArray", dict], + Union["ValueDefWithConditionMarkPropFieldOrDatumDefnumberArray", dict], + ], + UndefinedType, + ] = Undefined, + strokeOpacity: Union[ + str, + List[str], + Union[ + "NumericMarkPropDef", + Union["FieldOrDatumDefWithConditionDatumDefnumber", dict], + Union["FieldOrDatumDefWithConditionMarkPropFieldDefnumber", dict], + Union["ValueDefWithConditionMarkPropFieldOrDatumDefnumber", dict], + ], + UndefinedType, + ] = Undefined, + strokeWidth: Union[ + str, + List[str], + Union[ + "NumericMarkPropDef", + Union["FieldOrDatumDefWithConditionDatumDefnumber", dict], + Union["FieldOrDatumDefWithConditionMarkPropFieldDefnumber", dict], + Union["ValueDefWithConditionMarkPropFieldOrDatumDefnumber", dict], + ], + UndefinedType, + ] = Undefined, + text: Union[ + str, + List[str], + Union[ + "TextDef", + Union["FieldOrDatumDefWithConditionStringDatumDefText", dict], + Union["FieldOrDatumDefWithConditionStringFieldDefText", dict], + Union["ValueDefWithConditionStringFieldDefText", dict], + ], + UndefinedType, + ] = Undefined, + theta: Union[ + str, + List[str], + Union[ + "PolarDef", + Union["PositionDatumDefBase", dict], + Union["PositionFieldDefBase", dict], + Union["PositionValueDef", dict], + ], + UndefinedType, + ] = Undefined, + theta2: Union[ + str, + List[str], + Union[ + "Position2Def", + Union["DatumDef", dict], + Union["PositionValueDef", dict], + Union["SecondaryFieldDef", dict], + ], + UndefinedType, + ] = Undefined, + tooltip: Union[ + str, + List[str], + Union[ + None, + Sequence[Union["StringFieldDef", dict]], + Union["StringFieldDefWithCondition", dict], + Union["StringValueDefWithCondition", dict], + ], + UndefinedType, + ] = Undefined, + url: Union[ + str, + List[str], + Union[ + Union["StringFieldDefWithCondition", dict], + Union["StringValueDefWithCondition", dict], + ], + UndefinedType, + ] = Undefined, + x: Union[ + str, + List[str], + Union[ + "PositionDef", + Union["PositionDatumDef", dict], + Union["PositionFieldDef", dict], + Union["PositionValueDef", dict], + ], + UndefinedType, + ] = Undefined, + x2: Union[ + str, + List[str], + Union[ + "Position2Def", + Union["DatumDef", dict], + Union["PositionValueDef", dict], + Union["SecondaryFieldDef", dict], + ], + UndefinedType, + ] = Undefined, + xError: Union[ + str, + List[str], + Union[Union["SecondaryFieldDef", dict], Union["ValueDefnumber", dict]], + UndefinedType, + ] = Undefined, + xError2: Union[ + str, + List[str], + Union[Union["SecondaryFieldDef", dict], Union["ValueDefnumber", dict]], + UndefinedType, + ] = Undefined, + xOffset: Union[ + str, + List[str], + Union[ + "OffsetDef", + Union["ScaleDatumDef", dict], + Union["ScaleFieldDef", dict], + Union["ValueDefnumber", dict], + ], + UndefinedType, + ] = Undefined, + y: Union[ + str, + List[str], + Union[ + "PositionDef", + Union["PositionDatumDef", dict], + Union["PositionFieldDef", dict], + Union["PositionValueDef", dict], + ], + UndefinedType, + ] = Undefined, + y2: Union[ + str, + List[str], + Union[ + "Position2Def", + Union["DatumDef", dict], + Union["PositionValueDef", dict], + Union["SecondaryFieldDef", dict], + ], + UndefinedType, + ] = Undefined, + yError: Union[ + str, + List[str], + Union[Union["SecondaryFieldDef", dict], Union["ValueDefnumber", dict]], + UndefinedType, + ] = Undefined, + yError2: Union[ + str, + List[str], + Union[Union["SecondaryFieldDef", dict], Union["ValueDefnumber", dict]], + UndefinedType, + ] = Undefined, + yOffset: Union[ + str, + List[str], + Union[ + "OffsetDef", + Union["ScaleDatumDef", dict], + Union["ScaleFieldDef", dict], + Union["ValueDefnumber", dict], + ], + UndefinedType, + ] = Undefined, + **kwds +): + ... + + class VegaLiteSchema(SchemaBase): _rootschema = load_schema() diff --git a/tools/generate_schema_wrapper.py b/tools/generate_schema_wrapper.py index 193a3e975..10759224a 100644 --- a/tools/generate_schema_wrapper.py +++ b/tools/generate_schema_wrapper.py @@ -235,6 +235,11 @@ def configure_{prop}(self, *args, **kwargs) -> Self: return copy """ +ENCODE_SIGNATURE: Final = """ +def _encode_signature({encode_method_args}): + ... +""" + class SchemaGenerator(codegen.SchemaGenerator): schema_class_template = textwrap.dedent( @@ -352,7 +357,7 @@ def _add_shorthand_property_to_field_encodings(schema: dict) -> dict: for prop, propschema in encoding.properties.items(): def_dict = get_field_datum_value_defs(propschema, schema) - + field_ref = def_dict.get("field") if field_ref is not None: defschema = {"$ref": field_ref} @@ -461,6 +466,8 @@ def generate_vegalite_schema_wrapper(schema_file: str) -> str: definitions: Dict[str, SchemaGenerator] = {} + encode_method_args = "" + for name in rootschema["definitions"]: defschema = {"$ref": "#/definitions/" + name} defschema_repr = {"$ref": "#/definitions/" + name} @@ -473,6 +480,17 @@ def generate_vegalite_schema_wrapper(schema_file: str) -> str: basename=basename, rootschemarepr=CodeSnippet("{}._rootschema".format(basename)), ) + if name == "FacetedEncoding": + # For the .encode method in api.py:_EncodingMixin we need the same + # type signature as in core.FacetedEncoding but additionally, for every + # arguemtn it should also accept a "shorthand" as a string or list of + # strings. + encode_method_args = ", ".join(definitions[name].init_args( + additional_types=["str", "List[str]"] + )[0]) + + assert len(encode_method_args) > 0 + encode_method_signature = ENCODE_SIGNATURE.format(encode_method_args=encode_method_args) graph: Dict[str, List[str]] = {} @@ -497,6 +515,7 @@ def generate_vegalite_schema_wrapper(schema_file: str) -> str: LOAD_SCHEMA.format(schemafile="vega-lite-schema.json"), ] contents.append(PARAMETER_PROTOCOL) + contents.append(encode_method_signature) contents.append(BASE_SCHEMA.format(basename=basename)) contents.append( schema_class( diff --git a/tools/schemapi/codegen.py b/tools/schemapi/codegen.py index c5603deeb..03f8e3685 100644 --- a/tools/schemapi/codegen.py +++ b/tools/schemapi/codegen.py @@ -1,7 +1,7 @@ """Code generation utilities""" import re import textwrap -from typing import Set, Final, Optional, List, Union, Dict +from typing import Set, Final, Optional, List, Union, Dict, Tuple from dataclasses import dataclass from .utils import ( @@ -237,6 +237,21 @@ def docstring(self, indent: int = 0) -> str: def init_code(self, indent: int = 0) -> str: """Return code suitable for the __init__ function of a Schema class""" + args, super_args = self.init_args() + + initfunc = self.init_template.format( + classname=self.classname, + arglist=", ".join(args), + super_arglist=", ".join(super_args), + ) + if indent: + initfunc = ("\n" + indent * " ").join(initfunc.splitlines()) + return initfunc + + def init_args( + self, additional_types: Optional[List[str]] = None + ) -> Tuple[List[str], List[str]]: + additional_types = additional_types or [] info = self.info arg_info = self.arg_info @@ -257,10 +272,17 @@ def init_code(self, indent: int = 0) -> str: args.extend( f"{p}: Union[" - + info.properties[p].get_python_type_representation( - for_type_hints=True, altair_classes_prefix=self.altair_classes_prefix + + ", ".join( + [ + *additional_types, + info.properties[p].get_python_type_representation( + for_type_hints=True, + altair_classes_prefix=self.altair_classes_prefix, + ), + "UndefinedType", + ] ) - + ", UndefinedType] = Undefined" + + "] = Undefined" for p in sorted(arg_info.required) + sorted(arg_info.kwds) ) super_args.extend( @@ -273,15 +295,7 @@ def init_code(self, indent: int = 0) -> str: if arg_info.additional: args.append("**kwds") super_args.append("**kwds") - - initfunc = self.init_template.format( - classname=self.classname, - arglist=", ".join(args), - super_arglist=", ".join(super_args), - ) - if indent: - initfunc = ("\n" + indent * " ").join(initfunc.splitlines()) - return initfunc + return args, super_args def get_args(self, si: SchemaInfo) -> List[str]: contents = ["self"]