Skip to content

Commit

Permalink
feat: Adds @register_theme decorator (#3526)
Browse files Browse the repository at this point in the history
* feat: Adds `@register_theme` decorator

Resolves one item in #3519

* build: run `update-init-file`

Adds `@register_theme` to top-level

* test: Adds `test_register_theme_decorator`

* refactor(typing): Specify `dict[str, Any]` instead of `dict[Any, Any]`

The latter may give false-positives for json-incompatible dicts

---------

Co-authored-by: Stefan Binder <binder_stefan@outlook.com>
  • Loading branch information
dangotbanned and binste authored Sep 5, 2024
1 parent 5207768 commit f542e9e
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 12 deletions.
1 change: 1 addition & 0 deletions altair/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,7 @@
"mixins",
"param",
"parse_shorthand",
"register_theme",
"renderers",
"repeat",
"sample",
Expand Down
2 changes: 1 addition & 1 deletion altair/utils/plugin_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __init__(
self.entry_point_group: str = entry_point_group
self.plugin_type: IsPlugin
if plugin_type is not callable and isinstance(plugin_type, type):
msg = (
msg: Any = (
f"Pass a callable `TypeIs` function to `plugin_type` instead.\n"
f"{type(self).__name__!r}(plugin_type)\n\n"
f"See also:\n"
Expand Down
10 changes: 6 additions & 4 deletions altair/utils/theme.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from __future__ import annotations

import sys
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING, Any, Dict

from .plugin_registry import PluginRegistry
from .plugin_registry import Plugin, PluginRegistry

if sys.version_info >= (3, 11):
from typing import LiteralString
Expand All @@ -16,10 +16,12 @@
from altair.utils.plugin_registry import PluginEnabler
from altair.vegalite.v5.theme import AltairThemes, VegaThemes

ThemeType = Callable[..., dict]
ThemeType = Plugin[Dict[str, Any]]


class ThemeRegistry(PluginRegistry[ThemeType, dict]):
# HACK: See for `LiteralString` requirement in `name`
# https://github.com/vega/altair/pull/3526#discussion_r1743350127
class ThemeRegistry(PluginRegistry[ThemeType, Dict[str, Any]]):
def enable(
self, name: LiteralString | AltairThemes | VegaThemes | None = None, **options
) -> PluginEnabler:
Expand Down
2 changes: 1 addition & 1 deletion altair/vegalite/v5/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@
renderers,
)
from .schema import *
from .theme import themes
from .theme import register_theme, themes
94 changes: 90 additions & 4 deletions altair/vegalite/v5/theme.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,33 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Final, Literal, get_args
import sys
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Dict, Final, Literal, TypeVar, get_args

from altair.utils.theme import ThemeRegistry
from altair.vegalite.v5.schema._typing import VegaThemes

if TYPE_CHECKING:
import sys
if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec


if TYPE_CHECKING:
if sys.version_info >= (3, 11):
from typing import LiteralString
else:
from typing_extensions import LiteralString
if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias

P = ParamSpec("P")
R = TypeVar("R", bound=Dict[str, Any])
AltairThemes: TypeAlias = Literal["default", "opaque"]
VEGA_THEMES: list[str] = list(get_args(VegaThemes))
VEGA_THEMES: list[LiteralString] = list(get_args(VegaThemes))


class VegaTheme:
Expand Down Expand Up @@ -60,3 +72,77 @@ def __repr__(self) -> str:
themes.register(theme, VegaTheme(theme))

themes.enable("default")


# HACK: See for `LiteralString` requirement in `name`
# https://github.com/vega/altair/pull/3526#discussion_r1743350127
def register_theme(
name: LiteralString, *, enable: bool
) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""
Decorator for registering a theme function.
Parameters
----------
name
Unique name assigned in ``alt.themes``.
enable
Auto-enable the wrapped theme.
Examples
--------
Register and enable a theme::
from __future__ import annotations
from typing import Any
import altair as alt
@alt.register_theme("param_font_size", enable=True)
def custom_theme() -> dict[str, Any]:
sizes = 12, 14, 16, 18, 20
return {
"autosize": {"contains": "content", "resize": True},
"background": "#F3F2F1",
"config": {
"axisX": {"labelFontSize": sizes[1], "titleFontSize": sizes[1]},
"axisY": {"labelFontSize": sizes[1], "titleFontSize": sizes[1]},
"font": "'Lato', 'Segoe UI', Tahoma, Verdana, sans-serif",
"headerColumn": {"labelFontSize": sizes[1]},
"headerFacet": {"labelFontSize": sizes[1]},
"headerRow": {"labelFontSize": sizes[1]},
"legend": {"labelFontSize": sizes[0], "titleFontSize": sizes[1]},
"text": {"fontSize": sizes[0]},
"title": {"fontSize": sizes[-1]},
},
"height": {"step": 28},
"width": 350,
}
Until another theme has been enabled, all charts will use defaults set in ``custom_theme``::
from vega_datasets import data
source = data.stocks()
lines = (
alt.Chart(source, title=alt.Title("Stocks"))
.mark_line()
.encode(x="date:T", y="price:Q", color="symbol:N")
)
lines.interactive(bind_y=False)
"""

def decorate(func: Callable[P, R], /) -> Callable[P, R]:
themes.register(name, func)
if enable:
themes.enable(name)

@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
return func(*args, **kwargs)

return wrapper

return decorate
17 changes: 15 additions & 2 deletions tests/vegalite/v5/test_theme.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,32 @@
from __future__ import annotations

import pytest

import altair.vegalite.v5 as alt
from altair.vegalite.v5.theme import VEGA_THEMES
from altair.vegalite.v5.theme import VEGA_THEMES, register_theme, themes


@pytest.fixture
def chart():
return alt.Chart("data.csv").mark_bar().encode(x="x:Q")


def test_vega_themes(chart):
def test_vega_themes(chart) -> None:
for theme in VEGA_THEMES:
with alt.themes.enable(theme):
dct = chart.to_dict()
assert dct["usermeta"] == {"embedOptions": {"theme": theme}}
assert dct["config"] == {
"view": {"continuousWidth": 300, "continuousHeight": 300}
}


def test_register_theme_decorator() -> None:
@register_theme("unique name", enable=True)
def custom_theme() -> dict[str, int]:
return {"height": 400, "width": 700}

assert themes.active == "unique name"
registered = themes.get()
assert registered is not None
assert registered() == {"height": 400, "width": 700} == custom_theme()

0 comments on commit f542e9e

Please sign in to comment.