Skip to content

Commit

Permalink
refactor: apply pylint rule fixes, add an inline optimization for `…
Browse files Browse the repository at this point in the history
…_selection`

`param_kwds` in `_selection` triggered `PLR6201`. Instead rewrote as a dictcomp.
  • Loading branch information
dangotbanned committed Jun 5, 2024
1 parent 007b8b3 commit 10eb5d5
Show file tree
Hide file tree
Showing 15 changed files with 98 additions and 114 deletions.
2 changes: 1 addition & 1 deletion altair/utils/_transformed_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def name_views(
chart, (_chart_class_mapping[Chart], _chart_class_mapping[FacetChart])
):
if chart.name not in exclude:
if chart.name in (None, Undefined):
if chart.name in {None, Undefined}:
# Add name since none is specified
chart.name = Chart._get_name()
return [chart.name]
Expand Down
14 changes: 7 additions & 7 deletions altair/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,27 +211,27 @@ def infer_vegalite_type(
"""
typ = infer_dtype(data, skipna=False)

if typ in [
if typ in {
"floating",
"mixed-integer-float",
"integer",
"mixed-integer",
"complex",
]:
}:
return "quantitative"
elif typ == "categorical" and hasattr(data, "cat") and data.cat.ordered:
return ("ordinal", data.cat.categories.tolist())
elif typ in ["string", "bytes", "categorical", "boolean", "mixed", "unicode"]:
elif typ in {"string", "bytes", "categorical", "boolean", "mixed", "unicode"}:
return "nominal"
elif typ in [
elif typ in {
"datetime",
"datetime64",
"timedelta",
"timedelta64",
"date",
"time",
"period",
]:
}:
return "temporal"
else:
warnings.warn(
Expand Down Expand Up @@ -674,9 +674,9 @@ def infer_vegalite_type_for_dfi_column(
categories_column = column.describe_categorical["categories"]
categories_array = column_to_array(categories_column)
return "ordinal", categories_array.to_pylist()
if kind in (DtypeKind.STRING, DtypeKind.CATEGORICAL, DtypeKind.BOOL):
if kind in {DtypeKind.STRING, DtypeKind.CATEGORICAL, DtypeKind.BOOL}:
return "nominal"
elif kind in (DtypeKind.INT, DtypeKind.UINT, DtypeKind.FLOAT):
elif kind in {DtypeKind.INT, DtypeKind.UINT, DtypeKind.FLOAT}:
return "quantitative"
elif kind == DtypeKind.DATETIME:
return "temporal"
Expand Down
5 changes: 3 additions & 2 deletions altair/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@


from typing import Protocol, TypedDict, Literal
import locale


if TYPE_CHECKING:
Expand Down Expand Up @@ -186,7 +187,7 @@ def to_json(
data_json = _data_to_json_string(data)
data_hash = _compute_data_hash(data_json)
filename = filename.format(prefix=prefix, hash=data_hash, extension=extension)
Path(filename).write_text(data_json)
Path(filename).write_text(data_json, encoding=locale.getpreferredencoding(False))
return {"url": str(Path(urlpath, filename)), "format": {"type": "json"}}


Expand All @@ -202,7 +203,7 @@ def to_csv(
data_csv = _data_to_csv_string(data)
data_hash = _compute_data_hash(data_csv)
filename = filename.format(prefix=prefix, hash=data_hash, extension=extension)
Path(filename).write_text(data_csv)
Path(filename).write_text(data_csv, encoding=locale.getpreferredencoding(False))
return {"url": str(Path(urlpath, filename)), "format": {"type": "csv"}}


Expand Down
2 changes: 1 addition & 1 deletion altair/utils/html.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def spec_to_html(

mode = embed_options.setdefault("mode", mode)

if mode not in ["vega", "vega-lite"]:
if mode not in {"vega", "vega-lite"}:
msg = "mode must be either 'vega' or 'vega-lite'"
raise ValueError(msg)

Expand Down
2 changes: 1 addition & 1 deletion altair/utils/mimebundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def spec_to_mimebundle(

embed_options = preprocess_embed_options(final_embed_options)

if format in ["png", "svg", "pdf", "vega"]:
if format in {"png", "svg", "pdf", "vega"}:
format = cast(Literal["png", "svg", "pdf", "vega"], format)
return _spec_to_mimebundle_with_engine(
spec,
Expand Down
2 changes: 1 addition & 1 deletion altair/utils/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def perform_save():
write_file_or_filename(
fp, mimebundle["text/html"], mode="w", encoding=encoding
)
elif format in ["png", "svg", "pdf", "vega"]:
elif format in {"png", "svg", "pdf", "vega"}:
mimebundle = spec_to_mimebundle(
spec=spec,
format=format,
Expand Down
33 changes: 14 additions & 19 deletions altair/utils/schemapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,20 +220,15 @@ def _prepare_refs(d: Dict[str, Any]) -> Dict[str, Any]:
for key, value in d.items():
if key == "$ref":
d[key] = _VEGA_LITE_ROOT_URI + d[key]
else:
# $ref values can only be nested in dictionaries or lists
# as the passed in `d` dictionary comes from the Vega-Lite json schema
# and in json we only have arrays (-> lists in Python) and objects
# (-> dictionaries in Python) which we need to iterate through.
if isinstance(value, dict):
d[key] = _prepare_refs(value)
elif isinstance(value, list):
prepared_values = []
for v in value:
if isinstance(v, dict):
v = _prepare_refs(v)
prepared_values.append(v)
d[key] = prepared_values
elif isinstance(value, dict):
d[key] = _prepare_refs(value)
elif isinstance(value, list):
prepared_values = []
for v in value:
if isinstance(v, dict):
v = _prepare_refs(v)
prepared_values.append(v)
d[key] = prepared_values
return d

schema = _prepare_refs(schema)
Expand Down Expand Up @@ -616,7 +611,7 @@ def _format_params_as_table(param_dict_keys: Iterable[str]) -> str:
*[
(name, len(name))
for name in param_dict_keys
if name not in ["kwds", "self"]
if name not in {"kwds", "self"}
]
)
# Worst case scenario with the same longest param name in the same
Expand Down Expand Up @@ -712,7 +707,7 @@ def _get_default_error_message(
# considered so far. This is not expected to be used but more exists
# as a fallback for cases which were not known during development.
for validator, errors in errors_by_validator.items():
if validator not in ("enum", "type"):
if validator not in {"enum", "type"}:
message += "\n".join([e.message for e in errors])

return message
Expand Down Expand Up @@ -762,7 +757,7 @@ def __init__(self, *args: Any, **kwds: Any) -> None:
if kwds:
assert len(args) == 0
else:
assert len(args) in [0, 1]
assert len(args) in {0, 1}

# use object.__setattr__ because we override setattr below.
object.__setattr__(self, "_args", args)
Expand Down Expand Up @@ -942,7 +937,7 @@ def to_dict(
# when a non-ordinal data type is specifed manually
# or if the encoding channel does not support sorting
if "sort" in parsed_shorthand and (
"sort" not in kwds or kwds["type"] not in ["ordinal", Undefined]
"sort" not in kwds or kwds["type"] not in {"ordinal", Undefined}
):
parsed_shorthand.pop("sort")

Expand All @@ -954,7 +949,7 @@ def to_dict(
}
)
kwds = {
k: v for k, v in kwds.items() if k not in [*list(ignore), "shorthand"]
k: v for k, v in kwds.items() if k not in {*list(ignore), "shorthand"}
}
if "mark" in kwds and isinstance(kwds["mark"], str):
kwds["mark"] = {"type": kwds["mark"]}
Expand Down
2 changes: 1 addition & 1 deletion altair/vegalite/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class DataTransformerRegistry(_DataTransformerRegistry):
def disable_max_rows(self) -> PluginEnabler:
"""Disable the MaxRowsError."""
options = self.options
if self.active in ("default", "vegafusion"):
if self.active in {"default", "vegafusion"}:
options = options.copy()
options["max_rows"] = None
return self.enable(**options)
Expand Down
59 changes: 24 additions & 35 deletions altair/vegalite/v5/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,18 +461,16 @@ def _selection(
type: Union[Literal["interval", "point"], UndefinedType] = Undefined, **kwds
) -> Parameter:
# We separate out the parameter keywords from the selection keywords
param_kwds = {}

for kwd in {"name", "bind", "value", "empty", "init", "views"}:
if kwd in kwds:
param_kwds[kwd] = kwds.pop(kwd)
select_kwds = {"name", "bind", "value", "empty", "init", "views"}
param_kwds = {key: kwds.pop(key) for key in select_kwds & kwds.keys()}

select: Union[core.IntervalSelectionConfig, core.PointSelectionConfig]
if type == "interval":
select = core.IntervalSelectionConfig(type=type, **kwds)
elif type == "point":
select = core.PointSelectionConfig(type=type, **kwds)
elif type in ["single", "multi"]:
elif type in {"single", "multi"}:
select = core.PointSelectionConfig(type="point", **kwds)
warnings.warn(
"""The types 'single' and 'multi' are now
Expand Down Expand Up @@ -928,7 +926,7 @@ def to_dict(
"""

# Validate format
if format not in ("vega-lite", "vega"):
if format not in {"vega-lite", "vega"}:
msg = f'The format argument must be either "vega-lite" or "vega". Received {format!r}'
raise ValueError(msg)

Expand Down Expand Up @@ -998,15 +996,14 @@ def to_dict(
raise ValueError(msg)
else:
return _compile_with_vegafusion(vegalite_spec)
elif format == "vega":
plugin = vegalite_compilers.get()
if plugin is None:
msg = "No active vega-lite compiler plugin found"
raise ValueError(msg)
return plugin(vegalite_spec)
else:
if format == "vega":
plugin = vegalite_compilers.get()
if plugin is None:
msg = "No active vega-lite compiler plugin found"
raise ValueError(msg)
return plugin(vegalite_spec)
else:
return vegalite_spec
return vegalite_spec

def to_json(
self,
Expand Down Expand Up @@ -1242,7 +1239,6 @@ def save(
save(**kwds)
else:
save(**kwds)
return

# Fallback for when rendering fails; the full repr is too long to be
# useful in nearly all cases.
Expand Down Expand Up @@ -2448,10 +2444,9 @@ def transform_timeunit(
"""
if as_ is Undefined:
as_ = kwargs.pop("as", Undefined)
else:
if "as" in kwargs:
msg = "transform_timeunit: both 'as_' and 'as' passed as arguments."
raise ValueError(msg)
elif "as" in kwargs:
msg = "transform_timeunit: both 'as_' and 'as' passed as arguments."
raise ValueError(msg)
if as_ is not Undefined:
dct = {"as": as_, "timeUnit": timeUnit, "field": field}
self = self._add_transform(core.TimeUnitTransform(**dct)) # type: ignore[arg-type]
Expand Down Expand Up @@ -3271,7 +3266,7 @@ def repeat(
-------
repeat : RepeatRef object
"""
if repeater not in ["row", "column", "repeat", "layer"]:
if repeater not in {"row", "column", "repeat", "layer"}:
msg = "repeater must be one of ['row', 'column', 'repeat', 'layer']"
raise ValueError(msg)
return core.RepeatRef(repeat=repeater)
Expand Down Expand Up @@ -3819,11 +3814,8 @@ def remove_data(subchart):
if subdata is not Undefined and all(c.data is subdata for c in subcharts):
data = subdata
subcharts = [remove_data(c) for c in subcharts]
else:
# Top level has data; subchart data must be either
# undefined or identical to proceed.
if all(c.data is Undefined or c.data is data for c in subcharts):
subcharts = [remove_data(c) for c in subcharts]
elif all(c.data is Undefined or c.data is data for c in subcharts):
subcharts = [remove_data(c) for c in subcharts]

return data, subcharts

Expand Down Expand Up @@ -4053,17 +4045,14 @@ def remove_prop(subchart, prop):
else:
msg = f"There are inconsistent values {values} for {prop}"
raise ValueError(msg)
elif all(
getattr(c, prop, Undefined) is Undefined or c[prop] == chart[prop]
for c in subcharts
):
output_dict[prop] = chart[prop]
else:
# Top level has this prop; subchart must either not have the prop
# or it must be Undefined or identical to proceed.
if all(
getattr(c, prop, Undefined) is Undefined or c[prop] == chart[prop]
for c in subcharts
):
output_dict[prop] = chart[prop]
else:
msg = f"There are inconsistent values {values} for {prop}"
raise ValueError(msg)
msg = f"There are inconsistent values {values} for {prop}"
raise ValueError(msg)
subcharts = [remove_prop(c, prop) for c in subcharts]

return output_dict, subcharts
Expand Down
10 changes: 6 additions & 4 deletions sphinxext/altairgallery.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from altair.utils.execeval import eval_block
from tests.examples_arguments_syntax import iter_examples_arguments_syntax
from tests.examples_methods_syntax import iter_examples_methods_syntax
import locale


EXAMPLE_MODULE = "altair.examples"
Expand Down Expand Up @@ -158,7 +159,7 @@ def save_example_pngs(examples, image_dir, make_thumbnails=True):
hash_file = os.path.join(image_dir, "_image_hashes.json")

if os.path.exists(hash_file):
with open(hash_file) as f:
with open(hash_file, encoding=locale.getpreferredencoding(False)) as f:
hashes = json.load(f)
else:
hashes = {}
Expand All @@ -183,7 +184,7 @@ def save_example_pngs(examples, image_dir, make_thumbnails=True):
warnings.warn("Unable to save image: using generic image", stacklevel=1)
create_generic_image(image_file)

with open(hash_file, "w") as f:
with open(hash_file, "w", encoding=locale.getpreferredencoding(False)) as f:
json.dump(hashes, f)

if make_thumbnails:
Expand All @@ -197,7 +198,7 @@ def save_example_pngs(examples, image_dir, make_thumbnails=True):
create_thumbnail(image_file, thumb_file, **params)

# Save hashes so we know whether we need to re-generate plots
with open(hash_file, "w") as f:
with open(hash_file, "w", encoding=locale.getpreferredencoding(False)) as f:
json.dump(hashes, f)


Expand Down Expand Up @@ -343,7 +344,8 @@ def main(app):
examples=examples_toc.items(),
image_dir="/_static",
gallery_ref=gallery_ref,
)
),
encoding=locale.getpreferredencoding(False),
)

# save the images to file
Expand Down
13 changes: 7 additions & 6 deletions tests/vegalite/v5/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _make_chart_type(chart_type):
)
)

if chart_type in ["layer", "hconcat", "vconcat", "concat"]:
if chart_type in {"layer", "hconcat", "vconcat", "concat"}:
func = getattr(alt, chart_type)
return func(base.mark_square(), base.mark_circle())
elif chart_type == "facet":
Expand Down Expand Up @@ -289,14 +289,14 @@ def test_selection_expression():
@pytest.mark.parametrize("format", ["html", "json", "png", "svg", "pdf", "bogus"])
@pytest.mark.parametrize("engine", ["vl-convert"])
def test_save(format, engine, basic_chart):
if format in ["pdf", "png"]:
if format in {"pdf", "png"}:
out = io.BytesIO()
mode = "rb"
else:
out = io.StringIO()
mode = "r"

if format in ["svg", "png", "pdf", "bogus"] and engine == "vl-convert":
if format in {"svg", "png", "pdf", "bogus"} and engine == "vl-convert":
if format == "bogus":
with pytest.raises(ValueError) as err:
basic_chart.save(out, format=format, engine=engine)
Expand Down Expand Up @@ -605,9 +605,10 @@ def test_transforms():

# kwargs don't maintain order in Python < 3.6, so window list can
# be reversed
assert chart.transform == [
alt.WindowTransform(frame=[None, 0], window=window)
] or chart.transform == [alt.WindowTransform(frame=[None, 0], window=window[::-1])]
assert chart.transform in (
[alt.WindowTransform(frame=[None, 0], window=window)],
[alt.WindowTransform(frame=[None, 0], window=window[::-1])],
)


def test_filter_transform_selection_predicates():
Expand Down
Loading

0 comments on commit 10eb5d5

Please sign in to comment.