Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
shijianjian committed Sep 15, 2024
1 parent aa707a5 commit 1dfaf1a
Show file tree
Hide file tree
Showing 21 changed files with 111 additions and 41 deletions.
1 change: 1 addition & 0 deletions kornia/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
color,
contrib,
core,
config,
enhance,
feature,
io,
Expand Down
51 changes: 51 additions & 0 deletions kornia/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from enum import Enum
from dataclasses import dataclass, field

__all__ = ["config", "InstallationMode"]


class InstallationMode(str, Enum):
# Ask the user if to install the dependencies
ASK = "ASK"
# Install the dependencies
AUTO = "AUTO"
# Raise an error if the dependencies are not installed
RAISE = "RAISE"

def __eq__(self, other):
if isinstance(other, str):
return self.value.lower() == other.lower() # Case-insensitive comparison
return super().__eq__(other)


class LazyLoaderConfig:
_installation_mode: InstallationMode = InstallationMode.ASK

@property
def installation_mode(self) -> InstallationMode:
return self._installation_mode

@installation_mode.setter
def installation_mode(self, value: str):
# Allow setting via string by converting to the Enum
if isinstance(value, str):
try:
self._installation_mode = InstallationMode(value.upper())
except ValueError:
raise ValueError(f"{value} is not a valid InstallationMode. Choose from: {list(InstallationMode)}")
elif isinstance(value, InstallationMode):
self._installation_mode = value
else:
raise TypeError("installation_mode must be a string or InstallationMode Enum.")


@dataclass
class KorniaConfig:
output_dir: str = "kornia_outputs"
hub_cache_dir: str = ".kornia_hub"
hub_models_dir: str = ".kornia_hub/models"
hub_onnx_dir: str = ".kornia_hub/onnx_models"
lazyloader: LazyLoaderConfig = field(default_factory=LazyLoaderConfig)


kornia_config = KorniaConfig()
2 changes: 1 addition & 1 deletion kornia/contrib/models/rt_detr/post_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch

from kornia.core import Module, Tensor, concatenate, tensor
from kornia.models.detector.utils import BoxFiltering
from kornia.models.detection.utils import BoxFiltering


def mod(a: Tensor, b: int) -> Tensor:
Expand Down
10 changes: 5 additions & 5 deletions kornia/contrib/object_detection.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import warnings

from kornia.models.detector.base import (
from kornia.models.detection.base import (
BoundingBox as BoundingBoxBase,
)
from kornia.models.detector.base import (
from kornia.models.detection.base import (
BoundingBoxDataFormat,
)
from kornia.models.detector.base import (
from kornia.models.detection.base import (
ObjectDetector as ObjectDetectorBase,
)
from kornia.models.detector.base import (
from kornia.models.detection.base import (
ObjectDetectorResult as ObjectDetectorResultBase,
)
from kornia.models.detector.base import (
from kornia.models.detection.base import (
results_from_detections as results_from_detections_base,
)
from kornia.models.utils import ResizePreProcessor as ResizePreProcessorBase
Expand Down
39 changes: 27 additions & 12 deletions kornia/core/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from types import ModuleType
from typing import List, Optional

from kornia.config import kornia_config, InstallationMode

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -45,23 +47,36 @@ def _load(self) -> None:
try:
self.module = importlib.import_module(self.module_name)
except ImportError as e:
if self.auto_install:
if kornia_config.lazyloader.installation_mode == InstallationMode.AUTO or self.auto_install:
self._install_package(self.module_name)
else:
elif kornia_config.lazyloader.installation_mode == InstallationMode.ASK:
to_ask = True
if_install = input(
f"Optional dependency '{self.module_name}' is not installed. "
"You may silent this prompt by `kornia_config.lazyloader.installation_mode = 'auto'`. "
"Do you wish to install the dependency? [Y]es, [N]o, [A]ll."
)
if if_install.lower() == "y":
self._install_package(self.module_name)
elif if_install.lower() == "a":
self.auto_install = True
self._install_package(self.module_name)
else:
raise ImportError(
f"Optional dependency '{self.module_name}' is not installed. "
f"Please install it to use this functionality."
) from e
while to_ask:
if if_install.lower() == "y" or if_install.lower() == "yes":
self._install_package(self.module_name)
to_ask = False
elif if_install.lower() == "a" or if_install.lower() == "all":
self.auto_install = True
self._install_package(self.module_name)
to_ask = False
elif if_install.lower() == "n" or if_install.lower() == "no":
raise ImportError(
f"Optional dependency '{self.module_name}' is not installed. "
f"Please install it to use this functionality."
) from e
else:
if_install = input("Invalid input. Please enter 'Y', 'N', or 'A'.")

elif kornia_config.lazyloader.installation_mode == InstallationMode.RAISE:
raise ImportError(
f"Optional dependency '{self.module_name}' is not installed. "
f"Please install it to use this functionality."
) from e

def __getattr__(self, item: str) -> object:
"""Loads the module (if not already loaded) and returns the requested attribute.
Expand Down
2 changes: 1 addition & 1 deletion kornia/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from . import detector, segmentor, tracker
from . import detection, segmentation, tracking
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from kornia.core.external import numpy as np
from kornia.io import write_image
from kornia.utils.draw import draw_rectangle
from kornia.utils.image import tensor_to_image

__all__ = [
"BoundingBoxDataFormat",
Expand Down Expand Up @@ -156,7 +157,7 @@ def draw(
if output_type == "torch":
output.append(out_img[0])
elif output_type == "pil":
output.append(Image.fromarray((out_img[0] * 255).permute(1, 2, 0).numpy().astype(np.uint8))) # type: ignore
output.append(Image.fromarray((tensor_to_image(out_img[0]) * 255).astype(np.uint8))) # type: ignore
else:
raise RuntimeError(f"Unsupported output type `{output_type}`.")
return output
Expand All @@ -171,8 +172,8 @@ def save(
n_row: Number of images displayed in each row of the grid.
"""
if directory is None:
name = f"detection-{datetime.datetime.now(tz=datetime.timezone.utc).strftime('%Y%m%d%H%M%S')!s}"
directory = os.path.join("Kornia_outputs", name)
name = f"detection_{datetime.datetime.now(tz=datetime.timezone.utc).strftime('%Y%m%d%H%M%S')!s}"
directory = os.path.join("kornia_outputs", name)
outputs = self.draw(images, detections)
os.makedirs(directory, exist_ok=True)
for i, out_image in enumerate(outputs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from kornia.contrib.models.rt_detr import DETRPostProcessor
from kornia.contrib.models.rt_detr.model import RTDETR, RTDETRConfig
from kornia.core import rand
from kornia.models.detector.base import ObjectDetector
from kornia.models.detection.base import ObjectDetector
from kornia.models.utils import ResizePreProcessor

__all__ = ["RTDETRDetectorBuilder"]
Expand Down Expand Up @@ -128,10 +128,10 @@ def to_onnx(
if onnx_name is None:
_model_name = model_name
if model_name is None and config is not None:
_model_name = "rtdetr-customized"
_model_name = "rtdetr_customized"
elif model_name is None and config is None:
_model_name = "rtdetr_r18vd"
onnx_name = f"Kornia-RTDETR-{_model_name}-{image_size}.onnx"
onnx_name = f"kornia_{_model_name}_{image_size}.onnx"

if image_size is None:
val_image = rand(1, 3, 640, 640)
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@

from kornia.core import rand
from kornia.filters.dexined import DexiNed
from kornia.models.edge_detector.base import EdgeDetector
from kornia.models.edge_detection.base import EdgeDetector
from kornia.models.utils import ResizePostProcessor, ResizePreProcessor


class DexiNedBuilder:

@staticmethod
def build(pretrained: bool = True, image_size: Optional[int] = 352) -> EdgeDetector:
model = DexiNed(pretrained=pretrained)
return EdgeDetector(
model,
ResizePreProcessor((image_size, image_size)) if image_size is not None else nn.Identity(),
ResizePreProcessor(image_size, image_size) if image_size is not None else nn.Identity(),
ResizePostProcessor() if image_size is not None else nn.Identity(),
)

Expand All @@ -27,7 +28,7 @@ def to_onnx(
) -> Tuple[str, EdgeDetector]:
edge_detector = DexiNedBuilder.build(pretrained, image_size)
if onnx_name is None:
onnx_name = f"Kornia-DexiNed-{image_size}.onnx"
onnx_name = f"kornia_dexined_{image_size}.onnx"

if image_size is None:
val_image = rand(1, 3, 352, 352)
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from kornia.core import Tensor
from kornia.core.external import boxmot
from kornia.core.external import numpy as np
from kornia.models.detector.base import ObjectDetector
from kornia.models.detector.rtdetr import RTDETRDetectorBuilder
from kornia.models.detection.base import ObjectDetector
from kornia.models.detection.rtdetr import RTDETRDetectorBuilder
from kornia.utils.image import tensor_to_image

__all__ = ["BoxMotTracker"]
Expand Down
8 changes: 4 additions & 4 deletions kornia/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@ class ResizePreProcessor(Module):
Additionally, also returns the original image sizes for further post-processing.
"""

def __init__(self, size: tuple[int, int], interpolation_mode: str = "bilinear") -> None:
def __init__(self, height: int, width: int, interpolation_mode: str = "bilinear") -> None:
"""
Args:
size: images will be resized to this value. If a 2-integer tuple is given, it is interpreted as
(height, width).
height: height of the resized image.
width: width of the resized image.
interpolation_mode: interpolation mode for image resizing. Supported values: ``nearest``, ``bilinear``,
``bicubic``, ``area``, and ``nearest-exact``.
"""
super().__init__()
self.size = size
self.size = (height, width)
self.interpolation_mode = interpolation_mode

def forward(self, imgs: Union[Tensor, list[Tensor]]) -> tuple[Tensor, Tensor]:
Expand Down
5 changes: 3 additions & 2 deletions kornia/onnx/sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
from kornia.core.external import numpy as np
from kornia.core.external import onnx
from kornia.core.external import onnxruntime as ort
from kornia.config import kornia_config

from .utils import ONNXLoader

__all__ = ["ONNXSequential", "load"]


class ONNXSequential:
"""ONNXSequential to chain multiple ONNX operators together.
f"""ONNXSequential to chain multiple ONNX operators together.
Args:
*args: A variable number of ONNX models (either ONNX ModelProto objects or file paths).
Expand All @@ -24,7 +25,7 @@ class ONNXSequential:
only one input and output node for each graph.
If not None, `io_maps[0]` shall represent the `io_map` for combining the first and second ONNX models.
cache_dir: The directory where ONNX models are cached locally (only for downloading from HuggingFace).
Defaults to None, which will use a default `.kornia_hub/onnx_models` directory.
Defaults to None, which will use a default `{kornia_config.hub_onnx_dir}` directory.
"""

def __init__(
Expand Down
10 changes: 5 additions & 5 deletions kornia/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@
import requests

from kornia.core.external import onnx

from kornia.config import kornia_config
__all__ = ["ONNXLoader"]

logger = logging.getLogger(__name__)


class ONNXLoader:
"""Manages ONNX models, handling local caching, downloading from Hugging Face, and loading models.
f"""Manages ONNX models, handling local caching, downloading from Hugging Face, and loading models.
Attributes:
cache_dir: The directory where ONNX models are cached locally.
Defaults to None, which will use a default `.kornia_hub/onnx_models` directory.
Defaults to None, which will use a default `{kornia_config.hub_onnx_dir}` directory.
"""

def __init__(self, cache_dir: Optional[str] = None):
Expand All @@ -39,7 +39,7 @@ def _get_file_path(self, model_name: str, cache_dir: Optional[str]) -> str:
if self.cache_dir is not None:
cache_dir = self.cache_dir
else:
cache_dir = ".kornia_hub/onnx_models"
cache_dir = kornia_config.hub_onnx_dir

# The filename is the model name (without directory path)
file_name = f"{model_name.split('/')[-1]}.onnx"
Expand All @@ -61,7 +61,7 @@ def load_model(self, model_name: str, download: bool = True, **kwargs) -> "onnx.
"""
if model_name.startswith("hf://"):
model_name = model_name[len("hf://") :]
cache_dir = kwargs.get("cache_dir", None) or self.cache_dir
cache_dir = kwargs.get(kornia_config.hub_onnx_dir, None) or self.cache_dir
file_path = self._get_file_path(model_name, cache_dir)
if not os.path.exists(file_path):
# Construct the raw URL for the ONNX file
Expand Down

0 comments on commit 1dfaf1a

Please sign in to comment.