Skip to content

Commit

Permalink
fix(typing): Preserve generics in PluginEnabler
Browse files Browse the repository at this point in the history
  • Loading branch information
dangotbanned committed Sep 29, 2024
1 parent 5078f88 commit 2339116
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
16 changes: 10 additions & 6 deletions altair/utils/plugin_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __str__(self):
return f"No {self.name!r} entry point found in group {self.group!r}"


class PluginEnabler:
class PluginEnabler(Generic[PluginT, R]):
"""
Context manager for enabling plugins.
Expand All @@ -51,21 +51,23 @@ class PluginEnabler:
# plugins back to original state
"""

def __init__(self, registry: PluginRegistry, name: str, **options):
self.registry: PluginRegistry = registry
def __init__(
self, registry: PluginRegistry[PluginT, R], name: str, **options: Any
) -> None:
self.registry: PluginRegistry[PluginT, R] = registry
self.name: str = name
self.options: dict[str, Any] = options
self.original_state: dict[str, Any] = registry._get_state()
self.registry._enable(name, **options)

def __enter__(self) -> PluginEnabler:
def __enter__(self) -> PluginEnabler[PluginT, R]:
return self

def __exit__(self, typ: type, value: Exception, traceback: TracebackType) -> None:
self.registry._set_state(self.original_state)

def __repr__(self) -> str:
return f"{self.registry.__class__.__name__}.enable({self.name!r})"
return f"{type(self.registry).__name__}.enable({self.name!r})"


class PluginRegistry(Generic[PluginT, R]):
Expand Down Expand Up @@ -211,7 +213,9 @@ def _enable(self, name: str, **options) -> None:
self._global_settings[key] = options.pop(key)
self._options = options

def enable(self, name: str | None = None, **options) -> PluginEnabler:
def enable(
self, name: str | None = None, **options: Any
) -> PluginEnabler[PluginT, R]:
"""
Enable a plugin by name.
Expand Down
2 changes: 1 addition & 1 deletion altair/vegalite/v5/theme.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def enable(
self,
name: LiteralString | AltairThemes | VegaThemes | None = None,
**options: Any,
) -> PluginEnabler:
) -> PluginEnabler[Plugin[ThemeConfig], ThemeConfig]:
"""
Enable a theme by name.
Expand Down

0 comments on commit 2339116

Please sign in to comment.