Skip to content

Commit

Permalink
Merge pull request #9 from kornia/reduce_dependencies
Browse files Browse the repository at this point in the history
remove torch and numpy dependencies of limbus
  • Loading branch information
lferraz authored Mar 15, 2023
2 parents 868ec30 + 5d6c99b commit e555542
Show file tree
Hide file tree
Showing 11 changed files with 140 additions and 40 deletions.
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ Let's see a very simple example that sums 2 integers:
class Add(Component):
"""Add two numbers."""
# NOTE: type definition is optional, but it helps with the intellisense. ;)
class InputsTyping(OutputParams):
class InputsTyping(InputParams):
a: InputParam
b: InputParam

Expand Down Expand Up @@ -137,6 +137,15 @@ class Add(Component):
return ComponentState.OK
```

**Note** that `Component` can inherint from `nn.Module`. By default inherints from `object`.

To change the inheritance, before importing any other `limbus` module, set the `COMPONENT_TYPE` variable as:

```python
from limbus_config import config
config.COMPONENT_TYPE = "torch"
```

## Ecosystem

Limbus is a core technology to easily build different components and create generic pipelines. In the following list, you can find different examples
Expand Down
36 changes: 24 additions & 12 deletions examples/default_cmps.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Basic example with predefined cmps."""
import asyncio
from sys import version_info
import copy

import torch

Expand All @@ -25,15 +27,25 @@
t2.outputs.out >> stack.inputs.tensors.select(1)
stack.outputs.tensor >> show.inputs.inp

# create the pipeline and add its nodes
pipeline = Pipeline()
pipeline.add_nodes([c1, t1, t2, stack, show])

# run your pipeline (only one iteration, note that this pipeline can run forever)
print("Run with pipeline:")
pipeline.run(1)

# run 1 iteration using the asyncio loop
print("Run with loop:")
loop = asyncio.get_event_loop()
loop.run_until_complete(asyncio.gather(c1(), t1(), t2(), stack(), show()))
USING_PIPELINE = True
if USING_PIPELINE:
# run your pipeline (only one iteration, note that this pipeline can run forever)
print("Run with pipeline:")
# create the pipeline and add its nodes
pipeline = Pipeline()
pipeline.add_nodes([c1, t1, t2, stack, show])
pipeline.run(1)
else:
# run 1 iteration using the asyncio loop
print("Run with loop:")

async def f(): # noqa: D103
await asyncio.gather(c1(), t1(), t2(), stack(), show())

if version_info.minor < 10:
# for python <3.10 the loop must be run in this way to avoid creating a new loop.
loop = asyncio.get_event_loop()
loop.run_until_complete(f())
elif version_info.minor >= 10:
# for python >=3.10 the loop should be run in this way.
asyncio.run(f())
8 changes: 6 additions & 2 deletions examples/defining_cmps.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@
from typing import List, Any
import asyncio

from limbus.core import Component, InputParams, OutputParams, ComponentState, VerboseMode, OutputParam, InputParam
from limbus.core.pipeline import Pipeline
# If you want to change the limbus config you need to do it before importing any limbus module!!!
from limbus_config import config
config.COMPONENT_TYPE = "torch"

from limbus.core import Component, InputParams, OutputParams, ComponentState, OutputParam, InputParam # noqa: E402
from limbus.core import Pipeline, VerboseMode # noqa: E402


# define the components
Expand Down
31 changes: 26 additions & 5 deletions limbus/core/component.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,39 @@
"""Component definition."""
from __future__ import annotations
from abc import abstractmethod
from typing import List, Optional, TYPE_CHECKING, Callable
from typing import List, Optional, TYPE_CHECKING, Callable, Type, Union, Any, Coroutine
import logging
import asyncio
import traceback
import functools

import torch.nn as nn
try:
import torch.nn as nn
except ImportError:
pass

from limbus_config import config
from limbus.core.params import Params, InputParams, OutputParams
from limbus.core.states import ComponentState, ComponentStoppedError
# Note that Pipeline class cannot be imported to avoid circular dependencies.
if TYPE_CHECKING:
from limbus.core.pipeline import Pipeline


log = logging.getLogger(__name__)


base_class: Type = object
if config.COMPONENT_TYPE == "generic":
pass
elif config.COMPONENT_TYPE == "torch":
try:
base_class = nn.Module
except NameError:
log.error("Torch not installed. Using generic base class.")
else:
log.error("Invalid component type. Using generic base class.")


# this is a decorator that will determine how many iterations must be run
def iterations_manager(func: Callable) -> Callable:
"""Update the last iteration to be run by the component."""
Expand Down Expand Up @@ -86,7 +101,7 @@ def verbose(self, value: bool) -> None:
self._verbose = value


class Component(nn.Module):
class Component(base_class):
"""Base class to define a Limbus Component.
Args:
Expand All @@ -110,6 +125,12 @@ def __init__(self, name: str):
# Last execution to be run in the __call__ loop.
self._stopping_iteration: int = 0 # 0 means run forever

# method called in _run_with_hooks to execute the component forward method
self._run_forward: Callable[..., Coroutine[Any, Any, ComponentState]] = self.forward
if nn.Module in Component.__mro__:
# If the component inherits from nn.Module, the forward method is called by the __call__ method
self._run_forward = nn.Module.__call__

def init_from_component(self, ref_component: Component) -> None:
"""Init basic execution params from another component.
Expand Down Expand Up @@ -311,7 +332,7 @@ async def _run_with_hooks(self, *args, **kwargs) -> bool:
if len(self._inputs) == 0:
# RUNNING state is set once the input params are received, if there are not inputs the state is set here
self.set_state(ComponentState.RUNNING)
self.set_state(await super().__call__(*args, **kwargs)) # internally it calls the forward() method
self.set_state(await self._run_forward(*args, **kwargs))
except ComponentStoppedError as e:
self.set_state(e.state)
except Exception as e:
Expand Down
20 changes: 16 additions & 4 deletions limbus/core/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,25 @@
import contextlib

import typeguard
import torch

from limbus.core.states import ComponentState, ComponentStoppedError
from limbus.core import async_utils
# Note that Component class cannot be imported to avoid circular dependencies.
if TYPE_CHECKING:
from limbus.core.component import Component
from limbus.core.states import ComponentState, ComponentStoppedError
from limbus.core import async_utils

SUBSCRIPTABLE_TYPES: List[type] = []
try:
import torch
SUBSCRIPTABLE_TYPES.append(torch.Tensor)
except ImportError:
pass

try:
import numpy as np
SUBSCRIPTABLE_TYPES.append(np.ndarray)
except ImportError:
pass


class NoValue:
Expand Down Expand Up @@ -120,7 +132,7 @@ def _check_subscriptable(datatype: type) -> bool:
# mypy complaints in the case origin is NoneType
if is_abstract_seq or (not is_abstract and isinstance(origin(), typing.Iterable)): # type: ignore
if (len(datatype_args) == 1 or (len(datatype_args) == 2 and Ellipsis in datatype_args)):
if datatype_args[0] is torch.Tensor:
if datatype_args[0] in SUBSCRIPTABLE_TYPES:
return True
return False

Expand Down
24 changes: 12 additions & 12 deletions limbus/widgets/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@
import functools

try:
# NOTE: we import the cv2 & visdom modules here to avoid having it as a dependency
# NOTE: we import these modules here to avoid having it as a dependency
# for the whole project.
import cv2
import visdom
import torch
import kornia
import numpy as np
except ImportError:
pass

import torch
import kornia
import numpy as np

from limbus.core import Component
from limbus import widgets
Expand Down Expand Up @@ -121,7 +121,7 @@ def check_status(self) -> bool:
raise NotImplementedError

@abstractmethod
def show_image(self, component: Component, title: str, image: torch.Tensor):
def show_image(self, component: Component, title: str, image: "torch.Tensor"):
"""Show an image.
Args:
Expand All @@ -134,7 +134,7 @@ def show_image(self, component: Component, title: str, image: torch.Tensor):

@abstractmethod
def show_images(self, component: Component, title: str,
images: Union[torch.Tensor, List[torch.Tensor]],
images: Union["torch.Tensor", List["torch.Tensor"]],
nrow: Optional[int] = None
) -> None:
"""Show a batch of images.
Expand Down Expand Up @@ -203,7 +203,7 @@ def check_status(self) -> bool:

@is_enabled
@set_title
def show_image(self, component: Component, title: str, image: torch.Tensor) -> None:
def show_image(self, component: Component, title: str, image: "torch.Tensor") -> None:
"""Show an image.
Args:
Expand All @@ -219,7 +219,7 @@ def show_image(self, component: Component, title: str, image: torch.Tensor) -> N
@is_enabled
@set_title
def show_images(self, component: Component, title: str,
images: Union[torch.Tensor, List[torch.Tensor]],
images: Union["torch.Tensor", List["torch.Tensor"]],
nrow: Optional[int] = None
) -> None:
"""Show a batch of images.
Expand Down Expand Up @@ -269,7 +269,7 @@ def check_status(self) -> bool:

@is_enabled
@set_title
def show_image(self, component: Component, title: str, image: torch.Tensor):
def show_image(self, component: Component, title: str, image: "torch.Tensor"):
"""Show an image.
Args:
Expand All @@ -283,7 +283,7 @@ def show_image(self, component: Component, title: str, image: torch.Tensor):
@is_enabled
@set_title
def show_images(self, component: Component, title: str,
images: Union[torch.Tensor, List[torch.Tensor]],
images: Union["torch.Tensor", List["torch.Tensor"]],
nrow: Optional[int] = None
) -> None:
"""Show a batch of images.
Expand Down Expand Up @@ -326,7 +326,7 @@ def __init__(self) -> None:

@is_enabled
@set_title
def show_image(self, component: Component, title: str, image: torch.Tensor):
def show_image(self, component: Component, title: str, image: "torch.Tensor"):
"""Show an image.
Args:
Expand Down Expand Up @@ -357,7 +357,7 @@ def show_image(self, component: Component, title: str, image: torch.Tensor):
@is_enabled
@set_title
def show_images(self, component: Component, title: str,
images: Union[torch.Tensor, List[torch.Tensor]],
images: Union["torch.Tensor", List["torch.Tensor"]],
nrow: Optional[int] = None
) -> None:
"""Show a batch of images.
Expand Down
Empty file added limbus_config/__init__.py
Empty file.
11 changes: 11 additions & 0 deletions limbus_config/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Configuration file for Limbus.
Usage:
Before importing any other Limbus module, set the COMPONENT_TYPE variable as:
> from limbus_config import config
> config.COMPONENT_TYPE = "torch"
"""

COMPONENT_TYPE = "generic" # generic or torch
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ ignore =
exclude = docs/src

[mypy]
files = limbus, tests
files = examples, limbus, tests, limbus_config
show_error_codes = True
ignore_missing_imports = True

Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@
author='Luis Ferraz',
url='https://github.com/kornia/limbus',
install_requires=[
'torch',
'numpy',
'typeguard',
'kornia',
],
extras_require={
'dev': [
Expand All @@ -30,6 +27,9 @@
'limbus-components'
],
'widgets': [
'kornia',
'torch',
'numpy',
'visdom',
'opencv-python',
]
Expand Down
31 changes: 31 additions & 0 deletions tests/config/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Config tests."""
import sys

from torch import nn


def remove_limbus_imports():
"""Remove limbus dependencies from sys.modules."""
for key in list(sys.modules.keys()):
if key.startswith("limbus"):
del sys.modules[key]


def test_torch_base_class():
remove_limbus_imports()
from limbus_config import config
config.COMPONENT_TYPE = "torch"
import limbus
mro = limbus.Component.__mro__
remove_limbus_imports()
assert len(mro) == 3
assert nn.Module in mro


def test_generic_base_class():
remove_limbus_imports()
import limbus
mro = limbus.Component.__mro__
remove_limbus_imports()
assert len(mro) == 2
assert nn.Module not in mro

0 comments on commit e555542

Please sign in to comment.