Skip to content

Commit

Permalink
Merge pull request #15 from kornia/add_callables
Browse files Browse the repository at this point in the history
Add callables
  • Loading branch information
lferraz authored May 1, 2023
2 parents 3678fae + 2d28a5e commit 8836519
Show file tree
Hide file tree
Showing 10 changed files with 297 additions and 207 deletions.
51 changes: 43 additions & 8 deletions examples/defining_cmps.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from limbus_config import config
config.COMPONENT_TYPE = "torch"

from limbus.core import Component, InputParams, OutputParams, ComponentState, OutputParam, InputParam # noqa: E402
from limbus.core import Pipeline, VerboseMode # noqa: E402
from limbus.core import (Component, InputParams, OutputParams, PropertyParams, Pipeline, VerboseMode, # noqa: E402
ComponentState, OutputParam, InputParam, async_utils) # noqa: E402


# define the components
Expand All @@ -25,14 +25,26 @@ class OutputsTyping(OutputParams): # noqa: D106
inputs: InputsTyping # type: ignore
outputs: OutputsTyping # type: ignore

async def val_rec_a(self, value: Any) -> Any: # noqa: D102
print(f"CALLBACK: Add.a: {value}.")
return value

async def val_rec_b(self, value: Any) -> Any: # noqa: D102
print(f"CALLBACK: Add.b: {value}.")
return value

async def val_sent(self, value: Any) -> Any: # noqa: D102
print(f"CALLBACK: Add.out: {value}.")
return value

@staticmethod
def register_inputs(inputs: InputParams) -> None: # noqa: D102
inputs.declare("a", int)
inputs.declare("b", int)
inputs.declare("a", int, callback=Add.val_rec_a)
inputs.declare("b", int, callback=Add.val_rec_b)

@staticmethod
def register_outputs(outputs: OutputParams) -> None: # noqa: D102
outputs.declare("out", int)
outputs.declare("out", int, callback=Add.val_sent)

async def forward(self) -> ComponentState: # noqa: D102
a, b = await asyncio.gather(self._inputs.a.receive(), self._inputs.b.receive())
Expand All @@ -49,9 +61,13 @@ class InputsTyping(OutputParams): # noqa: D106

inputs: InputsTyping # type: ignore

async def val_changed(self, value: Any) -> Any: # noqa: D102
print(f"CALLBACK: Printer.inp: {value}.")
return value

@staticmethod
def register_inputs(inputs: InputParams) -> None: # noqa: D102
inputs.declare("inp", Any)
inputs.declare("inp", Any, callback=Printer.val_changed)

async def forward(self) -> ComponentState: # noqa: D102
value = await self._inputs.inp.receive()
Expand Down Expand Up @@ -98,6 +114,18 @@ def __init__(self, name: str, elements: int = 1):
super().__init__(name)
self._elements: int = elements

async def set_elements(self, value: int) -> int: # noqa: D102
print(f"CALLBACK: Acc.elements: {value}.")
# this is a bir tricky since the value is stored in 2 places the property and the variable.
# Since the acc uses the _elements variable in the forward method we need to update it here
# as well. Thanks to the callback we do not need to worry about both sources.
self._elements = value
return value

@staticmethod
def register_properties(properties: PropertyParams) -> None: # noqa: D102
properties.declare("elements", int, callback=Acc.set_elements)

@staticmethod
def register_inputs(inputs: InputParams) -> None: # noqa: D102
inputs.declare("inp", int)
Expand Down Expand Up @@ -143,5 +171,12 @@ async def forward(self) -> ComponentState: # noqa: D102
engine.add_nodes([add, printer0])
# there are several states for each component, with this verbose mode we can see them
engine.set_verbose_mode(VerboseMode.COMPONENT)
# run all teh components at least once (since there is an accumulator, some components will be run more than once)
engine.run(1)
# run all the components at least 2 times (since there is an accumulator, some components will be run more than once)


async def run() -> None: # noqa: D103
await engine.async_run(1)
await acc.properties.elements.set_property(3) # change the number of elements to accumulate
await engine.async_run(1)

async_utils.run_coroutine(run())
8 changes: 4 additions & 4 deletions limbus/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from limbus.core.component import Component, executions_manager
from limbus.core.states import ComponentState, PipelineState, VerboseMode
from limbus.core.param import NoValue, Param, Reference, InputParam, OutputParam
from limbus.core.params import Params, InputParams, OutputParams
from limbus.core.param import NoValue, Reference, InputParam, OutputParam, PropertyParam
from limbus.core.params import PropertyParams, InputParams, OutputParams
from limbus.core.pipeline import Pipeline
from limbus.core.app import App

Expand All @@ -14,11 +14,11 @@
"Component",
"executions_manager",
"ComponentState",
"Params",
"Reference",
"PropertyParams",
"InputParams",
"OutputParams",
"PropertyParam",
"InputParam",
"OutputParam",
"Param",
"NoValue"]
1 change: 0 additions & 1 deletion limbus/core/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import asyncio
import inspect
from typing import Coroutine, TYPE_CHECKING
from sys import version_info

if TYPE_CHECKING:
from limbus.core.component import Component
Expand Down
60 changes: 16 additions & 44 deletions limbus/core/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
pass

from limbus_config import config
from limbus.core.params import Params, InputParams, OutputParams
from limbus.core.params import InputParams, OutputParams, PropertyParams
from limbus.core.states import ComponentState, ComponentStoppedError
# Note that Pipeline class cannot be imported to avoid circular dependencies.
if TYPE_CHECKING:
Expand Down Expand Up @@ -143,7 +143,7 @@ def __init__(self, name: str):
self.__class__.register_inputs(self._inputs)
self._outputs = OutputParams(self)
self.__class__.register_outputs(self._outputs)
self._properties = Params(self)
self._properties = PropertyParams(self)
self.__class__.register_properties(self._properties)
self.__state: _ComponentState = _ComponentState(self, ComponentState.INITIALIZED)
self.__pipeline: None | Pipeline = None
Expand All @@ -152,7 +152,7 @@ def __init__(self, name: str):
self.__stopping_execution: int = 0 # 0 means run forever
self.__num_params_waiting_to_receive: int = 0 # updated from InputParam

# method called in _run_with_hooks to execute the component forward method
# method called in __run_with_hooks to execute the component forward method
self.__run_forward: Callable[..., Coroutine[Any, Any, ComponentState]] = self.forward
try:
if nn.Module in Component.__mro__:
Expand Down Expand Up @@ -249,7 +249,7 @@ def outputs(self) -> OutputParams:
return self._outputs

@property
def properties(self) -> Params:
def properties(self) -> PropertyParams:
"""Get the set of properties for this component."""
return self._properties

Expand All @@ -258,7 +258,7 @@ def register_inputs(inputs: InputParams) -> None:
"""Register the input params.
Args:
inputs: Params object to register the inputs.
inputs: object to register the inputs.
"""
pass
Expand All @@ -268,51 +268,23 @@ def register_outputs(outputs: OutputParams) -> None:
"""Register the output params.
Args:
outputs: Params object to register the outputs.
outputs: object to register the outputs.
"""
pass

@staticmethod
def register_properties(properties: Params) -> None:
def register_properties(properties: PropertyParams) -> None:
"""Register the properties.
These params are optional.
Args:
properties: Params object to register the properties.
properties: object to register the properties.
"""
pass

def set_properties(self, **kwargs) -> bool:
"""Simplify the way to set the viz params.
You can pass all the viz params you want to set as keyword arguments.
These 2 codes are equivalent:
>> component.set_properties(param_name_0=value_0, param_name_1=value_1, ...)
and
>> component.properties.set_param('param_name_0', value_0)
>> component.properties.set_param('param_name_1', value_1)
>> .
>> .
Returns:
bool: True if all the passed viz params were setted, False otherwise.
"""
all_ok = True
properties: list[str] = self._properties.get_params()
for key, value in kwargs.items():
if key in properties:
self._properties.set_param(key, value)
else:
log.warning(f"In component {self._name} the param {key} is not a valid viz param.")
all_ok = False
return all_ok

@property
def pipeline(self) -> None | Pipeline:
"""Get the pipeline object."""
Expand All @@ -322,7 +294,7 @@ def set_pipeline(self, pipeline: None | Pipeline) -> None:
"""Set the pipeline running the component."""
self.__pipeline = pipeline

def _stop_component(self) -> None:
def __stop_component(self) -> None:
"""Prepare the component to be stopped."""
for input in self._inputs.get_params():
for ref in self._inputs[input].references:
Expand All @@ -349,7 +321,7 @@ async def __call__(self) -> None:
NOTE 1: If you want to use `async for...` instead of `while True` this method must be overridden.
E.g.:
async for x in xyz:
if await self._run_with_hooks(x):
if await self.__run_with_hooks(x):
break
Note that in this example the forward method will require 1 parameter.
Expand All @@ -358,7 +330,7 @@ async def __call__(self) -> None:
"""
while True:
if await self._run_with_hooks():
if await self.__run_with_hooks():
break

def is_stopped(self) -> bool:
Expand All @@ -369,23 +341,23 @@ def is_stopped(self) -> bool:
return True
return False

def _stop_if_needed(self) -> bool:
def __stop_if_needed(self) -> bool:
"""Stop the component if it is required."""
if self.is_stopped():
if ComponentState.STOPPED_AT_ITER not in self.state:
# in this case we need to force the stop of the component. When it is stopped at a given iter
# the pipeline ends without forcing anything.
self._stop_component()
self.__stop_component()
return True
return False

async def _run_with_hooks(self, *args, **kwargs) -> bool:
async def __run_with_hooks(self, *args, **kwargs) -> bool:
self.__exec_counter += 1
if self.__pipeline is not None:
await self.__pipeline.before_component_hook(self)
if self.__pipeline.before_component_user_hook:
await self.__pipeline.before_component_user_hook(self)
if self._stop_if_needed(): # just in case the component state is changed in the before_component_hook
if self.__stop_if_needed(): # just in case the component state is changed in the before_component_hook
return True
# run the component
try:
Expand All @@ -404,7 +376,7 @@ async def _run_with_hooks(self, *args, **kwargs) -> bool:
await self.__pipeline.after_component_hook(self)
if self.__pipeline.after_component_user_hook:
await self.__pipeline.after_component_user_hook(self)
if self._stop_if_needed():
if self.__stop_if_needed():
return True
return False
# if there is not a pipeline, the component is executed only once
Expand Down
Loading

0 comments on commit 8836519

Please sign in to comment.