Skip to content

Commit

Permalink
feat: add context manager for register_type (#470)
Browse files Browse the repository at this point in the history
* feat: add context manager for register_type

* review and add more tests

* check if rc was present
  • Loading branch information
tlambert03 authored Oct 22, 2022
1 parent 88e598b commit f0aabb8
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 4 deletions.
5 changes: 3 additions & 2 deletions magicgui/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@

from ._magicgui import magic_factory, magicgui
from .application import event_loop, use_app
from .type_map import register_type
from .type_map import register_type, type_registered

__all__ = [
"event_loop",
"magicgui",
"magic_factory",
"magicgui",
"register_type",
"type_registered",
"use_app",
]

Expand Down
71 changes: 70 additions & 1 deletion magicgui/type_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,20 @@
import types
import warnings
from collections import defaultdict
from contextlib import contextmanager
from enum import EnumMeta
from typing import Any, Callable, DefaultDict, ForwardRef, Type, TypeVar, cast, overload
from typing import (
Any,
Callable,
DefaultDict,
ForwardRef,
Iterator,
Optional,
Type,
TypeVar,
cast,
overload,
)

from typing_extensions import Literal, get_origin

Expand Down Expand Up @@ -369,6 +381,63 @@ def _deco(type_):
return _deco if type_ is None else _deco(type_)


@contextmanager
def type_registered(
type_: _T,
*,
widget_type: WidgetRef | None = None,
return_callback: ReturnCallback | None = None,
**options,
) -> Iterator[None]:
"""Context manager that temporarily registers a widget type for a given `type_`.
When the context is exited, the previous widget type associations for `type_` is
restored.
Parameters
----------
type_ : _T
The type for which a widget class or return callback will be provided.
widget_type : Optional[WidgetRef]
A widget class from the current backend that should be used whenever ``type_``
is used as the type annotation for an argument in a decorated function,
by default None
return_callback: Optional[callable]
If provided, whenever ``type_`` is declared as the return type of a decorated
function, ``return_callback(widget, value, return_type)`` will be called
whenever the decorated function is called... where ``widget`` is the Widget
instance, and ``value`` is the return value of the decorated function.
**options
key value pairs where the keys are valid `WidgetOptions`
"""
tw = TypeWrapper(type_)
tw.resolve()
_type_ = tw.outer_type_

# check if return_callback is already registered
rc_was_present = return_callback in _RETURN_CALLBACKS.get(_type_, [])
# store any previous widget_type and options for this type
prev_type_def: Optional[WidgetTuple] = _TYPE_DEFS.get(_type_, None)
_type_ = register_type(
type_, widget_type=widget_type, return_callback=return_callback, **options
)
new_type_def: Optional[WidgetTuple] = _TYPE_DEFS.get(_type_, None)
try:
yield
finally:
# restore things to before the context
if return_callback is not None and not rc_was_present:
_RETURN_CALLBACKS[_type_].remove(return_callback)

if _TYPE_DEFS.get(_type_, None) is not new_type_def:
warnings.warn("Type definition changed during context", stacklevel=2)

if prev_type_def is not None:
_TYPE_DEFS[_type_] = prev_type_def
else:
_TYPE_DEFS.pop(_type_, None)


def _type2callback(type_: type) -> list[ReturnCallback]:
"""Check if return callbacks have been registered for ``type_`` and return if so.
Expand Down
49 changes: 48 additions & 1 deletion tests/test_types.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from enum import Enum
from pathlib import Path
from typing import Optional, Union
from unittest.mock import Mock

import pytest

from magicgui import magicgui, register_type, type_map, types, widgets
from magicgui import magicgui, register_type, type_map, type_registered, types, widgets
from magicgui.type_map import _RETURN_CALLBACKS


def test_forward_refs():
Expand Down Expand Up @@ -103,3 +106,47 @@ def test_widget_options():
choice3 = widgets.create_widget(annotation=E)
assert choice1._nullable is choice3._nullable is False
assert choice2._nullable is True


def test_type_registered():
assert isinstance(widgets.create_widget(annotation=Path), widgets.FileEdit)
with type_registered(Path, widget_type=widgets.LineEdit):
assert isinstance(widgets.create_widget(annotation=Path), widgets.LineEdit)
assert isinstance(widgets.create_widget(annotation=Path), widgets.FileEdit)


def test_type_registered_callbacks():
@magicgui
def func(a: int) -> int:
return a

assert not _RETURN_CALLBACKS[int]
mock = Mock()
func(1)
mock.assert_not_called()

cb = lambda g, v, r: mock(v) # noqa
cb2 = lambda g, v, r: None # noqa

with type_registered(int, return_callback=cb):
func(2)
mock.assert_called_once_with(2)
mock.reset_mock()
assert _RETURN_CALLBACKS[int] == [cb]
register_type(int, return_callback=cb2)
assert _RETURN_CALLBACKS[int] == [cb, cb2]

func(3)
mock.assert_not_called()
assert _RETURN_CALLBACKS[int] == [cb2]


def test_type_registered_warns():
"""Test that type_registered warns if the type was changed during context."""
assert isinstance(widgets.create_widget(annotation=Path), widgets.FileEdit)
with pytest.warns(UserWarning, match="Type definition changed during context"):
with type_registered(Path, widget_type=widgets.LineEdit):
assert isinstance(widgets.create_widget(annotation=Path), widgets.LineEdit)
register_type(Path, widget_type=widgets.TextEdit)
assert isinstance(widgets.create_widget(annotation=Path), widgets.TextEdit)
assert isinstance(widgets.create_widget(annotation=Path), widgets.FileEdit)

0 comments on commit f0aabb8

Please sign in to comment.