diff --git a/kornia/feature/sold2/sold2.py b/kornia/feature/sold2/sold2.py index 8958064c0a..76e3be406e 100644 --- a/kornia/feature/sold2/sold2.py +++ b/kornia/feature/sold2/sold2.py @@ -9,7 +9,7 @@ from kornia.utils import map_location_to_cpu from .backbones import SOLD2Net -from .sold2_detector import LineSegmentDetectionModule, line_map_to_segments, prob_to_junctions +from .sold2_detector import LineDetectorCfg, LineSegmentDetectionModule, line_map_to_segments, prob_to_junctions urls: Dict[str, str] = {} urls["wireframe"] = "http://cmp.felk.cvut.cz/~mishkdmy/models/sold2_wireframe.pth" @@ -22,23 +22,7 @@ "keep_border_valid": True, "detection_thresh": 0.0153846, # = 1/65: threshold of junction detection "max_num_junctions": 500, # maximum number of junctions per image - "line_detector_cfg": { - "detect_thresh": 0.5, - "num_samples": 64, - "inlier_thresh": 0.99, - "use_candidate_suppression": True, - "nms_dist_tolerance": 3.0, - "use_heatmap_refinement": True, - "heatmap_refine_cfg": { - "mode": "local", - "ratio": 0.2, - "valid_thresh": 0.001, - "num_blocks": 20, - "overlap_ratio": 0.5, - }, - "use_junction_refinement": True, - "junction_refine_cfg": {"num_perturbs": 9, "perturb_interval": 0.25}, - }, + "line_detector_cfg": LineDetectorCfg(), "line_matcher_cfg": { "cross_check": True, "num_samples": 5, @@ -92,8 +76,7 @@ def __init__(self, pretrained: bool = True, config: Optional[Dict[str, Any]] = N self.eval() # Initialize the line detector - self.line_detector_cfg = self.config["line_detector_cfg"] - self.line_detector = LineSegmentDetectionModule(**self.config["line_detector_cfg"]) + self.line_detector = LineSegmentDetectionModule(LineDetectorCfg()) # Initialize the line matcher self.line_matcher = WunschLineMatcher(**self.config["line_matcher_cfg"]) diff --git a/kornia/feature/sold2/sold2_detector.py b/kornia/feature/sold2/sold2_detector.py index 60e97822e9..61ed8280d7 100644 --- a/kornia/feature/sold2/sold2_detector.py +++ b/kornia/feature/sold2/sold2_detector.py @@ -1,4 +1,6 @@ import math +import warnings +from dataclasses import dataclass, field from typing import Any, Dict, Optional, Tuple import torch @@ -6,7 +8,7 @@ from kornia.core import Module, Tensor, concatenate, sin, stack, tensor, where, zeros from kornia.core.check import KORNIA_CHECK_SHAPE from kornia.geometry.bbox import nms -from kornia.utils import map_location_to_cpu, torch_meshgrid +from kornia.utils import dataclass_to_dict, dict_to_dataclass, map_location_to_cpu, torch_meshgrid from .backbones import SOLD2Net @@ -14,31 +16,56 @@ urls["wireframe"] = "https://www.polybox.ethz.ch/index.php/s/blOrW89gqSLoHOk/download" -default_detector_cfg = { - "backbone_cfg": {"input_channel": 1, "depth": 4, "num_stacks": 2, "num_blocks": 1, "num_classes": 5}, - "use_descriptor": False, - "grid_size": 8, - "keep_border_valid": True, - "detection_thresh": 0.0153846, # = 1/65: threshold of junction detection - "max_num_junctions": 500, # maximum number of junctions per image - "line_detector_cfg": { - "detect_thresh": 0.5, - "num_samples": 64, - "inlier_thresh": 0.99, - "use_candidate_suppression": True, - "nms_dist_tolerance": 3.0, - "use_heatmap_refinement": True, - "heatmap_refine_cfg": { - "mode": "local", - "ratio": 0.2, - "valid_thresh": 0.001, - "num_blocks": 20, - "overlap_ratio": 0.5, - }, - "use_junction_refinement": True, - "junction_refine_cfg": {"num_perturbs": 9, "perturb_interval": 0.25}, - }, -} +@dataclass +class HeatMapRefineCfg: + mode: str = "local" + ratio: float = 0.2 + valid_thresh: float = 0.001 + num_blocks: int = 20 + overlap_ratio: float = 0.5 + + +@dataclass +class JunctionRefineCfg: + num_perturbs: int = 9 + perturb_interval: float = 0.25 + + +@dataclass +class LineDetectorCfg: + detect_thresh: float = 0.5 + num_samples: int = 64 + inlier_thresh: float = 0.99 + use_candidate_suppression: bool = True + nms_dist_tolerance: float = 3.0 + heatmap_low_thresh: float = 0.15 + heatmap_high_thresh: float = 0.2 + max_local_patch_radius: float = 3 + lambda_radius: float = 2.0 + use_heatmap_refinement: bool = True + heatmap_refine_cfg: HeatMapRefineCfg = field(default_factory=HeatMapRefineCfg) + use_junction_refinement: bool = True + junction_refine_cfg: JunctionRefineCfg = field(default_factory=JunctionRefineCfg) + + +@dataclass +class BackboneCfg: + input_channel: int = 1 + depth: int = 4 + num_stacks: int = 2 + num_blocks: int = 1 + num_classes: int = 5 + + +@dataclass +class DetectorCfg: + backbone_cfg: BackboneCfg = field(default_factory=BackboneCfg) + use_descriptor: bool = False + grid_size: int = 8 + keep_border_valid: bool = True + detection_thresh: float = 0.0153846 # = 1/65: threshold of junction detection + max_num_junctions: int = 500 # maximum number of junctions per image + line_detector_cfg: LineDetectorCfg = field(default_factory=LineDetectorCfg) class SOLD2_detector(Module): @@ -48,9 +75,10 @@ class SOLD2_detector(Module): Occlusion-aware Line Detector and Descriptor". See :cite:`SOLD22021` for more details. Args: - config: Dict specifying parameters. None will load the default parameters, - which are tuned for images in the range 400~800 px. - pretrained: If True, download and set pretrained weights to the model. + config (DetectorCfg): Configuration object containing all parameters. None will load the default parameters, + which are tuned for images in the range 400~800 px. Using a dataclass ensures type safety and clearer + parameter management. + pretrained (bool): If True, download and set pretrained weights to the model. Returns: The raw junction and line heatmaps, as well as the list of detected line segments (ij coordinates convention). @@ -61,25 +89,34 @@ class SOLD2_detector(Module): >>> line_segments = sold2_detector(img)["line_segments"] """ - def __init__(self, pretrained: bool = True, config: Optional[Dict[str, Any]] = None) -> None: + def __init__(self, pretrained: bool = True, config: Optional[DetectorCfg] = None) -> None: + if isinstance(config, dict): + warnings.warn( + "Usage of config as a plain dictionary is deprecated in favor of" + " `kornia.feature.sold2.sold2_detector.DetectorCfg`. The support of plain dictionaries" + "as config will be removed in kornia v0.8.0 (December 2024).", + category=DeprecationWarning, + stacklevel=2, + ) + config = dict_to_dataclass(config, DetectorCfg) super().__init__() # Initialize some parameters - self.config = default_detector_cfg if config is None else config - self.grid_size = self.config["grid_size"] - self.junc_detect_thresh = self.config.get("detection_thresh", 1 / 65) - self.max_num_junctions = self.config.get("max_num_junctions", 500) + self.config = config if config is not None else DetectorCfg() + self.grid_size = self.config.grid_size + self.junc_detect_thresh = self.config.detection_thresh + self.max_num_junctions = self.config.max_num_junctions # Load the pre-trained model - self.model = SOLD2Net(self.config) + self.model = SOLD2Net(dataclass_to_dict(self.config)) + if pretrained: pretrained_dict = torch.hub.load_state_dict_from_url(urls["wireframe"], map_location=map_location_to_cpu) state_dict = self.adapt_state_dict(pretrained_dict["model_state_dict"]) self.model.load_state_dict(state_dict) self.eval() - # Initialize the line detector - self.line_detector_cfg = self.config["line_detector_cfg"] - self.line_detector = LineSegmentDetectionModule(**self.config["line_detector_cfg"]) + # Initialize the line detector with a configuration from the dataclass + self.line_detector = LineSegmentDetectionModule(self.config.line_detector_cfg) def adapt_state_dict(self, state_dict: Dict[str, Any]) -> Dict[str, Any]: del state_dict["w_junc"] @@ -127,66 +164,62 @@ class LineSegmentDetectionModule: r"""Module extracting line segments from junctions and line heatmaps. Args: - detect_thresh: The probability threshold for mean activation (0. ~ 1.) - num_samples: Number of sampling locations along the line segments. - inlier_thresh: The min inlier ratio to satisfy (0. ~ 1.) => 0. means no threshold. - heatmap_low_thresh: The lowest threshold for the pixel to be considered as candidate in junction recovery. - heatmap_high_thresh: The higher threshold for NMS in junction recovery. - max_local_patch_radius: The max patch to be considered in local maximum search. - lambda_radius: The lambda factor in linear local maximum search formulation - use_candidate_suppression: Apply candidate suppression to break long segments into short sub-segments. - nms_dist_tolerance: The distance tolerance for nms. Decide whether the junctions are on the line. - use_heatmap_refinement: Use heatmap refinement method or not. - heatmap_refine_cfg: The configs for heatmap refinement methods. - use_junction_refinement: Use junction refinement method or not. - junction_refine_cfg: The configs for junction refinement methods. + config (LineDetectorCfg): Configuration dataclass containing all settings required for line segment detection. + - detect_thresh (float): Probability threshold for mean activation (0. ~ 1.). + - num_samples (int): Number of sampling locations along the line segments. + - inlier_thresh (float): Minimum inlier ratio to satisfy (0. ~ 1.) => 0. means no threshold. + - heatmap_low_thresh (float): Lowest threshold for pixel considered as a candidate in junction recovery. + - heatmap_high_thresh (float): Higher threshold for NMS in junction recovery. + - max_local_patch_radius (float): Maximum patch to be considered in local maximum search. + - lambda_radius (float): Lambda factor in linear local maximum search formulation. + - use_candidate_suppression (bool): Apply candidate suppression to break long segments into sub-segments. + - nms_dist_tolerance (float): Distance tolerance for NMS. Decides whether the junctions are on the line. + - use_heatmap_refinement (bool): Whether to use heatmap refinement methods. + - heatmap_refine_cfg: Configuration for heatmap refinement methods. + - use_junction_refinement (bool): Whether to use junction refinement methods. + - junction_refine_cfg: Configuration for junction refinement methods. + + Example: + >>> config = LineDetectorCfg(detect_thresh=0.6, use_heatmap_refinement=True) + >>> module = LineSegmentDetectionModule(config) + >>> junctions, heatmap = torch.rand(10, 2), torch.rand(256, 256) + >>> line_map, junctions, _ = module.detect(junctions, heatmap) """ - def __init__( - self, - detect_thresh: float, - num_samples: int = 64, - inlier_thresh: float = 0.0, - heatmap_low_thresh: float = 0.15, - heatmap_high_thresh: float = 0.2, - max_local_patch_radius: float = 3, - lambda_radius: float = 2.0, - use_candidate_suppression: bool = False, - nms_dist_tolerance: float = 3.0, - use_heatmap_refinement: bool = False, - heatmap_refine_cfg: Optional[Dict[str, Any]] = None, - use_junction_refinement: bool = False, - junction_refine_cfg: Optional[Dict[str, Any]] = None, - ) -> None: + def __init__(self, config: LineDetectorCfg = LineDetectorCfg()) -> None: + # Load LineDetectorCfg + self.config = config + # Line detection parameters - self.detect_thresh = detect_thresh + self.detect_thresh = self.config.detect_thresh + # self.detect_thresh = detect_thresh # Line sampling parameters - self.num_samples = num_samples - self.inlier_thresh = inlier_thresh - self.local_patch_radius = max_local_patch_radius - self.lambda_radius = lambda_radius + self.num_samples = self.config.num_samples + self.inlier_thresh = self.config.inlier_thresh + self.local_patch_radius = self.config.max_local_patch_radius + self.lambda_radius = self.config.lambda_radius # Detecting junctions on the boundary parameters - self.low_thresh = heatmap_low_thresh - self.high_thresh = heatmap_high_thresh + self.low_thresh = self.config.heatmap_low_thresh + self.high_thresh = self.config.heatmap_high_thresh # Pre-compute the linspace sampler self.torch_sampler = torch.linspace(0, 1, self.num_samples) # Long line segment suppression configuration - self.use_candidate_suppression = use_candidate_suppression - self.nms_dist_tolerance = nms_dist_tolerance + self.use_candidate_suppression = self.config.use_candidate_suppression + self.nms_dist_tolerance = self.config.nms_dist_tolerance # Heatmap refinement configuration - self.use_heatmap_refinement = use_heatmap_refinement - self.heatmap_refine_cfg = heatmap_refine_cfg + self.use_heatmap_refinement = self.config.use_heatmap_refinement + self.heatmap_refine_cfg = self.config.heatmap_refine_cfg if self.use_heatmap_refinement and self.heatmap_refine_cfg is None: raise ValueError("[Error] Missing heatmap refinement config.") # Junction refinement configuration - self.use_junction_refinement = use_junction_refinement - self.junction_refine_cfg = junction_refine_cfg + self.use_junction_refinement = self.config.use_junction_refinement + self.junction_refine_cfg = self.config.junction_refine_cfg if self.use_junction_refinement and self.junction_refine_cfg is None: raise ValueError("[Error] Missing junction refinement config.") @@ -197,18 +230,18 @@ def detect(self, junctions: Tensor, heatmap: Tensor) -> Tuple[Tensor, Tensor, Te device = junctions.device # Perform the heatmap refinement - if self.use_heatmap_refinement and isinstance(self.heatmap_refine_cfg, dict): - if self.heatmap_refine_cfg["mode"] == "global": + if self.use_heatmap_refinement and isinstance(self.heatmap_refine_cfg, HeatMapRefineCfg): + if self.heatmap_refine_cfg.mode == "global": heatmap = self.refine_heatmap( - heatmap, self.heatmap_refine_cfg["ratio"], self.heatmap_refine_cfg["valid_thresh"] + heatmap, self.heatmap_refine_cfg.ratio, self.heatmap_refine_cfg.valid_thresh ) - elif self.heatmap_refine_cfg["mode"] == "local": + elif self.heatmap_refine_cfg.mode == "local": heatmap = self.refine_heatmap_local( heatmap, - self.heatmap_refine_cfg["num_blocks"], - self.heatmap_refine_cfg["overlap_ratio"], - self.heatmap_refine_cfg["ratio"], - self.heatmap_refine_cfg["valid_thresh"], + self.heatmap_refine_cfg.num_blocks, + self.heatmap_refine_cfg.overlap_ratio, + self.heatmap_refine_cfg.ratio, + self.heatmap_refine_cfg.valid_thresh, ) # Initialize empty line map @@ -393,10 +426,13 @@ def refine_junction_perturb( ) -> Tuple[Tensor, Tensor]: """Refine the line endpoints in a similar way as in LSD.""" # Fetch refinement parameters - if not isinstance(self.junction_refine_cfg, dict): - raise TypeError(f"Expected to have a dict of config for junction. Gotcha {type(self.junction_refine_cfg)}") - num_perturbs = self.junction_refine_cfg["num_perturbs"] - perturb_interval = self.junction_refine_cfg["perturb_interval"] + if not isinstance(self.junction_refine_cfg, JunctionRefineCfg): + raise TypeError( + "Expected to have dataclass of type JunctionRefineCfg for junction." + f"Gotcha {type(self.junction_refine_cfg)}" + ) + num_perturbs = self.junction_refine_cfg.num_perturbs + perturb_interval = self.junction_refine_cfg.perturb_interval side_perturbs = (num_perturbs - 1) // 2 # Fetch the 2D perturb mat diff --git a/kornia/utils/__init__.py b/kornia/utils/__init__.py index 10efb4ee85..8fc674f6fb 100644 --- a/kornia/utils/__init__.py +++ b/kornia/utils/__init__.py @@ -3,7 +3,9 @@ from .grid import create_meshgrid, create_meshgrid3d from .helpers import ( _extract_device_dtype, + dataclass_to_dict, deprecated, + dict_to_dataclass, get_cuda_device_if_available, get_cuda_or_mps_device_if_available, get_mps_device_if_available, @@ -58,4 +60,6 @@ "print_image", "xla_is_available", "is_mps_tensor_safe", + "dataclass_to_dict", + "dict_to_dataclass", ] diff --git a/kornia/utils/helpers.py b/kornia/utils/helpers.py index 8d03713d63..3738ee21c6 100644 --- a/kornia/utils/helpers.py +++ b/kornia/utils/helpers.py @@ -2,9 +2,10 @@ import platform import sys import warnings +from dataclasses import asdict, fields, is_dataclass from functools import wraps from inspect import isclass, isfunction -from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union, overload +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union, overload import torch from torch.linalg import inv_ex @@ -318,3 +319,35 @@ def is_autocast_enabled(both: bool = True) -> bool: return torch.is_autocast_enabled() or torch.is_autocast_cpu_enabled() return torch.is_autocast_enabled() + + +def dataclass_to_dict(obj: Any) -> Any: + """Recursively convert dataclass instances to dictionaries.""" + if is_dataclass(obj) and not isinstance(obj, type): + return {key: dataclass_to_dict(value) for key, value in asdict(obj).items()} + elif isinstance(obj, (list, tuple)): + return type(obj)(dataclass_to_dict(item) for item in obj) + elif isinstance(obj, dict): + return {key: dataclass_to_dict(value) for key, value in obj.items()} + else: + return obj + + +T = TypeVar("T") + + +def dict_to_dataclass(dict_obj: Dict[str, Any], dataclass_type: Type[T]) -> T: + """Recursively convert dictionaries to dataclass instances.""" + if not isinstance(dict_obj, dict): + raise TypeError("Input conf must be dict") + if not is_dataclass(dataclass_type): + raise TypeError("dataclass_type must be a dataclass") + field_types = {f.name: f.type for f in fields(dataclass_type)} + constructor_args = {} + for key, value in dict_obj.items(): + if key in field_types and is_dataclass(field_types[key]): + constructor_args[key] = dict_to_dataclass(value, field_types[key]) + else: + constructor_args[key] = value + # TODO: remove type ignore when https://github.com/python/mypy/issues/14941 be andressed + return dataclass_type(**constructor_args) # type: ignore[return-value]