diff --git a/limbus/core/param.py b/limbus/core/param.py index 23b9d91..d22258a 100644 --- a/limbus/core/param.py +++ b/limbus/core/param.py @@ -236,7 +236,10 @@ class Param(ABC): value (optional): value of the parameter. Default: NoValue(). arg (optional): name of the argument in the component constructor related with this param. Default: None. parent (optional): parent component. Default: None. - callback (optional): callback to be called when the value of the parameter changes. Default: None. + callback (optional): async callback to be called when the value of the parameter changes. + Prototype: `async def callback(parent: Component, value: TYPE) -> TYPE:` + - MUST return the value to be finally used. + Default: None. """ def __init__(self, name: str, tp: Any = Any, value: Any = NoValue(), arg: None | str = None, @@ -497,8 +500,8 @@ def init_property(self, value: Any) -> None: So, it is not running the callback function. """ - assert self._parent is not None - if self._parent.executions_counter > 0: + # ComponentState.INITIALIZED means that the component was just created + if self._parent is not None and ComponentState.INITIALIZED not in self._parent.state: raise RuntimeError("The property can only be initialized before running the component.") self.value = value @@ -512,7 +515,7 @@ async def set_property(self, value: Any) -> None: if self._callback is None: self.value = value else: - self.value = await self._callback(self._parent, self.value) + self.value = await self._callback(self._parent, value) class InputParam(Param): @@ -522,6 +525,9 @@ async def receive(self) -> Any: """Wait until the input param receives a value from the connected output param. Note that using this metohd will run the callback function as soon as a new value is received. + Note tha the callback changes teh result returned by the received method, not the value inside the + param (Param.value). This is in this way because the param can be shared between several input params, + so each callback call could change its value. """ assert self._parent is not None diff --git a/limbus/core/params.py b/limbus/core/params.py index e70d8d5..5e66418 100644 --- a/limbus/core/params.py +++ b/limbus/core/params.py @@ -69,8 +69,10 @@ def declare(self, name: str, tp: Any = Any, value: Any = NoValue(), callback: Ca name: name of the parameter. tp: type (e.g. str, int, list, str | int,...). Default: typing.Any value (optional): value for the parameter. Default: NoValue(). - - callback (optional): callback function to be called when the parameter value changes. Default: None. + callback (optional): async callback function to be called when the parameter value changes. + Prototype: `async def callback(parent: Component, value: TYPE) -> TYPE:` + - MUST return the value to be finally used. + Default: None. """ if isinstance(value, Param): @@ -92,7 +94,10 @@ def declare(self, name: str, tp: Any = Any, value: Any = NoValue(), callback: Ca name: name of the parameter. tp: type (e.g. str, int, list, str | int,...). Default: typing.Any value (optional): value for the parameter. Default: NoValue(). - callback (optional): callback function to be called when the parameter value changes. Default: None. + callback (optional): async callback function to be called when the parameter value changes. + Prototype: `async def callback(parent: Component, value: TYPE) -> TYPE:` + - MUST return the value to be finally used. + Default: None. """ if isinstance(value, Param): @@ -114,9 +119,12 @@ def declare(self, name: str, tp: Any = Any, arg: None | str = None, callback: Ca name: name of the parameter. tp: type (e.g. str, int, list, str | int,...). Default: typing.Any arg (optional): Component argument directly related with the value of the parameter. Default: None. - E.g. this is useful to propagate datatypes and values from a pin with a default value to - an argument in a Component (GUI). - callback (optional): callback function to be called when the parameter value changes. Default: None. + E.g. this is useful to propagate datatypes and values from a pin with a default value to an argument + in a Component (GUI). + callback (optional): async callback function to be called when the parameter value changes. + Prototype: `async def callback(parent: Component, value: TYPE) -> TYPE:` + - MUST return the value to be finally used. + Default: None. """ setattr(self, name, OutputParam(name, tp, NoValue(), arg, self._parent, callback)) diff --git a/tests/core/test_param.py b/tests/core/test_param.py index 54ebe8e..54e8fe7 100644 --- a/tests/core/test_param.py +++ b/tests/core/test_param.py @@ -1,14 +1,17 @@ import pytest from typing import Any, List, Sequence, Iterable, Tuple import asyncio +import logging import torch import limbus.core.param from limbus.core import NoValue, Component, ComponentState -from limbus.core.param import (Container, Param, InputParam, OutputParam, +from limbus.core.param import (Container, Param, InputParam, OutputParam, PropertyParam, IterableContainer, IterableInputContainers, IterableParam, Reference) +log = logging.getLogger(__name__) + class TestContainer: def test_smoke(self): @@ -479,6 +482,20 @@ async def test_receive_from_iterable_param(self): assert list(pi._refs[1])[0].sent.is_set() is False assert pi.value == [torch.tensor(2), torch.tensor(1)] + async def test_receive_with_callback(self, caplog): + async def callback(self, value): + assert self.name == "a" + log.info(f"callback: {value}") + return 2 + po = OutputParam("b", parent=A("b")) + pi = InputParam("a", parent=A("a"), callback=callback) + po >> pi + with caplog.at_level(logging.INFO): + res = await asyncio.gather(po.send(1), pi.receive()) + assert pi.value == 1 # the callback does not change the internal param value + assert res[1] == 2 # onl changes the return value + assert "callback: 1" in caplog.text + class TestOutputParam: def test_smoke(self): @@ -532,3 +549,53 @@ async def test_send_from_iterable_param(self): assert list(po._refs[1])[0].sent.is_set() is False assert pi0.value == torch.tensor(1) assert pi1.value == torch.tensor(2) + + async def test_send_with_callback(self, caplog): + async def callback(self, value): + assert self.name == "b" + log.info(f"callback: {value}") + return 2 + po = OutputParam("b", parent=A("b"), callback=callback) + pi = InputParam("a", parent=A("a")) + po >> pi + with caplog.at_level(logging.INFO): + await asyncio.gather(po.send(1), pi.receive()) + assert pi.value == 2 + assert "callback: 1" in caplog.text + + +class TestPropertyParam: + def test_smoke(self): + p = PropertyParam("a") + assert isinstance(p, Param) + + def test_init_without_parent(self): + p = PropertyParam("a") + p.init_property(1) + assert p.value == 1 + + def test_init_with_parent(self): + p = PropertyParam("a", parent=A("b")) + p.init_property(1) + assert p.value == 1 + + async def test_set_without_parent(self): + p = PropertyParam("a") + with pytest.raises(AssertionError): + await p.set_property(1) + + async def test_set_property_with_parent(self): + p = PropertyParam("b", parent=A("b")) + await p.set_property(1) + assert p.value == 1 + + async def test_set_property_with_callback(self, caplog): + async def callback(self, value): + assert self.name == "b" + log.info(f"callback: {value}") + return 2 + p = PropertyParam("b", parent=A("b"), callback=callback) + with caplog.at_level(logging.INFO): + await p.set_property(1) + assert p.value == 2 + assert "callback: 1" in caplog.text