diff --git a/magicgui/__init__.py b/magicgui/__init__.py index 441715d5a..b335c9c04 100644 --- a/magicgui/__init__.py +++ b/magicgui/__init__.py @@ -10,13 +10,14 @@ from ._magicgui import magic_factory, magicgui from .application import event_loop, use_app -from .type_map import register_type +from .type_map import register_type, type_registered __all__ = [ "event_loop", - "magicgui", "magic_factory", + "magicgui", "register_type", + "type_registered", "use_app", ] diff --git a/magicgui/type_map.py b/magicgui/type_map.py index 8d1f8a1fc..b650dd872 100644 --- a/magicgui/type_map.py +++ b/magicgui/type_map.py @@ -7,8 +7,20 @@ import types import warnings from collections import defaultdict +from contextlib import contextmanager from enum import EnumMeta -from typing import Any, Callable, DefaultDict, ForwardRef, Type, TypeVar, cast, overload +from typing import ( + Any, + Callable, + DefaultDict, + ForwardRef, + Iterator, + Optional, + Type, + TypeVar, + cast, + overload, +) from typing_extensions import Literal, get_origin @@ -369,6 +381,63 @@ def _deco(type_): return _deco if type_ is None else _deco(type_) +@contextmanager +def type_registered( + type_: _T, + *, + widget_type: WidgetRef | None = None, + return_callback: ReturnCallback | None = None, + **options, +) -> Iterator[None]: + """Context manager that temporarily registers a widget type for a given `type_`. + + When the context is exited, the previous widget type associations for `type_` is + restored. + + Parameters + ---------- + type_ : _T + The type for which a widget class or return callback will be provided. + widget_type : Optional[WidgetRef] + A widget class from the current backend that should be used whenever ``type_`` + is used as the type annotation for an argument in a decorated function, + by default None + return_callback: Optional[callable] + If provided, whenever ``type_`` is declared as the return type of a decorated + function, ``return_callback(widget, value, return_type)`` will be called + whenever the decorated function is called... where ``widget`` is the Widget + instance, and ``value`` is the return value of the decorated function. + **options + key value pairs where the keys are valid `WidgetOptions` + """ + tw = TypeWrapper(type_) + tw.resolve() + _type_ = tw.outer_type_ + + # check if return_callback is already registered + rc_was_present = return_callback in _RETURN_CALLBACKS.get(_type_, []) + # store any previous widget_type and options for this type + prev_type_def: Optional[WidgetTuple] = _TYPE_DEFS.get(_type_, None) + _type_ = register_type( + type_, widget_type=widget_type, return_callback=return_callback, **options + ) + new_type_def: Optional[WidgetTuple] = _TYPE_DEFS.get(_type_, None) + try: + yield + finally: + # restore things to before the context + if return_callback is not None and not rc_was_present: + _RETURN_CALLBACKS[_type_].remove(return_callback) + + if _TYPE_DEFS.get(_type_, None) is not new_type_def: + warnings.warn("Type definition changed during context", stacklevel=2) + + if prev_type_def is not None: + _TYPE_DEFS[_type_] = prev_type_def + else: + _TYPE_DEFS.pop(_type_, None) + + def _type2callback(type_: type) -> list[ReturnCallback]: """Check if return callbacks have been registered for ``type_`` and return if so. diff --git a/tests/test_types.py b/tests/test_types.py index f3ac99096..333e6e1f8 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -1,9 +1,12 @@ from enum import Enum +from pathlib import Path from typing import Optional, Union +from unittest.mock import Mock import pytest -from magicgui import magicgui, register_type, type_map, types, widgets +from magicgui import magicgui, register_type, type_map, type_registered, types, widgets +from magicgui.type_map import _RETURN_CALLBACKS def test_forward_refs(): @@ -103,3 +106,47 @@ def test_widget_options(): choice3 = widgets.create_widget(annotation=E) assert choice1._nullable is choice3._nullable is False assert choice2._nullable is True + + +def test_type_registered(): + assert isinstance(widgets.create_widget(annotation=Path), widgets.FileEdit) + with type_registered(Path, widget_type=widgets.LineEdit): + assert isinstance(widgets.create_widget(annotation=Path), widgets.LineEdit) + assert isinstance(widgets.create_widget(annotation=Path), widgets.FileEdit) + + +def test_type_registered_callbacks(): + @magicgui + def func(a: int) -> int: + return a + + assert not _RETURN_CALLBACKS[int] + mock = Mock() + func(1) + mock.assert_not_called() + + cb = lambda g, v, r: mock(v) # noqa + cb2 = lambda g, v, r: None # noqa + + with type_registered(int, return_callback=cb): + func(2) + mock.assert_called_once_with(2) + mock.reset_mock() + assert _RETURN_CALLBACKS[int] == [cb] + register_type(int, return_callback=cb2) + assert _RETURN_CALLBACKS[int] == [cb, cb2] + + func(3) + mock.assert_not_called() + assert _RETURN_CALLBACKS[int] == [cb2] + + +def test_type_registered_warns(): + """Test that type_registered warns if the type was changed during context.""" + assert isinstance(widgets.create_widget(annotation=Path), widgets.FileEdit) + with pytest.warns(UserWarning, match="Type definition changed during context"): + with type_registered(Path, widget_type=widgets.LineEdit): + assert isinstance(widgets.create_widget(annotation=Path), widgets.LineEdit) + register_type(Path, widget_type=widgets.TextEdit) + assert isinstance(widgets.create_widget(annotation=Path), widgets.TextEdit) + assert isinstance(widgets.create_widget(annotation=Path), widgets.FileEdit)