From 1dfaf1a71205f8011f1a0913053f90088d510af6 Mon Sep 17 00:00:00 2001 From: shijianjian Date: Sun, 15 Sep 2024 16:19:06 +0300 Subject: [PATCH] update --- kornia/__init__.py | 1 + kornia/config.py | 51 +++++++++++++++++++ .../contrib/models/rt_detr/post_processor.py | 2 +- kornia/contrib/object_detection.py | 10 ++-- kornia/core/external.py | 39 +++++++++----- kornia/models/__init__.py | 2 +- .../{detector => detection}/__init__.py | 0 kornia/models/{detector => detection}/base.py | 7 +-- .../models/{detector => detection}/rtdetr.py | 6 +-- .../models/{detector => detection}/utils.py | 0 .../__init__.py | 0 .../{edge_detector => edge_detection}/base.py | 0 .../dexined.py | 7 +-- .../{segmentor => segmentation}/__init__.py | 0 .../{segmentor => segmentation}/base.py | 0 .../segmentation_models.py | 0 .../models/{tracker => tracking}/__init__.py | 0 .../{tracker => tracking}/boxmot_tracker.py | 4 +- kornia/models/utils.py | 8 +-- kornia/onnx/sequential.py | 5 +- kornia/onnx/utils.py | 10 ++-- 21 files changed, 111 insertions(+), 41 deletions(-) create mode 100644 kornia/config.py rename kornia/models/{detector => detection}/__init__.py (100%) rename kornia/models/{detector => detection}/base.py (95%) rename kornia/models/{detector => detection}/rtdetr.py (97%) rename kornia/models/{detector => detection}/utils.py (100%) rename kornia/models/{edge_detector => edge_detection}/__init__.py (100%) rename kornia/models/{edge_detector => edge_detection}/base.py (100%) rename kornia/models/{edge_detector => edge_detection}/dexined.py (86%) rename kornia/models/{segmentor => segmentation}/__init__.py (100%) rename kornia/models/{segmentor => segmentation}/base.py (100%) rename kornia/models/{segmentor => segmentation}/segmentation_models.py (100%) rename kornia/models/{tracker => tracking}/__init__.py (100%) rename kornia/models/{tracker => tracking}/boxmot_tracker.py (97%) diff --git a/kornia/__init__.py b/kornia/__init__.py index 9bdc9c1d66..b1b9782747 100644 --- a/kornia/__init__.py +++ b/kornia/__init__.py @@ -10,6 +10,7 @@ color, contrib, core, + config, enhance, feature, io, diff --git a/kornia/config.py b/kornia/config.py new file mode 100644 index 0000000000..857a566e28 --- /dev/null +++ b/kornia/config.py @@ -0,0 +1,51 @@ +from enum import Enum +from dataclasses import dataclass, field + +__all__ = ["config", "InstallationMode"] + + +class InstallationMode(str, Enum): + # Ask the user if to install the dependencies + ASK = "ASK" + # Install the dependencies + AUTO = "AUTO" + # Raise an error if the dependencies are not installed + RAISE = "RAISE" + + def __eq__(self, other): + if isinstance(other, str): + return self.value.lower() == other.lower() # Case-insensitive comparison + return super().__eq__(other) + + +class LazyLoaderConfig: + _installation_mode: InstallationMode = InstallationMode.ASK + + @property + def installation_mode(self) -> InstallationMode: + return self._installation_mode + + @installation_mode.setter + def installation_mode(self, value: str): + # Allow setting via string by converting to the Enum + if isinstance(value, str): + try: + self._installation_mode = InstallationMode(value.upper()) + except ValueError: + raise ValueError(f"{value} is not a valid InstallationMode. Choose from: {list(InstallationMode)}") + elif isinstance(value, InstallationMode): + self._installation_mode = value + else: + raise TypeError("installation_mode must be a string or InstallationMode Enum.") + + +@dataclass +class KorniaConfig: + output_dir: str = "kornia_outputs" + hub_cache_dir: str = ".kornia_hub" + hub_models_dir: str = ".kornia_hub/models" + hub_onnx_dir: str = ".kornia_hub/onnx_models" + lazyloader: LazyLoaderConfig = field(default_factory=LazyLoaderConfig) + + +kornia_config = KorniaConfig() diff --git a/kornia/contrib/models/rt_detr/post_processor.py b/kornia/contrib/models/rt_detr/post_processor.py index ae9c4ac791..0df2c1190c 100644 --- a/kornia/contrib/models/rt_detr/post_processor.py +++ b/kornia/contrib/models/rt_detr/post_processor.py @@ -7,7 +7,7 @@ import torch from kornia.core import Module, Tensor, concatenate, tensor -from kornia.models.detector.utils import BoxFiltering +from kornia.models.detection.utils import BoxFiltering def mod(a: Tensor, b: int) -> Tensor: diff --git a/kornia/contrib/object_detection.py b/kornia/contrib/object_detection.py index 2917834368..c8c4e336e5 100644 --- a/kornia/contrib/object_detection.py +++ b/kornia/contrib/object_detection.py @@ -1,18 +1,18 @@ import warnings -from kornia.models.detector.base import ( +from kornia.models.detection.base import ( BoundingBox as BoundingBoxBase, ) -from kornia.models.detector.base import ( +from kornia.models.detection.base import ( BoundingBoxDataFormat, ) -from kornia.models.detector.base import ( +from kornia.models.detection.base import ( ObjectDetector as ObjectDetectorBase, ) -from kornia.models.detector.base import ( +from kornia.models.detection.base import ( ObjectDetectorResult as ObjectDetectorResultBase, ) -from kornia.models.detector.base import ( +from kornia.models.detection.base import ( results_from_detections as results_from_detections_base, ) from kornia.models.utils import ResizePreProcessor as ResizePreProcessorBase diff --git a/kornia/core/external.py b/kornia/core/external.py index de1544f7ca..39c0b789d1 100644 --- a/kornia/core/external.py +++ b/kornia/core/external.py @@ -5,6 +5,8 @@ from types import ModuleType from typing import List, Optional +from kornia.config import kornia_config, InstallationMode + logger = logging.getLogger(__name__) @@ -45,23 +47,36 @@ def _load(self) -> None: try: self.module = importlib.import_module(self.module_name) except ImportError as e: - if self.auto_install: + if kornia_config.lazyloader.installation_mode == InstallationMode.AUTO or self.auto_install: self._install_package(self.module_name) - else: + elif kornia_config.lazyloader.installation_mode == InstallationMode.ASK: + to_ask = True if_install = input( f"Optional dependency '{self.module_name}' is not installed. " + "You may silent this prompt by `kornia_config.lazyloader.installation_mode = 'auto'`. " "Do you wish to install the dependency? [Y]es, [N]o, [A]ll." ) - if if_install.lower() == "y": - self._install_package(self.module_name) - elif if_install.lower() == "a": - self.auto_install = True - self._install_package(self.module_name) - else: - raise ImportError( - f"Optional dependency '{self.module_name}' is not installed. " - f"Please install it to use this functionality." - ) from e + while to_ask: + if if_install.lower() == "y" or if_install.lower() == "yes": + self._install_package(self.module_name) + to_ask = False + elif if_install.lower() == "a" or if_install.lower() == "all": + self.auto_install = True + self._install_package(self.module_name) + to_ask = False + elif if_install.lower() == "n" or if_install.lower() == "no": + raise ImportError( + f"Optional dependency '{self.module_name}' is not installed. " + f"Please install it to use this functionality." + ) from e + else: + if_install = input("Invalid input. Please enter 'Y', 'N', or 'A'.") + + elif kornia_config.lazyloader.installation_mode == InstallationMode.RAISE: + raise ImportError( + f"Optional dependency '{self.module_name}' is not installed. " + f"Please install it to use this functionality." + ) from e def __getattr__(self, item: str) -> object: """Loads the module (if not already loaded) and returns the requested attribute. diff --git a/kornia/models/__init__.py b/kornia/models/__init__.py index ec936c3a85..9117c87c71 100644 --- a/kornia/models/__init__.py +++ b/kornia/models/__init__.py @@ -1 +1 @@ -from . import detector, segmentor, tracker +from . import detection, segmentation, tracking diff --git a/kornia/models/detector/__init__.py b/kornia/models/detection/__init__.py similarity index 100% rename from kornia/models/detector/__init__.py rename to kornia/models/detection/__init__.py diff --git a/kornia/models/detector/base.py b/kornia/models/detection/base.py similarity index 95% rename from kornia/models/detector/base.py rename to kornia/models/detection/base.py index 2edf978eca..744fc579d2 100644 --- a/kornia/models/detector/base.py +++ b/kornia/models/detection/base.py @@ -15,6 +15,7 @@ from kornia.core.external import numpy as np from kornia.io import write_image from kornia.utils.draw import draw_rectangle +from kornia.utils.image import tensor_to_image __all__ = [ "BoundingBoxDataFormat", @@ -156,7 +157,7 @@ def draw( if output_type == "torch": output.append(out_img[0]) elif output_type == "pil": - output.append(Image.fromarray((out_img[0] * 255).permute(1, 2, 0).numpy().astype(np.uint8))) # type: ignore + output.append(Image.fromarray((tensor_to_image(out_img[0]) * 255).astype(np.uint8))) # type: ignore else: raise RuntimeError(f"Unsupported output type `{output_type}`.") return output @@ -171,8 +172,8 @@ def save( n_row: Number of images displayed in each row of the grid. """ if directory is None: - name = f"detection-{datetime.datetime.now(tz=datetime.timezone.utc).strftime('%Y%m%d%H%M%S')!s}" - directory = os.path.join("Kornia_outputs", name) + name = f"detection_{datetime.datetime.now(tz=datetime.timezone.utc).strftime('%Y%m%d%H%M%S')!s}" + directory = os.path.join("kornia_outputs", name) outputs = self.draw(images, detections) os.makedirs(directory, exist_ok=True) for i, out_image in enumerate(outputs): diff --git a/kornia/models/detector/rtdetr.py b/kornia/models/detection/rtdetr.py similarity index 97% rename from kornia/models/detector/rtdetr.py rename to kornia/models/detection/rtdetr.py index 96ead94a89..60cc296f75 100644 --- a/kornia/models/detector/rtdetr.py +++ b/kornia/models/detection/rtdetr.py @@ -7,7 +7,7 @@ from kornia.contrib.models.rt_detr import DETRPostProcessor from kornia.contrib.models.rt_detr.model import RTDETR, RTDETRConfig from kornia.core import rand -from kornia.models.detector.base import ObjectDetector +from kornia.models.detection.base import ObjectDetector from kornia.models.utils import ResizePreProcessor __all__ = ["RTDETRDetectorBuilder"] @@ -128,10 +128,10 @@ def to_onnx( if onnx_name is None: _model_name = model_name if model_name is None and config is not None: - _model_name = "rtdetr-customized" + _model_name = "rtdetr_customized" elif model_name is None and config is None: _model_name = "rtdetr_r18vd" - onnx_name = f"Kornia-RTDETR-{_model_name}-{image_size}.onnx" + onnx_name = f"kornia_{_model_name}_{image_size}.onnx" if image_size is None: val_image = rand(1, 3, 640, 640) diff --git a/kornia/models/detector/utils.py b/kornia/models/detection/utils.py similarity index 100% rename from kornia/models/detector/utils.py rename to kornia/models/detection/utils.py diff --git a/kornia/models/edge_detector/__init__.py b/kornia/models/edge_detection/__init__.py similarity index 100% rename from kornia/models/edge_detector/__init__.py rename to kornia/models/edge_detection/__init__.py diff --git a/kornia/models/edge_detector/base.py b/kornia/models/edge_detection/base.py similarity index 100% rename from kornia/models/edge_detector/base.py rename to kornia/models/edge_detection/base.py diff --git a/kornia/models/edge_detector/dexined.py b/kornia/models/edge_detection/dexined.py similarity index 86% rename from kornia/models/edge_detector/dexined.py rename to kornia/models/edge_detection/dexined.py index 841f1f1b74..64e1bf4139 100644 --- a/kornia/models/edge_detector/dexined.py +++ b/kornia/models/edge_detection/dexined.py @@ -5,17 +5,18 @@ from kornia.core import rand from kornia.filters.dexined import DexiNed -from kornia.models.edge_detector.base import EdgeDetector +from kornia.models.edge_detection.base import EdgeDetector from kornia.models.utils import ResizePostProcessor, ResizePreProcessor class DexiNedBuilder: + @staticmethod def build(pretrained: bool = True, image_size: Optional[int] = 352) -> EdgeDetector: model = DexiNed(pretrained=pretrained) return EdgeDetector( model, - ResizePreProcessor((image_size, image_size)) if image_size is not None else nn.Identity(), + ResizePreProcessor(image_size, image_size) if image_size is not None else nn.Identity(), ResizePostProcessor() if image_size is not None else nn.Identity(), ) @@ -27,7 +28,7 @@ def to_onnx( ) -> Tuple[str, EdgeDetector]: edge_detector = DexiNedBuilder.build(pretrained, image_size) if onnx_name is None: - onnx_name = f"Kornia-DexiNed-{image_size}.onnx" + onnx_name = f"kornia_dexined_{image_size}.onnx" if image_size is None: val_image = rand(1, 3, 352, 352) diff --git a/kornia/models/segmentor/__init__.py b/kornia/models/segmentation/__init__.py similarity index 100% rename from kornia/models/segmentor/__init__.py rename to kornia/models/segmentation/__init__.py diff --git a/kornia/models/segmentor/base.py b/kornia/models/segmentation/base.py similarity index 100% rename from kornia/models/segmentor/base.py rename to kornia/models/segmentation/base.py diff --git a/kornia/models/segmentor/segmentation_models.py b/kornia/models/segmentation/segmentation_models.py similarity index 100% rename from kornia/models/segmentor/segmentation_models.py rename to kornia/models/segmentation/segmentation_models.py diff --git a/kornia/models/tracker/__init__.py b/kornia/models/tracking/__init__.py similarity index 100% rename from kornia/models/tracker/__init__.py rename to kornia/models/tracking/__init__.py diff --git a/kornia/models/tracker/boxmot_tracker.py b/kornia/models/tracking/boxmot_tracker.py similarity index 97% rename from kornia/models/tracker/boxmot_tracker.py rename to kornia/models/tracking/boxmot_tracker.py index 4cc8be105e..a9feeb25df 100644 --- a/kornia/models/tracker/boxmot_tracker.py +++ b/kornia/models/tracking/boxmot_tracker.py @@ -5,8 +5,8 @@ from kornia.core import Tensor from kornia.core.external import boxmot from kornia.core.external import numpy as np -from kornia.models.detector.base import ObjectDetector -from kornia.models.detector.rtdetr import RTDETRDetectorBuilder +from kornia.models.detection.base import ObjectDetector +from kornia.models.detection.rtdetr import RTDETRDetectorBuilder from kornia.utils.image import tensor_to_image __all__ = ["BoxMotTracker"] diff --git a/kornia/models/utils.py b/kornia/models/utils.py index 08676160c2..f0f4931531 100644 --- a/kornia/models/utils.py +++ b/kornia/models/utils.py @@ -14,16 +14,16 @@ class ResizePreProcessor(Module): Additionally, also returns the original image sizes for further post-processing. """ - def __init__(self, size: tuple[int, int], interpolation_mode: str = "bilinear") -> None: + def __init__(self, height: int, width: int, interpolation_mode: str = "bilinear") -> None: """ Args: - size: images will be resized to this value. If a 2-integer tuple is given, it is interpreted as - (height, width). + height: height of the resized image. + width: width of the resized image. interpolation_mode: interpolation mode for image resizing. Supported values: ``nearest``, ``bilinear``, ``bicubic``, ``area``, and ``nearest-exact``. """ super().__init__() - self.size = size + self.size = (height, width) self.interpolation_mode = interpolation_mode def forward(self, imgs: Union[Tensor, list[Tensor]]) -> tuple[Tensor, Tensor]: diff --git a/kornia/onnx/sequential.py b/kornia/onnx/sequential.py index c934a0550e..155fada053 100644 --- a/kornia/onnx/sequential.py +++ b/kornia/onnx/sequential.py @@ -3,6 +3,7 @@ from kornia.core.external import numpy as np from kornia.core.external import onnx from kornia.core.external import onnxruntime as ort +from kornia.config import kornia_config from .utils import ONNXLoader @@ -10,7 +11,7 @@ class ONNXSequential: - """ONNXSequential to chain multiple ONNX operators together. + f"""ONNXSequential to chain multiple ONNX operators together. Args: *args: A variable number of ONNX models (either ONNX ModelProto objects or file paths). @@ -24,7 +25,7 @@ class ONNXSequential: only one input and output node for each graph. If not None, `io_maps[0]` shall represent the `io_map` for combining the first and second ONNX models. cache_dir: The directory where ONNX models are cached locally (only for downloading from HuggingFace). - Defaults to None, which will use a default `.kornia_hub/onnx_models` directory. + Defaults to None, which will use a default `{kornia_config.hub_onnx_dir}` directory. """ def __init__( diff --git a/kornia/onnx/utils.py b/kornia/onnx/utils.py index 98cd774f30..669cbce79a 100644 --- a/kornia/onnx/utils.py +++ b/kornia/onnx/utils.py @@ -7,18 +7,18 @@ import requests from kornia.core.external import onnx - +from kornia.config import kornia_config __all__ = ["ONNXLoader"] logger = logging.getLogger(__name__) class ONNXLoader: - """Manages ONNX models, handling local caching, downloading from Hugging Face, and loading models. + f"""Manages ONNX models, handling local caching, downloading from Hugging Face, and loading models. Attributes: cache_dir: The directory where ONNX models are cached locally. - Defaults to None, which will use a default `.kornia_hub/onnx_models` directory. + Defaults to None, which will use a default `{kornia_config.hub_onnx_dir}` directory. """ def __init__(self, cache_dir: Optional[str] = None): @@ -39,7 +39,7 @@ def _get_file_path(self, model_name: str, cache_dir: Optional[str]) -> str: if self.cache_dir is not None: cache_dir = self.cache_dir else: - cache_dir = ".kornia_hub/onnx_models" + cache_dir = kornia_config.hub_onnx_dir # The filename is the model name (without directory path) file_name = f"{model_name.split('/')[-1]}.onnx" @@ -61,7 +61,7 @@ def load_model(self, model_name: str, download: bool = True, **kwargs) -> "onnx. """ if model_name.startswith("hf://"): model_name = model_name[len("hf://") :] - cache_dir = kwargs.get("cache_dir", None) or self.cache_dir + cache_dir = kwargs.get(kornia_config.hub_onnx_dir, None) or self.cache_dir file_path = self._get_file_path(model_name, cache_dir) if not os.path.exists(file_path): # Construct the raw URL for the ONNX file