diff --git a/examples/defining_cmps.py b/examples/defining_cmps.py index b8e5b0b..18d5786 100644 --- a/examples/defining_cmps.py +++ b/examples/defining_cmps.py @@ -6,8 +6,8 @@ from limbus_config import config config.COMPONENT_TYPE = "torch" -from limbus.core import Component, InputParams, OutputParams, ComponentState, OutputParam, InputParam # noqa: E402 -from limbus.core import Pipeline, VerboseMode # noqa: E402 +from limbus.core import (Component, InputParams, OutputParams, PropertyParams, Pipeline, VerboseMode, # noqa: E402 + ComponentState, OutputParam, InputParam, async_utils) # noqa: E402 # define the components @@ -25,14 +25,26 @@ class OutputsTyping(OutputParams): # noqa: D106 inputs: InputsTyping # type: ignore outputs: OutputsTyping # type: ignore + async def val_rec_a(self, value: Any) -> Any: # noqa: D102 + print(f"CALLBACK: Add.a: {value}.") + return value + + async def val_rec_b(self, value: Any) -> Any: # noqa: D102 + print(f"CALLBACK: Add.b: {value}.") + return value + + async def val_sent(self, value: Any) -> Any: # noqa: D102 + print(f"CALLBACK: Add.out: {value}.") + return value + @staticmethod def register_inputs(inputs: InputParams) -> None: # noqa: D102 - inputs.declare("a", int) - inputs.declare("b", int) + inputs.declare("a", int, callback=Add.val_rec_a) + inputs.declare("b", int, callback=Add.val_rec_b) @staticmethod def register_outputs(outputs: OutputParams) -> None: # noqa: D102 - outputs.declare("out", int) + outputs.declare("out", int, callback=Add.val_sent) async def forward(self) -> ComponentState: # noqa: D102 a, b = await asyncio.gather(self._inputs.a.receive(), self._inputs.b.receive()) @@ -49,9 +61,13 @@ class InputsTyping(OutputParams): # noqa: D106 inputs: InputsTyping # type: ignore + async def val_changed(self, value: Any) -> Any: # noqa: D102 + print(f"CALLBACK: Printer.inp: {value}.") + return value + @staticmethod def register_inputs(inputs: InputParams) -> None: # noqa: D102 - inputs.declare("inp", Any) + inputs.declare("inp", Any, callback=Printer.val_changed) async def forward(self) -> ComponentState: # noqa: D102 value = await self._inputs.inp.receive() @@ -98,6 +114,18 @@ def __init__(self, name: str, elements: int = 1): super().__init__(name) self._elements: int = elements + async def set_elements(self, value: int) -> int: # noqa: D102 + print(f"CALLBACK: Acc.elements: {value}.") + # this is a bir tricky since the value is stored in 2 places the property and the variable. + # Since the acc uses the _elements variable in the forward method we need to update it here + # as well. Thanks to the callback we do not need to worry about both sources. + self._elements = value + return value + + @staticmethod + def register_properties(properties: PropertyParams) -> None: # noqa: D102 + properties.declare("elements", int, callback=Acc.set_elements) + @staticmethod def register_inputs(inputs: InputParams) -> None: # noqa: D102 inputs.declare("inp", int) @@ -143,5 +171,12 @@ async def forward(self) -> ComponentState: # noqa: D102 engine.add_nodes([add, printer0]) # there are several states for each component, with this verbose mode we can see them engine.set_verbose_mode(VerboseMode.COMPONENT) -# run all teh components at least once (since there is an accumulator, some components will be run more than once) -engine.run(1) +# run all the components at least 2 times (since there is an accumulator, some components will be run more than once) + + +async def run() -> None: # noqa: D103 + await engine.async_run(1) + await acc.properties.elements.set_property(3) # change the number of elements to accumulate + await engine.async_run(1) + +async_utils.run_coroutine(run()) diff --git a/limbus/core/__init__.py b/limbus/core/__init__.py index 481214c..c91f907 100644 --- a/limbus/core/__init__.py +++ b/limbus/core/__init__.py @@ -1,7 +1,7 @@ from limbus.core.component import Component, executions_manager from limbus.core.states import ComponentState, PipelineState, VerboseMode -from limbus.core.param import NoValue, Param, Reference, InputParam, OutputParam -from limbus.core.params import Params, InputParams, OutputParams +from limbus.core.param import NoValue, Reference, InputParam, OutputParam, PropertyParam +from limbus.core.params import PropertyParams, InputParams, OutputParams from limbus.core.pipeline import Pipeline from limbus.core.app import App @@ -14,11 +14,11 @@ "Component", "executions_manager", "ComponentState", - "Params", "Reference", + "PropertyParams", "InputParams", "OutputParams", + "PropertyParam", "InputParam", "OutputParam", - "Param", "NoValue"] diff --git a/limbus/core/async_utils.py b/limbus/core/async_utils.py index 8db3ecd..e223718 100644 --- a/limbus/core/async_utils.py +++ b/limbus/core/async_utils.py @@ -3,7 +3,6 @@ import asyncio import inspect from typing import Coroutine, TYPE_CHECKING -from sys import version_info if TYPE_CHECKING: from limbus.core.component import Component diff --git a/limbus/core/component.py b/limbus/core/component.py index baeeae3..dd70503 100644 --- a/limbus/core/component.py +++ b/limbus/core/component.py @@ -14,7 +14,7 @@ pass from limbus_config import config -from limbus.core.params import Params, InputParams, OutputParams +from limbus.core.params import InputParams, OutputParams, PropertyParams from limbus.core.states import ComponentState, ComponentStoppedError # Note that Pipeline class cannot be imported to avoid circular dependencies. if TYPE_CHECKING: @@ -143,7 +143,7 @@ def __init__(self, name: str): self.__class__.register_inputs(self._inputs) self._outputs = OutputParams(self) self.__class__.register_outputs(self._outputs) - self._properties = Params(self) + self._properties = PropertyParams(self) self.__class__.register_properties(self._properties) self.__state: _ComponentState = _ComponentState(self, ComponentState.INITIALIZED) self.__pipeline: None | Pipeline = None @@ -152,7 +152,7 @@ def __init__(self, name: str): self.__stopping_execution: int = 0 # 0 means run forever self.__num_params_waiting_to_receive: int = 0 # updated from InputParam - # method called in _run_with_hooks to execute the component forward method + # method called in __run_with_hooks to execute the component forward method self.__run_forward: Callable[..., Coroutine[Any, Any, ComponentState]] = self.forward try: if nn.Module in Component.__mro__: @@ -249,7 +249,7 @@ def outputs(self) -> OutputParams: return self._outputs @property - def properties(self) -> Params: + def properties(self) -> PropertyParams: """Get the set of properties for this component.""" return self._properties @@ -258,7 +258,7 @@ def register_inputs(inputs: InputParams) -> None: """Register the input params. Args: - inputs: Params object to register the inputs. + inputs: object to register the inputs. """ pass @@ -268,51 +268,23 @@ def register_outputs(outputs: OutputParams) -> None: """Register the output params. Args: - outputs: Params object to register the outputs. + outputs: object to register the outputs. """ pass @staticmethod - def register_properties(properties: Params) -> None: + def register_properties(properties: PropertyParams) -> None: """Register the properties. These params are optional. Args: - properties: Params object to register the properties. + properties: object to register the properties. """ pass - def set_properties(self, **kwargs) -> bool: - """Simplify the way to set the viz params. - - You can pass all the viz params you want to set as keyword arguments. - - These 2 codes are equivalent: - >> component.set_properties(param_name_0=value_0, param_name_1=value_1, ...) - - and - >> component.properties.set_param('param_name_0', value_0) - >> component.properties.set_param('param_name_1', value_1) - >> . - >> . - - Returns: - bool: True if all the passed viz params were setted, False otherwise. - - """ - all_ok = True - properties: list[str] = self._properties.get_params() - for key, value in kwargs.items(): - if key in properties: - self._properties.set_param(key, value) - else: - log.warning(f"In component {self._name} the param {key} is not a valid viz param.") - all_ok = False - return all_ok - @property def pipeline(self) -> None | Pipeline: """Get the pipeline object.""" @@ -322,7 +294,7 @@ def set_pipeline(self, pipeline: None | Pipeline) -> None: """Set the pipeline running the component.""" self.__pipeline = pipeline - def _stop_component(self) -> None: + def __stop_component(self) -> None: """Prepare the component to be stopped.""" for input in self._inputs.get_params(): for ref in self._inputs[input].references: @@ -349,7 +321,7 @@ async def __call__(self) -> None: NOTE 1: If you want to use `async for...` instead of `while True` this method must be overridden. E.g.: async for x in xyz: - if await self._run_with_hooks(x): + if await self.__run_with_hooks(x): break Note that in this example the forward method will require 1 parameter. @@ -358,7 +330,7 @@ async def __call__(self) -> None: """ while True: - if await self._run_with_hooks(): + if await self.__run_with_hooks(): break def is_stopped(self) -> bool: @@ -369,23 +341,23 @@ def is_stopped(self) -> bool: return True return False - def _stop_if_needed(self) -> bool: + def __stop_if_needed(self) -> bool: """Stop the component if it is required.""" if self.is_stopped(): if ComponentState.STOPPED_AT_ITER not in self.state: # in this case we need to force the stop of the component. When it is stopped at a given iter # the pipeline ends without forcing anything. - self._stop_component() + self.__stop_component() return True return False - async def _run_with_hooks(self, *args, **kwargs) -> bool: + async def __run_with_hooks(self, *args, **kwargs) -> bool: self.__exec_counter += 1 if self.__pipeline is not None: await self.__pipeline.before_component_hook(self) if self.__pipeline.before_component_user_hook: await self.__pipeline.before_component_user_hook(self) - if self._stop_if_needed(): # just in case the component state is changed in the before_component_hook + if self.__stop_if_needed(): # just in case the component state is changed in the before_component_hook return True # run the component try: @@ -404,7 +376,7 @@ async def _run_with_hooks(self, *args, **kwargs) -> bool: await self.__pipeline.after_component_hook(self) if self.__pipeline.after_component_user_hook: await self.__pipeline.after_component_user_hook(self) - if self._stop_if_needed(): + if self.__stop_if_needed(): return True return False # if there is not a pipeline, the component is executed only once diff --git a/limbus/core/param.py b/limbus/core/param.py index 36692cc..d22258a 100644 --- a/limbus/core/param.py +++ b/limbus/core/param.py @@ -3,11 +3,12 @@ from dataclasses import dataclass from collections import defaultdict import typing -from typing import Any, TYPE_CHECKING +from typing import Any, TYPE_CHECKING, Callable import inspect import collections import asyncio import contextlib +from abc import ABC import typeguard @@ -226,8 +227,8 @@ def __eq__(self, other: Any) -> bool: return False -class Param: - """Class to store data for each parameter. +class Param(ABC): + """Base class to store data for each parameter. Args: name: name of the parameter. @@ -235,10 +236,14 @@ class Param: value (optional): value of the parameter. Default: NoValue(). arg (optional): name of the argument in the component constructor related with this param. Default: None. parent (optional): parent component. Default: None. + callback (optional): async callback to be called when the value of the parameter changes. + Prototype: `async def callback(parent: Component, value: TYPE) -> TYPE:` + - MUST return the value to be finally used. + Default: None. """ def __init__(self, name: str, tp: Any = Any, value: Any = NoValue(), arg: None | str = None, - parent: None | Component = None) -> None: + parent: None | Component = None, callback: Callable | None = None) -> None: # validate that the type is coherent with the value if not isinstance(value, NoValue): typeguard.check_type(name, value, tp) @@ -253,6 +258,7 @@ def __init__(self, name: str, tp: Any = Any, value: Any = NoValue(), arg: None | # only sequences with tensors inside are subscriptable self._is_subscriptable = _check_subscriptable(tp) self._parent: None | Component = parent + self._callback: None | Callable = callback @property def is_subscriptable(self) -> bool: @@ -274,7 +280,11 @@ def parent(self) -> None | Component: @property def arg(self) -> None | str: - """Get the argument related with the param.""" + """Get the argument in the Component constructor related with this param. + + This is a trick to pass a value and type of an argument in the Component constructor to this parameter. + + """ return self._arg @property @@ -295,10 +305,6 @@ def references(self) -> set[Reference]: refs = refs.union(ref_set) return refs - def __call__(self) -> Any: - """Get the value of the parameter.""" - return self.value - @property def value(self) -> Any: """Get the value of the parameter.""" @@ -484,11 +490,46 @@ def disconnect(self, dst: "Param" | IterableParam) -> None: self._disconnect(self, dst) +class PropertyParam(Param): + """Class to manage the comunication for each property parameter.""" + + def init_property(self, value: Any) -> None: + """Initialize the property with the given value. + + This method should be called before running the component to init the property. + So, it is not running the callback function. + + """ + # ComponentState.INITIALIZED means that the component was just created + if self._parent is not None and ComponentState.INITIALIZED not in self._parent.state: + raise RuntimeError("The property can only be initialized before running the component.") + self.value = value + + async def set_property(self, value: Any) -> None: + """Set the value of the property. + + Note: using this method is the only way to run the callback function. + + """ + assert self._parent is not None + if self._callback is None: + self.value = value + else: + self.value = await self._callback(self._parent, value) + + class InputParam(Param): """Class to manage the comunication for each input parameter.""" async def receive(self) -> Any: - """Wait until the input param receives a value from the connected output param.""" + """Wait until the input param receives a value from the connected output param. + + Note that using this metohd will run the callback function as soon as a new value is received. + Note tha the callback changes teh result returned by the received method, not the value inside the + param (Param.value). This is in this way because the param can be shared between several input params, + so each callback call could change its value. + + """ assert self._parent is not None self._parent._Component__num_params_waiting_to_receive += 1 if self.references: @@ -545,8 +586,12 @@ async def receive(self) -> Any: ref.sent.clear() # allow to know to the sender that it can send again else: value = self.value + if self._callback is not None: + # specific callback for this param + value = await self._callback(self._parent, value) await self._are_all_waiting_params_received() if self._parent.pipeline and self._parent.pipeline.param_received_user_hook: + # hook from the pipeline, all the components and input params run the same code await self._parent.pipeline.param_received_user_hook(self) return value @@ -562,9 +607,17 @@ class OutputParam(Param): """Class to manage the comunication for each output parameter.""" async def send(self, value: Any) -> None: - """Send the value of this param to the connected input params.""" + """Send the value of this param to the connected input params. + + Note that using this metohd will run the callback function as soon as a new value is received. + + """ assert self._parent is not None - self.value = value # set the value for the param + if self._callback is None: + self.value = value # set the value for the param + else: + self.value = await self._callback(self._parent, value) + for ref in self.references: assert isinstance(ref.sent, asyncio.Event) assert isinstance(ref.consumed, asyncio.Event) diff --git a/limbus/core/params.py b/limbus/core/params.py index 37f215d..5e66418 100644 --- a/limbus/core/params.py +++ b/limbus/core/params.py @@ -1,48 +1,25 @@ """Classes to define set of parameters.""" from __future__ import annotations -from typing import Any, Iterator, Iterable +from typing import Any, Iterator, Iterable, Callable +from abc import ABC, abstractmethod # Note that Component class cannot be imported to avoid circular dependencies. # Since it is only used for type hints we import the module and use "component.Component" for typing. from limbus.core import component -from limbus.core.param import Param, NoValue, InputParam, OutputParam +from limbus.core.param import Param, NoValue, InputParam, OutputParam, PropertyParam -class Params(Iterable): +class Params(Iterable, ABC): """Class to store parameters.""" def __init__(self, parent_component: None | "component.Component" = None): super().__init__() self._parent = parent_component - def declare(self, name: str, tp: Any = Any, value: Any = NoValue(), arg: None | str = None) -> None: - """Add or modify a param. - - Args: - name: name of the parameter. - tp: type (e.g. str, int, list, str | int,...). Default: typing.Any - value (optional): value for the parameter. Default: NoValue(). - arg (optional): Component argument directly related with the value of the parameter. Default: None. - E.g. this is useful to propagate datatypes and values from a pin with a default value to - an argument in a Component (GUI). - - """ - if isinstance(value, Param): - value = value.value - setattr(self, name, Param(name, tp, value, arg, self._parent)) - - def __getattr__(self, name: str) -> Param: # type: ignore # it should return a Param - """Trick to avoid mypy issues with dinamyc attributes.""" - ... - - def get_related_arg(self, name: str) -> None | str: - """Return the argument in the Component constructor related with a given param. - - Args: - name: name of the param. - - """ - return getattr(self, name).arg + @abstractmethod + def declare(self, *args, **kwargs) -> None: + """Add or modify a param.""" + raise NotImplementedError def get_params(self, only_connected: bool = False) -> list[str]: """Return the name of all the params. @@ -58,43 +35,9 @@ def get_params(self, only_connected: bool = False) -> list[str]: params.append(name) return params - def get_types(self) -> dict[str, type]: - """Return the name and the type of all the params.""" - types: dict[str, type] = { - name: getattr(self, name).type for name in self.__dict__ if not name.startswith('_')} - return types - - def get_type(self, name: str) -> type: - """Return the type of a given param. - - Args: - name: name of the param. - - """ - return getattr(self, name).type - - def get_param(self, name: str) -> Any: - """Return the param value. - - Args: - name: name of the param. - - """ - return getattr(self, name).value - def __len__(self) -> int: return len(self.get_params()) - def set_param(self, name: str, value: Any) -> None: - """Set the param value. - - Args: - name: name of the param. - value: value to be setted. - - """ - getattr(self, name).value = value - def __getitem__(self, name: str) -> Param: return getattr(self, name) @@ -119,45 +62,72 @@ def __repr__(self) -> str: class InputParams(Params): """Class to manage input parameters.""" - def declare(self, name: str, tp: Any = Any, value: Any = NoValue(), arg: None | str = None) -> None: + def declare(self, name: str, tp: Any = Any, value: Any = NoValue(), callback: Callable | None = None) -> None: """Add or modify a param. Args: name: name of the parameter. tp: type (e.g. str, int, list, str | int,...). Default: typing.Any value (optional): value for the parameter. Default: NoValue(). - arg (optional): Component argument directly related with the value of the parameter. Default: None. - E.g. this is useful to propagate datatypes and values from a pin with a default value to - an argument in a Component (GUI). + callback (optional): async callback function to be called when the parameter value changes. + Prototype: `async def callback(parent: Component, value: TYPE) -> TYPE:` + - MUST return the value to be finally used. + Default: None. """ if isinstance(value, Param): value = value.value - setattr(self, name, InputParam(name, tp, value, arg, self._parent)) + setattr(self, name, InputParam(name, tp, value, None, self._parent, callback)) def __getattr__(self, name: str) -> InputParam: # type: ignore # it should return an InitParam """Trick to avoid mypy issues with dinamyc attributes.""" ... -class OutputParams(Params): - """Class to manage output parameters.""" +class PropertyParams(Params): + """Class to manage property parameters.""" - def declare(self, name: str, tp: Any = Any, value: Any = NoValue(), arg: None | str = None) -> None: + def declare(self, name: str, tp: Any = Any, value: Any = NoValue(), callback: Callable | None = None) -> None: """Add or modify a param. Args: name: name of the parameter. tp: type (e.g. str, int, list, str | int,...). Default: typing.Any value (optional): value for the parameter. Default: NoValue(). - arg (optional): Component argument directly related with the value of the parameter. Default: None. - E.g. this is useful to propagate datatypes and values from a pin with a default value to - an argument in a Component (GUI). + callback (optional): async callback function to be called when the parameter value changes. + Prototype: `async def callback(parent: Component, value: TYPE) -> TYPE:` + - MUST return the value to be finally used. + Default: None. """ if isinstance(value, Param): value = value.value - setattr(self, name, OutputParam(name, tp, value, arg, self._parent)) + setattr(self, name, PropertyParam(name, tp, value, None, self._parent, callback)) + + def __getattr__(self, name: str) -> PropertyParam: # type: ignore # it should return an PropParam + """Trick to avoid mypy issues with dinamyc attributes.""" + ... + + +class OutputParams(Params): + """Class to manage output parameters.""" + + def declare(self, name: str, tp: Any = Any, arg: None | str = None, callback: Callable | None = None) -> None: + """Add or modify a param. + + Args: + name: name of the parameter. + tp: type (e.g. str, int, list, str | int,...). Default: typing.Any + arg (optional): Component argument directly related with the value of the parameter. Default: None. + E.g. this is useful to propagate datatypes and values from a pin with a default value to an argument + in a Component (GUI). + callback (optional): async callback function to be called when the parameter value changes. + Prototype: `async def callback(parent: Component, value: TYPE) -> TYPE:` + - MUST return the value to be finally used. + Default: None. + + """ + setattr(self, name, OutputParam(name, tp, NoValue(), arg, self._parent, callback)) def __getattr__(self, name: str) -> OutputParam: # type: ignore # it should return an OutputParam """Trick to avoid mypy issues with dinamyc attributes.""" diff --git a/limbus/widgets/widget_component.py b/limbus/widgets/widget_component.py index 59b3793..c939654 100644 --- a/limbus/widgets/widget_component.py +++ b/limbus/widgets/widget_component.py @@ -5,7 +5,7 @@ from enum import Enum from limbus import widgets -from limbus.core import Component, ComponentState, Params +from limbus.core import Component, ComponentState, PropertyParams class WidgetState(Enum): @@ -65,11 +65,11 @@ class BaseWidgetComponent(WidgetComponent): WIDGET_STATE: WidgetState = WidgetState.ENABLED @staticmethod - def register_properties(properties: Params) -> None: + def register_properties(properties: PropertyParams) -> None: """Register the properties. Args: - properties: Params object to register the properties. + properties: object to register the properties. """ # this line is like super() but for static methods. @@ -81,12 +81,12 @@ async def _show(self, title: str) -> None: """Show the data. Args: - title: same as self._properties.get_param("title"). + title: same as self._properties[]"title"].value. """ raise NotImplementedError @is_disabled async def forward(self) -> ComponentState: # noqa: D102 - await self._show(self._properties.get_param("title")) + await self._show(self._properties["title"].value) return ComponentState.OK diff --git a/tests/core/test_component.py b/tests/core/test_component.py index 3dc0527..113c306 100644 --- a/tests/core/test_component.py +++ b/tests/core/test_component.py @@ -57,7 +57,7 @@ def test_set_state(self): cmp.state_message(ComponentState.ERROR) == "error" cmp.state_message(ComponentState.FORCED_STOP) is None - def test_set_properties(self): + def test_register_properties(self): class A(Component): @staticmethod def register_properties(properties): @@ -66,23 +66,11 @@ def register_properties(properties): properties.declare("b", float, 2.) cmp = A("yuhu") - assert cmp.properties.get_param("a") == 1. - assert cmp.properties.get_param("b") == 2. - assert cmp.properties.a() == 1. - assert cmp.properties.b() == 2. - cmp.properties.set_param("a", 3.) - assert cmp.properties.get_param("a") == 3. - assert cmp.set_properties(a=4., b=5.) - assert cmp.properties.get_param("a") == 4. - assert cmp.properties.get_param("b") == 5. - assert cmp.set_properties(c=4.) is False - p = cmp.properties.get_params() - assert len(p) == 2 - assert p[0] in ["a", "b"] - assert p[1] in ["a", "b"] - p = cmp.properties.get_types() - assert p["a"] == float - assert p["b"] == float + assert len(cmp.properties) == 2 + assert len(cmp.inputs) == 0 + assert len(cmp.outputs) == 0 + assert cmp.properties.a.value == 1. + assert cmp.properties.b.value == 2. def test_register_inputs(self): class A(Component): @@ -92,6 +80,7 @@ def register_inputs(inputs): inputs.declare("b", float, 2.) cmp = A("yuhu") + assert len(cmp.properties) == 0 assert len(cmp.outputs) == 0 assert len(cmp.inputs) == 2 assert cmp.inputs.a.value == 1. @@ -101,14 +90,15 @@ def test_register_outputs(self): class A(Component): @staticmethod def register_outputs(outputs): - outputs.declare("a", float, 1.) - outputs.declare("b", float, 2.) + outputs.declare("a", float) + outputs.declare("b", float) cmp = A("yuhu") + assert len(cmp.properties) == 0 assert len(cmp.inputs) == 0 assert len(cmp.outputs) == 2 - assert cmp.outputs.a.value == 1. - assert cmp.outputs.b.value == 2. + assert cmp.outputs.a.type is float + assert cmp.outputs.b.type is float def test_init_from_component(self): class A(Component): diff --git a/tests/core/test_param.py b/tests/core/test_param.py index 154fbbf..54e8fe7 100644 --- a/tests/core/test_param.py +++ b/tests/core/test_param.py @@ -1,14 +1,17 @@ import pytest from typing import Any, List, Sequence, Iterable, Tuple import asyncio +import logging import torch import limbus.core.param from limbus.core import NoValue, Component, ComponentState -from limbus.core.param import (Container, Param, InputParam, OutputParam, +from limbus.core.param import (Container, Param, InputParam, OutputParam, PropertyParam, IterableContainer, IterableInputContainers, IterableParam, Reference) +log = logging.getLogger(__name__) + class TestContainer: def test_smoke(self): @@ -117,7 +120,6 @@ def test_smoke(self): assert p.arg is None assert p._is_subscriptable is False assert p.is_subscriptable is False - assert p() == p.value def test_subcriptability(self): p = Param("a", List[torch.Tensor], value=[torch.tensor(1), torch.tensor(1)]) @@ -136,7 +138,6 @@ def test_init_with_type(self): def test_init_with_value(self): p = Param("a", tp=int, value=1) assert p.value == 1 - assert p() == 1 def test_init_with_invalid_value_raise_error(self): with pytest.raises(TypeError): @@ -481,6 +482,20 @@ async def test_receive_from_iterable_param(self): assert list(pi._refs[1])[0].sent.is_set() is False assert pi.value == [torch.tensor(2), torch.tensor(1)] + async def test_receive_with_callback(self, caplog): + async def callback(self, value): + assert self.name == "a" + log.info(f"callback: {value}") + return 2 + po = OutputParam("b", parent=A("b")) + pi = InputParam("a", parent=A("a"), callback=callback) + po >> pi + with caplog.at_level(logging.INFO): + res = await asyncio.gather(po.send(1), pi.receive()) + assert pi.value == 1 # the callback does not change the internal param value + assert res[1] == 2 # onl changes the return value + assert "callback: 1" in caplog.text + class TestOutputParam: def test_smoke(self): @@ -534,3 +549,53 @@ async def test_send_from_iterable_param(self): assert list(po._refs[1])[0].sent.is_set() is False assert pi0.value == torch.tensor(1) assert pi1.value == torch.tensor(2) + + async def test_send_with_callback(self, caplog): + async def callback(self, value): + assert self.name == "b" + log.info(f"callback: {value}") + return 2 + po = OutputParam("b", parent=A("b"), callback=callback) + pi = InputParam("a", parent=A("a")) + po >> pi + with caplog.at_level(logging.INFO): + await asyncio.gather(po.send(1), pi.receive()) + assert pi.value == 2 + assert "callback: 1" in caplog.text + + +class TestPropertyParam: + def test_smoke(self): + p = PropertyParam("a") + assert isinstance(p, Param) + + def test_init_without_parent(self): + p = PropertyParam("a") + p.init_property(1) + assert p.value == 1 + + def test_init_with_parent(self): + p = PropertyParam("a", parent=A("b")) + p.init_property(1) + assert p.value == 1 + + async def test_set_without_parent(self): + p = PropertyParam("a") + with pytest.raises(AssertionError): + await p.set_property(1) + + async def test_set_property_with_parent(self): + p = PropertyParam("b", parent=A("b")) + await p.set_property(1) + assert p.value == 1 + + async def test_set_property_with_callback(self, caplog): + async def callback(self, value): + assert self.name == "b" + log.info(f"callback: {value}") + return 2 + p = PropertyParam("b", parent=A("b"), callback=callback) + with caplog.at_level(logging.INFO): + await p.set_property(1) + assert p.value == 2 + assert "callback: 1" in caplog.text diff --git a/tests/core/test_params.py b/tests/core/test_params.py index 341bd59..81f6b6b 100644 --- a/tests/core/test_params.py +++ b/tests/core/test_params.py @@ -1,41 +1,45 @@ import pytest -from typing import Any import torch -from limbus.core import Params, NoValue, InputParams, OutputParams -from limbus.core.param import Param, InputParam, OutputParam +from limbus.core import PropertyParams, NoValue, InputParams, OutputParams +from limbus.core.param import Param, InputParam, OutputParam, PropertyParam + + +class TParams(InputParams): + """Test class to test Params class with all teh posible params for Param. + + NOTE: Inherits from InputParams because it is the only one that allows to use all the args in Param. + + """ + pass class TestParams: def test_smoke(self): - p = Params() + p = TParams() assert p is not None def test_declare(self): - p = Params() - - assert p.x is None # p.x does not exist but Params accept dynamic attributes - + p = TParams() p.declare("x") assert isinstance(p.x.value, NoValue) - assert isinstance(p.get_param("x"), NoValue) + assert isinstance(p["x"].value, NoValue) p.declare("y", float, 1.) assert p.y.value == 1. assert p["y"].value == 1. - assert p.get_param("y") == 1. assert isinstance(p["y"], Param) assert isinstance(p.y, Param) assert isinstance(p["y"].value, float) assert p["y"].type == float assert p["y"].name == "y" assert p["y"].arg is None - assert p.get_related_arg("y") is None + assert p.y.arg is None def test_tensor(self): - p1 = Params() - p2 = Params() + p1 = TParams() + p2 = TParams() p1.declare("x", torch.Tensor, torch.tensor(1.)) assert isinstance(p1["x"].value, torch.Tensor) @@ -43,32 +47,24 @@ def test_tensor(self): p2.declare("y", torch.Tensor, p1.x) assert p1.x.value == p2.y.value - def test_get_param(self): - p = Params() + def test_get_params(self): + p = TParams() p.declare("x") p.declare("y", float, 1.) assert len(p) == 2 assert p.get_params() == ["x", "y"] - assert isinstance(p.get_param("x"), NoValue) - assert p.get_param("y") == 1. - p.set_param("x", "xyz") - assert p.get_param("x") == "xyz" + assert isinstance(p.x.value, NoValue) + assert p.y.value == 1. + p.x.value = "xyz" + assert p.x.value == "xyz" def test_wrong_set_param_type(self): - p = Params() + p = TParams() with pytest.raises(TypeError): p.declare("x", int, 1.) p.declare("x", int) with pytest.raises(TypeError): - p.set_param("x", "xyz") - - def test_get_type(self): - p = Params() - p.declare("x") - p.declare("y", float, 1.) - assert p.get_type("x") == Any - assert p.get_type("y") == float - assert p.get_types() == {"x": Any, "y": float} + p.x.value = "xyz" class TestInputParams: @@ -82,16 +78,26 @@ def test_declare_with_param(self): p0 = Param("x", float, 1.) p.declare("x", float, p0) assert p.x.value == p0.value + assert p.z is None # Intellisense asumes p.z exist as an InputParams class TestOutputParams: def test_declare(self): p = OutputParams() - p.declare("x", float, 1.) + p.declare("x", float) assert isinstance(p.x, OutputParam) + assert p.z is None # Intellisense asumes p.z exist as an OutputParam + + +class TestPropertyParams: + def test_declare(self): + p = PropertyParams() + p.declare("x", float, 1.) + assert isinstance(p.x, PropertyParam) + assert p.z is None # Intellisense asumes p.z exist as an PropParams def test_declare_with_param(self): - p = OutputParams() + p = PropertyParams() p0 = Param("x", float, 1.) p.declare("x", float, p0) assert p.x.value == p0.value