diff --git a/README.md b/README.md index dba858c..c9ce650 100644 --- a/README.md +++ b/README.md @@ -106,7 +106,7 @@ Let's see a very simple example that sums 2 integers: class Add(Component): """Add two numbers.""" # NOTE: type definition is optional, but it helps with the intellisense. ;) - class InputsTyping(OutputParams): + class InputsTyping(InputParams): a: InputParam b: InputParam @@ -137,6 +137,15 @@ class Add(Component): return ComponentState.OK ``` +**Note** that `Component` can inherint from `nn.Module`. By default inherints from `object`. + +To change the inheritance, before importing any other `limbus` module, set the `COMPONENT_TYPE` variable as: + +```python +from limbus_config import config +config.COMPONENT_TYPE = "torch" +``` + ## Ecosystem Limbus is a core technology to easily build different components and create generic pipelines. In the following list, you can find different examples diff --git a/examples/default_cmps.py b/examples/default_cmps.py index 5e4bdfd..a3d511b 100644 --- a/examples/default_cmps.py +++ b/examples/default_cmps.py @@ -1,5 +1,7 @@ """Basic example with predefined cmps.""" import asyncio +from sys import version_info +import copy import torch @@ -25,15 +27,25 @@ t2.outputs.out >> stack.inputs.tensors.select(1) stack.outputs.tensor >> show.inputs.inp -# create the pipeline and add its nodes -pipeline = Pipeline() -pipeline.add_nodes([c1, t1, t2, stack, show]) - -# run your pipeline (only one iteration, note that this pipeline can run forever) -print("Run with pipeline:") -pipeline.run(1) - -# run 1 iteration using the asyncio loop -print("Run with loop:") -loop = asyncio.get_event_loop() -loop.run_until_complete(asyncio.gather(c1(), t1(), t2(), stack(), show())) +USING_PIPELINE = True +if USING_PIPELINE: + # run your pipeline (only one iteration, note that this pipeline can run forever) + print("Run with pipeline:") + # create the pipeline and add its nodes + pipeline = Pipeline() + pipeline.add_nodes([c1, t1, t2, stack, show]) + pipeline.run(1) +else: + # run 1 iteration using the asyncio loop + print("Run with loop:") + + async def f(): # noqa: D103 + await asyncio.gather(c1(), t1(), t2(), stack(), show()) + + if version_info.minor < 10: + # for python <3.10 the loop must be run in this way to avoid creating a new loop. + loop = asyncio.get_event_loop() + loop.run_until_complete(f()) + elif version_info.minor >= 10: + # for python >=3.10 the loop should be run in this way. + asyncio.run(f()) diff --git a/examples/defining_cmps.py b/examples/defining_cmps.py index 93805fa..b8e5b0b 100644 --- a/examples/defining_cmps.py +++ b/examples/defining_cmps.py @@ -2,8 +2,12 @@ from typing import List, Any import asyncio -from limbus.core import Component, InputParams, OutputParams, ComponentState, VerboseMode, OutputParam, InputParam -from limbus.core.pipeline import Pipeline +# If you want to change the limbus config you need to do it before importing any limbus module!!! +from limbus_config import config +config.COMPONENT_TYPE = "torch" + +from limbus.core import Component, InputParams, OutputParams, ComponentState, OutputParam, InputParam # noqa: E402 +from limbus.core import Pipeline, VerboseMode # noqa: E402 # define the components diff --git a/limbus/core/component.py b/limbus/core/component.py index 7520a3b..d01d381 100644 --- a/limbus/core/component.py +++ b/limbus/core/component.py @@ -1,24 +1,39 @@ """Component definition.""" from __future__ import annotations from abc import abstractmethod -from typing import List, Optional, TYPE_CHECKING, Callable +from typing import List, Optional, TYPE_CHECKING, Callable, Type, Union, Any, Coroutine import logging import asyncio import traceback import functools -import torch.nn as nn +try: + import torch.nn as nn +except ImportError: + pass +from limbus_config import config from limbus.core.params import Params, InputParams, OutputParams from limbus.core.states import ComponentState, ComponentStoppedError # Note that Pipeline class cannot be imported to avoid circular dependencies. if TYPE_CHECKING: from limbus.core.pipeline import Pipeline - log = logging.getLogger(__name__) +base_class: Type = object +if config.COMPONENT_TYPE == "generic": + pass +elif config.COMPONENT_TYPE == "torch": + try: + base_class = nn.Module + except NameError: + log.error("Torch not installed. Using generic base class.") +else: + log.error("Invalid component type. Using generic base class.") + + # this is a decorator that will determine how many iterations must be run def iterations_manager(func: Callable) -> Callable: """Update the last iteration to be run by the component.""" @@ -86,7 +101,7 @@ def verbose(self, value: bool) -> None: self._verbose = value -class Component(nn.Module): +class Component(base_class): """Base class to define a Limbus Component. Args: @@ -110,6 +125,12 @@ def __init__(self, name: str): # Last execution to be run in the __call__ loop. self._stopping_iteration: int = 0 # 0 means run forever + # method called in _run_with_hooks to execute the component forward method + self._run_forward: Callable[..., Coroutine[Any, Any, ComponentState]] = self.forward + if nn.Module in Component.__mro__: + # If the component inherits from nn.Module, the forward method is called by the __call__ method + self._run_forward = nn.Module.__call__ + def init_from_component(self, ref_component: Component) -> None: """Init basic execution params from another component. @@ -311,7 +332,7 @@ async def _run_with_hooks(self, *args, **kwargs) -> bool: if len(self._inputs) == 0: # RUNNING state is set once the input params are received, if there are not inputs the state is set here self.set_state(ComponentState.RUNNING) - self.set_state(await super().__call__(*args, **kwargs)) # internally it calls the forward() method + self.set_state(await self._run_forward(*args, **kwargs)) except ComponentStoppedError as e: self.set_state(e.state) except Exception as e: diff --git a/limbus/core/param.py b/limbus/core/param.py index 344f552..66dff82 100644 --- a/limbus/core/param.py +++ b/limbus/core/param.py @@ -10,13 +10,25 @@ import contextlib import typeguard -import torch +from limbus.core.states import ComponentState, ComponentStoppedError +from limbus.core import async_utils # Note that Component class cannot be imported to avoid circular dependencies. if TYPE_CHECKING: from limbus.core.component import Component -from limbus.core.states import ComponentState, ComponentStoppedError -from limbus.core import async_utils + +SUBSCRIPTABLE_TYPES: List[type] = [] +try: + import torch + SUBSCRIPTABLE_TYPES.append(torch.Tensor) +except ImportError: + pass + +try: + import numpy as np + SUBSCRIPTABLE_TYPES.append(np.ndarray) +except ImportError: + pass class NoValue: @@ -120,7 +132,7 @@ def _check_subscriptable(datatype: type) -> bool: # mypy complaints in the case origin is NoneType if is_abstract_seq or (not is_abstract and isinstance(origin(), typing.Iterable)): # type: ignore if (len(datatype_args) == 1 or (len(datatype_args) == 2 and Ellipsis in datatype_args)): - if datatype_args[0] is torch.Tensor: + if datatype_args[0] in SUBSCRIPTABLE_TYPES: return True return False diff --git a/limbus/widgets/types.py b/limbus/widgets/types.py index fa86c9e..153f0eb 100644 --- a/limbus/widgets/types.py +++ b/limbus/widgets/types.py @@ -6,16 +6,16 @@ import functools try: - # NOTE: we import the cv2 & visdom modules here to avoid having it as a dependency + # NOTE: we import these modules here to avoid having it as a dependency # for the whole project. import cv2 import visdom + import torch + import kornia + import numpy as np except ImportError: pass -import torch -import kornia -import numpy as np from limbus.core import Component from limbus import widgets @@ -121,7 +121,7 @@ def check_status(self) -> bool: raise NotImplementedError @abstractmethod - def show_image(self, component: Component, title: str, image: torch.Tensor): + def show_image(self, component: Component, title: str, image: "torch.Tensor"): """Show an image. Args: @@ -134,7 +134,7 @@ def show_image(self, component: Component, title: str, image: torch.Tensor): @abstractmethod def show_images(self, component: Component, title: str, - images: Union[torch.Tensor, List[torch.Tensor]], + images: Union["torch.Tensor", List["torch.Tensor"]], nrow: Optional[int] = None ) -> None: """Show a batch of images. @@ -203,7 +203,7 @@ def check_status(self) -> bool: @is_enabled @set_title - def show_image(self, component: Component, title: str, image: torch.Tensor) -> None: + def show_image(self, component: Component, title: str, image: "torch.Tensor") -> None: """Show an image. Args: @@ -219,7 +219,7 @@ def show_image(self, component: Component, title: str, image: torch.Tensor) -> N @is_enabled @set_title def show_images(self, component: Component, title: str, - images: Union[torch.Tensor, List[torch.Tensor]], + images: Union["torch.Tensor", List["torch.Tensor"]], nrow: Optional[int] = None ) -> None: """Show a batch of images. @@ -269,7 +269,7 @@ def check_status(self) -> bool: @is_enabled @set_title - def show_image(self, component: Component, title: str, image: torch.Tensor): + def show_image(self, component: Component, title: str, image: "torch.Tensor"): """Show an image. Args: @@ -283,7 +283,7 @@ def show_image(self, component: Component, title: str, image: torch.Tensor): @is_enabled @set_title def show_images(self, component: Component, title: str, - images: Union[torch.Tensor, List[torch.Tensor]], + images: Union["torch.Tensor", List["torch.Tensor"]], nrow: Optional[int] = None ) -> None: """Show a batch of images. @@ -326,7 +326,7 @@ def __init__(self) -> None: @is_enabled @set_title - def show_image(self, component: Component, title: str, image: torch.Tensor): + def show_image(self, component: Component, title: str, image: "torch.Tensor"): """Show an image. Args: @@ -357,7 +357,7 @@ def show_image(self, component: Component, title: str, image: torch.Tensor): @is_enabled @set_title def show_images(self, component: Component, title: str, - images: Union[torch.Tensor, List[torch.Tensor]], + images: Union["torch.Tensor", List["torch.Tensor"]], nrow: Optional[int] = None ) -> None: """Show a batch of images. diff --git a/limbus_config/__init__.py b/limbus_config/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/limbus_config/config.py b/limbus_config/config.py new file mode 100644 index 0000000..a58a5c0 --- /dev/null +++ b/limbus_config/config.py @@ -0,0 +1,11 @@ +"""Configuration file for Limbus. + +Usage: + Before importing any other Limbus module, set the COMPONENT_TYPE variable as: + > from limbus_config import config + > config.COMPONENT_TYPE = "torch" + + +""" + +COMPONENT_TYPE = "generic" # generic or torch diff --git a/setup.cfg b/setup.cfg index 2f05d8c..50abe76 100644 --- a/setup.cfg +++ b/setup.cfg @@ -30,7 +30,7 @@ ignore = exclude = docs/src [mypy] -files = limbus, tests +files = examples, limbus, tests, limbus_config show_error_codes = True ignore_missing_imports = True diff --git a/setup.py b/setup.py index 5387f46..6760018 100644 --- a/setup.py +++ b/setup.py @@ -8,10 +8,7 @@ author='Luis Ferraz', url='https://github.com/kornia/limbus', install_requires=[ - 'torch', - 'numpy', 'typeguard', - 'kornia', ], extras_require={ 'dev': [ @@ -30,6 +27,9 @@ 'limbus-components' ], 'widgets': [ + 'kornia', + 'torch', + 'numpy', 'visdom', 'opencv-python', ] diff --git a/tests/config/test_config.py b/tests/config/test_config.py new file mode 100644 index 0000000..aef4e9a --- /dev/null +++ b/tests/config/test_config.py @@ -0,0 +1,31 @@ +"""Config tests.""" +import sys + +from torch import nn + + +def remove_limbus_imports(): + """Remove limbus dependencies from sys.modules.""" + for key in list(sys.modules.keys()): + if key.startswith("limbus"): + del sys.modules[key] + + +def test_torch_base_class(): + remove_limbus_imports() + from limbus_config import config + config.COMPONENT_TYPE = "torch" + import limbus + mro = limbus.Component.__mro__ + remove_limbus_imports() + assert len(mro) == 3 + assert nn.Module in mro + + +def test_generic_base_class(): + remove_limbus_imports() + import limbus + mro = limbus.Component.__mro__ + remove_limbus_imports() + assert len(mro) == 2 + assert nn.Module not in mro