From d86d71d728183dcbb68e515b3267eb4c7128d7d2 Mon Sep 17 00:00:00 2001 From: Michael Lappert Date: Wed, 10 Apr 2024 15:05:51 +0200 Subject: [PATCH 01/25] Introduce dataclasses for discussion --- kornia/feature/sold2/sold2_detector.py | 49 ++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/kornia/feature/sold2/sold2_detector.py b/kornia/feature/sold2/sold2_detector.py index 60e97822e9..15abc50363 100644 --- a/kornia/feature/sold2/sold2_detector.py +++ b/kornia/feature/sold2/sold2_detector.py @@ -1,4 +1,5 @@ import math +from dataclasses import dataclass, field from typing import Any, Dict, Optional, Tuple import torch @@ -14,6 +15,54 @@ urls["wireframe"] = "https://www.polybox.ethz.ch/index.php/s/blOrW89gqSLoHOk/download" +@dataclass +class HeatMapRefineCfg: + mode: str = "local" + ratio: float = 0.2 + valid_thresh: float = 0.001 + num_blocks: int = 20 + overlap_ratio: float = 0.5 + + +@dataclass +class JunctionRefineCfg: + num_perturbs: int = 9 + perturb_interval: float = 0.25 + + +@dataclass +class LineDetectorCfg: + detect_thresh: float = 0.5 + num_samples: int = 64 + inlier_thresh: float = 0.99 + use_candidate_suppression: bool = True + nms_dist_tolerance: float = 3.0 + use_heatmap_refinement: bool = True + heatmap_refine_cfg: HeatMapRefineCfg = field(default_factory=HeatMapRefineCfg) + use_junction_refinement: bool = True + junction_refine_cfg: JunctionRefineCfg = field(default_factory=JunctionRefineCfg) + + +@dataclass +class BackboneCfg: + input_channel: int = 1 + depth: int = 4 + num_stacks: int = 2 + num_blocks: int = 1 + num_classes: int = 5 + + +@dataclass +class DetectorCfg: + backbone_cfg: BackboneCfg = field(default_factory=BackboneCfg) + use_descriptor: bool = False + grid_size: int = 8 + keep_border_valid: bool = True + detection_thresh: float = 0.0153846 + max_num_junctions: int = 500 + line_detector_cfg: LineDetectorCfg = field(default_factory=LineDetectorCfg) + + default_detector_cfg = { "backbone_cfg": {"input_channel": 1, "depth": 4, "num_stacks": 2, "num_blocks": 1, "num_classes": 5}, "use_descriptor": False, From 291f0380d244a57154a55c9b56e4b7d8af8d62d1 Mon Sep 17 00:00:00 2001 From: Michael Lappert Date: Wed, 10 Apr 2024 15:16:29 +0200 Subject: [PATCH 02/25] Add dict_to_dataclass function --- kornia/feature/sold2/sold2_detector.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/kornia/feature/sold2/sold2_detector.py b/kornia/feature/sold2/sold2_detector.py index 15abc50363..847e026230 100644 --- a/kornia/feature/sold2/sold2_detector.py +++ b/kornia/feature/sold2/sold2_detector.py @@ -1,5 +1,5 @@ import math -from dataclasses import dataclass, field +from dataclasses import dataclass, field, is_dataclass from typing import Any, Dict, Optional, Tuple import torch @@ -63,6 +63,20 @@ class DetectorCfg: line_detector_cfg: LineDetectorCfg = field(default_factory=LineDetectorCfg) +def dict_to_dataclass(dict_obj, dataclass_type): + """Recursively convert dictionaries to dataclass instances.""" + if not isinstance(dict_obj, dict): + return TypeError("Input conf must be dict") + field_types = {f.name: f.type for f in dataclass_type.__dataclass_fields__.values()} + constructor_args = {} + for key, value in dict_obj.items(): + if key in field_types and is_dataclass(field_types[key]): + constructor_args[key] = dict_to_dataclass(value, field_types[key]) + else: + constructor_args[key] = value + return dataclass_type(**constructor_args) + + default_detector_cfg = { "backbone_cfg": {"input_channel": 1, "depth": 4, "num_stacks": 2, "num_blocks": 1, "num_classes": 5}, "use_descriptor": False, From 9a8b8658119cb23d703322f9e04a41c655d5a18e Mon Sep 17 00:00:00 2001 From: Michael Lappert Date: Thu, 11 Apr 2024 08:20:03 +0200 Subject: [PATCH 03/25] Add dataclass_to_dict function --- kornia/feature/sold2/sold2_detector.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/kornia/feature/sold2/sold2_detector.py b/kornia/feature/sold2/sold2_detector.py index 847e026230..29733bacf9 100644 --- a/kornia/feature/sold2/sold2_detector.py +++ b/kornia/feature/sold2/sold2_detector.py @@ -1,5 +1,5 @@ import math -from dataclasses import dataclass, field, is_dataclass +from dataclasses import asdict, dataclass, field, is_dataclass from typing import Any, Dict, Optional, Tuple import torch @@ -77,6 +77,18 @@ def dict_to_dataclass(dict_obj, dataclass_type): return dataclass_type(**constructor_args) +def dataclass_to_dict(obj): + """Recursively convert dataclass instances to dictionaries.""" + if is_dataclass(obj) and not isinstance(obj, type): + return {key: dataclass_to_dict(value) for key, value in asdict(obj).items()} + elif isinstance(obj, (list, tuple)): + return type(obj)(dataclass_to_dict(item) for item in obj) + elif isinstance(obj, dict): + return {key: dataclass_to_dict(value) for key, value in obj.items()} + else: + return obj + + default_detector_cfg = { "backbone_cfg": {"input_channel": 1, "depth": 4, "num_stacks": 2, "num_blocks": 1, "num_classes": 5}, "use_descriptor": False, From b559a571392a8b65f22586288d41cf7425a55bcf Mon Sep 17 00:00:00 2001 From: Michael Lappert Date: Thu, 11 Apr 2024 10:18:24 +0200 Subject: [PATCH 04/25] Update SOLD2_detector to use newly introduced dataclasses --- kornia/feature/sold2/sold2_detector.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/kornia/feature/sold2/sold2_detector.py b/kornia/feature/sold2/sold2_detector.py index 29733bacf9..610368f7a6 100644 --- a/kornia/feature/sold2/sold2_detector.py +++ b/kornia/feature/sold2/sold2_detector.py @@ -136,16 +136,17 @@ class SOLD2_detector(Module): >>> line_segments = sold2_detector(img)["line_segments"] """ - def __init__(self, pretrained: bool = True, config: Optional[Dict[str, Any]] = None) -> None: + def __init__(self, pretrained: bool = True, config: Optional[DetectorCfg] = None) -> None: super().__init__() # Initialize some parameters - self.config = default_detector_cfg if config is None else config - self.grid_size = self.config["grid_size"] - self.junc_detect_thresh = self.config.get("detection_thresh", 1 / 65) - self.max_num_junctions = self.config.get("max_num_junctions", 500) + # self.config = default_detector_cfg if config is None else config + self.config = config if config is not None else DetectorCfg() + self.grid_size = self.config.grid_size + self.junc_detect_thresh = self.config.detection_thresh + self.max_num_junctions = self.config.max_num_junctions # Load the pre-trained model - self.model = SOLD2Net(self.config) + self.model = SOLD2Net(**dataclass_to_dict(self.config)) if pretrained: pretrained_dict = torch.hub.load_state_dict_from_url(urls["wireframe"], map_location=map_location_to_cpu) state_dict = self.adapt_state_dict(pretrained_dict["model_state_dict"]) @@ -153,8 +154,8 @@ def __init__(self, pretrained: bool = True, config: Optional[Dict[str, Any]] = N self.eval() # Initialize the line detector - self.line_detector_cfg = self.config["line_detector_cfg"] - self.line_detector = LineSegmentDetectionModule(**self.config["line_detector_cfg"]) + self.line_detector_cfg = self.config.line_detector_cfg + self.line_detector = LineSegmentDetectionModule(**dataclass_to_dict(self.line_detector_cfg)) def adapt_state_dict(self, state_dict: Dict[str, Any]) -> Dict[str, Any]: del state_dict["w_junc"] From cec2f2b3c61b3e96b2af32300b62120ffc2e968f Mon Sep 17 00:00:00 2001 From: Michael Lappert Date: Fri, 12 Apr 2024 10:02:48 +0200 Subject: [PATCH 05/25] Remove default_detector_cfg dict --- kornia/feature/sold2/sold2_detector.py | 27 -------------------------- 1 file changed, 27 deletions(-) diff --git a/kornia/feature/sold2/sold2_detector.py b/kornia/feature/sold2/sold2_detector.py index 610368f7a6..d489a4b63d 100644 --- a/kornia/feature/sold2/sold2_detector.py +++ b/kornia/feature/sold2/sold2_detector.py @@ -89,33 +89,6 @@ def dataclass_to_dict(obj): return obj -default_detector_cfg = { - "backbone_cfg": {"input_channel": 1, "depth": 4, "num_stacks": 2, "num_blocks": 1, "num_classes": 5}, - "use_descriptor": False, - "grid_size": 8, - "keep_border_valid": True, - "detection_thresh": 0.0153846, # = 1/65: threshold of junction detection - "max_num_junctions": 500, # maximum number of junctions per image - "line_detector_cfg": { - "detect_thresh": 0.5, - "num_samples": 64, - "inlier_thresh": 0.99, - "use_candidate_suppression": True, - "nms_dist_tolerance": 3.0, - "use_heatmap_refinement": True, - "heatmap_refine_cfg": { - "mode": "local", - "ratio": 0.2, - "valid_thresh": 0.001, - "num_blocks": 20, - "overlap_ratio": 0.5, - }, - "use_junction_refinement": True, - "junction_refine_cfg": {"num_perturbs": 9, "perturb_interval": 0.25}, - }, -} - - class SOLD2_detector(Module): r"""Module, which detects line segments in an image. From 7f2e0621ce0e59bd796d5644260ee7e83df53ad5 Mon Sep 17 00:00:00 2001 From: Michael Lappert Date: Fri, 12 Apr 2024 10:12:41 +0200 Subject: [PATCH 06/25] Remove comment --- kornia/feature/sold2/sold2_detector.py | 1 - 1 file changed, 1 deletion(-) diff --git a/kornia/feature/sold2/sold2_detector.py b/kornia/feature/sold2/sold2_detector.py index d489a4b63d..5eb585085e 100644 --- a/kornia/feature/sold2/sold2_detector.py +++ b/kornia/feature/sold2/sold2_detector.py @@ -112,7 +112,6 @@ class SOLD2_detector(Module): def __init__(self, pretrained: bool = True, config: Optional[DetectorCfg] = None) -> None: super().__init__() # Initialize some parameters - # self.config = default_detector_cfg if config is None else config self.config = config if config is not None else DetectorCfg() self.grid_size = self.config.grid_size self.junc_detect_thresh = self.config.detection_thresh From f8eba24b3af3b5f5348548ce5f56bfc8ed140718 Mon Sep 17 00:00:00 2001 From: Michael Lappert Date: Fri, 12 Apr 2024 10:16:39 +0200 Subject: [PATCH 07/25] Extend LineDetectorCfg to all configs used in LineSegmentDetectionModule --- kornia/feature/sold2/sold2_detector.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/kornia/feature/sold2/sold2_detector.py b/kornia/feature/sold2/sold2_detector.py index 5eb585085e..11a4dc6ed2 100644 --- a/kornia/feature/sold2/sold2_detector.py +++ b/kornia/feature/sold2/sold2_detector.py @@ -37,6 +37,10 @@ class LineDetectorCfg: inlier_thresh: float = 0.99 use_candidate_suppression: bool = True nms_dist_tolerance: float = 3.0 + heatmap_low_thresh: float = (0.15,) + heatmap_high_thresh: float = (0.2,) + max_local_patch_radius: float = (3,) + lambda_radius: float = (2.0,) use_heatmap_refinement: bool = True heatmap_refine_cfg: HeatMapRefineCfg = field(default_factory=HeatMapRefineCfg) use_junction_refinement: bool = True From 4175970d11c54af31ceda75442c8f29a6faebd0b Mon Sep 17 00:00:00 2001 From: Michael Lappert Date: Fri, 12 Apr 2024 10:24:15 +0200 Subject: [PATCH 08/25] Update LineSegmentDetectionModule init to use LineDetectorCfg dataclass --- kornia/feature/sold2/sold2_detector.py | 44 ++++++++++---------------- 1 file changed, 16 insertions(+), 28 deletions(-) diff --git a/kornia/feature/sold2/sold2_detector.py b/kornia/feature/sold2/sold2_detector.py index 11a4dc6ed2..13193a39b8 100644 --- a/kornia/feature/sold2/sold2_detector.py +++ b/kornia/feature/sold2/sold2_detector.py @@ -194,51 +194,39 @@ class LineSegmentDetectionModule: junction_refine_cfg: The configs for junction refinement methods. """ - def __init__( - self, - detect_thresh: float, - num_samples: int = 64, - inlier_thresh: float = 0.0, - heatmap_low_thresh: float = 0.15, - heatmap_high_thresh: float = 0.2, - max_local_patch_radius: float = 3, - lambda_radius: float = 2.0, - use_candidate_suppression: bool = False, - nms_dist_tolerance: float = 3.0, - use_heatmap_refinement: bool = False, - heatmap_refine_cfg: Optional[Dict[str, Any]] = None, - use_junction_refinement: bool = False, - junction_refine_cfg: Optional[Dict[str, Any]] = None, - ) -> None: + def __init__(self, detect_thresh: float = 0.5, config: LineDetectorCfg = LineDetectorCfg()) -> None: # Line detection parameters self.detect_thresh = detect_thresh + # Load LineDetectorCfg + self.config = config + # Line sampling parameters - self.num_samples = num_samples - self.inlier_thresh = inlier_thresh - self.local_patch_radius = max_local_patch_radius - self.lambda_radius = lambda_radius + self.num_samples = self.config.num_samples + self.inlier_thresh = self.config.inlier_thresh + self.local_patch_radius = self.config.max_local_patch_radius + self.lambda_radius = self.config.lambda_radius # Detecting junctions on the boundary parameters - self.low_thresh = heatmap_low_thresh - self.high_thresh = heatmap_high_thresh + self.low_thresh = self.config.heatmap_low_thresh + self.high_thresh = self.config.heatmap_high_thresh # Pre-compute the linspace sampler self.torch_sampler = torch.linspace(0, 1, self.num_samples) # Long line segment suppression configuration - self.use_candidate_suppression = use_candidate_suppression - self.nms_dist_tolerance = nms_dist_tolerance + self.use_candidate_suppression = self.config.use_candidate_suppression + self.nms_dist_tolerance = self.config.nms_dist_tolerance # Heatmap refinement configuration - self.use_heatmap_refinement = use_heatmap_refinement - self.heatmap_refine_cfg = heatmap_refine_cfg + self.use_heatmap_refinement = self.config.use_heatmap_refinement + self.heatmap_refine_cfg = self.config.heatmap_refine_cfg if self.use_heatmap_refinement and self.heatmap_refine_cfg is None: raise ValueError("[Error] Missing heatmap refinement config.") # Junction refinement configuration - self.use_junction_refinement = use_junction_refinement - self.junction_refine_cfg = junction_refine_cfg + self.use_junction_refinement = self.config.use_junction_refinement + self.junction_refine_cfg = self.config.junction_refine_cfg if self.use_junction_refinement and self.junction_refine_cfg is None: raise ValueError("[Error] Missing junction refinement config.") From 65302f3a94fa6234490490586d7b27cefb44c78b Mon Sep 17 00:00:00 2001 From: Michael Lappert Date: Fri, 12 Apr 2024 10:36:26 +0200 Subject: [PATCH 09/25] Update LineSegmentDetectinoModule docstring --- kornia/feature/sold2/sold2_detector.py | 39 +++++++++++++++----------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/kornia/feature/sold2/sold2_detector.py b/kornia/feature/sold2/sold2_detector.py index 13193a39b8..41dd9c5bc6 100644 --- a/kornia/feature/sold2/sold2_detector.py +++ b/kornia/feature/sold2/sold2_detector.py @@ -179,28 +179,35 @@ class LineSegmentDetectionModule: r"""Module extracting line segments from junctions and line heatmaps. Args: - detect_thresh: The probability threshold for mean activation (0. ~ 1.) - num_samples: Number of sampling locations along the line segments. - inlier_thresh: The min inlier ratio to satisfy (0. ~ 1.) => 0. means no threshold. - heatmap_low_thresh: The lowest threshold for the pixel to be considered as candidate in junction recovery. - heatmap_high_thresh: The higher threshold for NMS in junction recovery. - max_local_patch_radius: The max patch to be considered in local maximum search. - lambda_radius: The lambda factor in linear local maximum search formulation - use_candidate_suppression: Apply candidate suppression to break long segments into short sub-segments. - nms_dist_tolerance: The distance tolerance for nms. Decide whether the junctions are on the line. - use_heatmap_refinement: Use heatmap refinement method or not. - heatmap_refine_cfg: The configs for heatmap refinement methods. - use_junction_refinement: Use junction refinement method or not. - junction_refine_cfg: The configs for junction refinement methods. + config (LineDetectorCfg): Configuration dataclass containing all settings required for line segment detection. + - detect_thresh (float): Probability threshold for mean activation (0. ~ 1.). + - num_samples (int): Number of sampling locations along the line segments. + - inlier_thresh (float): Minimum inlier ratio to satisfy (0. ~ 1.) => 0. means no threshold. + - heatmap_low_thresh (float): Lowest threshold for pixel considered as a candidate in junction recovery. + - heatmap_high_thresh (float): Higher threshold for NMS in junction recovery. + - max_local_patch_radius (float): Maximum patch to be considered in local maximum search. + - lambda_radius (float): Lambda factor in linear local maximum search formulation. + - use_candidate_suppression (bool): Apply candidate suppression to break long segments into sub-segments. + - nms_dist_tolerance (float): Distance tolerance for NMS. Decides whether the junctions are on the line. + - use_heatmap_refinement (bool): Whether to use heatmap refinement methods. + - heatmap_refine_cfg: Configuration for heatmap refinement methods. + - use_junction_refinement (bool): Whether to use junction refinement methods. + - junction_refine_cfg: Configuration for junction refinement methods. + + Example: + >>> config = LineDetectorCfg(detect_thresh=0.6, use_heatmap_refinement=True) + >>> module = LineSegmentDetectionModule(config) + >>> junctions, heatmap = torch.rand(10, 2), torch.rand(256, 256) + >>> line_map, junctions, _ = module.detect(junctions, heatmap) """ def __init__(self, detect_thresh: float = 0.5, config: LineDetectorCfg = LineDetectorCfg()) -> None: - # Line detection parameters - self.detect_thresh = detect_thresh - # Load LineDetectorCfg self.config = config + # Line detection parameters + self.detect_thresh = self.config.detect_thresh + # Line sampling parameters self.num_samples = self.config.num_samples self.inlier_thresh = self.config.inlier_thresh From e107d0cb8562a256a24036ce90527ea1f4aaa4b0 Mon Sep 17 00:00:00 2001 From: Michael Lappert Date: Fri, 12 Apr 2024 10:44:31 +0200 Subject: [PATCH 10/25] Update SOLD2_detector docstring --- kornia/feature/sold2/sold2_detector.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/kornia/feature/sold2/sold2_detector.py b/kornia/feature/sold2/sold2_detector.py index 41dd9c5bc6..5916d2db19 100644 --- a/kornia/feature/sold2/sold2_detector.py +++ b/kornia/feature/sold2/sold2_detector.py @@ -100,9 +100,10 @@ class SOLD2_detector(Module): Occlusion-aware Line Detector and Descriptor". See :cite:`SOLD22021` for more details. Args: - config: Dict specifying parameters. None will load the default parameters, - which are tuned for images in the range 400~800 px. - pretrained: If True, download and set pretrained weights to the model. + config (DetectorCfg): Configuration object containing all parameters. None will load the default parameters, + which are tuned for images in the range 400~800 px. Using a dataclass ensures type safety and clearer + parameter management. + pretrained (bool): If True, download and set pretrained weights to the model. Returns: The raw junction and line heatmaps, as well as the list of detected line segments (ij coordinates convention). From 8a2fe1b4249a7569e8a196a4072f6cee8d07c4ed Mon Sep 17 00:00:00 2001 From: Michael Lappert Date: Fri, 12 Apr 2024 10:50:41 +0200 Subject: [PATCH 11/25] Update LineSegmentDetectionModule call to use its dataclass --- kornia/feature/sold2/sold2_detector.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/kornia/feature/sold2/sold2_detector.py b/kornia/feature/sold2/sold2_detector.py index 5916d2db19..9af02add65 100644 --- a/kornia/feature/sold2/sold2_detector.py +++ b/kornia/feature/sold2/sold2_detector.py @@ -130,9 +130,8 @@ def __init__(self, pretrained: bool = True, config: Optional[DetectorCfg] = None self.model.load_state_dict(state_dict) self.eval() - # Initialize the line detector - self.line_detector_cfg = self.config.line_detector_cfg - self.line_detector = LineSegmentDetectionModule(**dataclass_to_dict(self.line_detector_cfg)) + # Initialize the line detector with a configuration from the dataclass + self.line_detector = LineSegmentDetectionModule(self.config.line_detector_cfg) def adapt_state_dict(self, state_dict: Dict[str, Any]) -> Dict[str, Any]: del state_dict["w_junc"] From e76bf139855dcc385aba1b2a5091ca35c3903a45 Mon Sep 17 00:00:00 2001 From: Michael Lappert Date: Fri, 12 Apr 2024 10:51:40 +0200 Subject: [PATCH 12/25] Remove dict_to_dataclass since its not used --- kornia/feature/sold2/sold2_detector.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/kornia/feature/sold2/sold2_detector.py b/kornia/feature/sold2/sold2_detector.py index 9af02add65..21905de2e9 100644 --- a/kornia/feature/sold2/sold2_detector.py +++ b/kornia/feature/sold2/sold2_detector.py @@ -67,20 +67,6 @@ class DetectorCfg: line_detector_cfg: LineDetectorCfg = field(default_factory=LineDetectorCfg) -def dict_to_dataclass(dict_obj, dataclass_type): - """Recursively convert dictionaries to dataclass instances.""" - if not isinstance(dict_obj, dict): - return TypeError("Input conf must be dict") - field_types = {f.name: f.type for f in dataclass_type.__dataclass_fields__.values()} - constructor_args = {} - for key, value in dict_obj.items(): - if key in field_types and is_dataclass(field_types[key]): - constructor_args[key] = dict_to_dataclass(value, field_types[key]) - else: - constructor_args[key] = value - return dataclass_type(**constructor_args) - - def dataclass_to_dict(obj): """Recursively convert dataclass instances to dictionaries.""" if is_dataclass(obj) and not isinstance(obj, type): From c7c1ba97ea2afc531c1c7402aa1d53df91f54d5d Mon Sep 17 00:00:00 2001 From: Michael <61876623+lappemic@users.noreply.github.com> Date: Mon, 15 Apr 2024 09:12:53 +0200 Subject: [PATCH 13/25] Add DeprecationWarning for dict as config in favour of dataclass config MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: João Gustavo A. Amorim --- kornia/feature/sold2/sold2_detector.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/kornia/feature/sold2/sold2_detector.py b/kornia/feature/sold2/sold2_detector.py index 21905de2e9..a026bf190a 100644 --- a/kornia/feature/sold2/sold2_detector.py +++ b/kornia/feature/sold2/sold2_detector.py @@ -101,6 +101,13 @@ class SOLD2_detector(Module): """ def __init__(self, pretrained: bool = True, config: Optional[DetectorCfg] = None) -> None: + if isinstance(cofig, dict): + warnings.warn( + f"Usage of config as a plain dictionary is deprecated in favor of `kornia.feature.sold2.sold2_detector.DetectorCfg`. The support of plain dictionaries as config will be removed in kornia v0.8.0 (December 2024).", + category=DeprecationWarning, + stacklevel=2, + ) + config = dict_to_dataclass(**config, DetectorCfg) super().__init__() # Initialize some parameters self.config = config if config is not None else DetectorCfg() From f1da4758ecff09c45d58182ff4bd6329b718335a Mon Sep 17 00:00:00 2001 From: Michael Lappert Date: Mon, 15 Apr 2024 10:30:55 +0200 Subject: [PATCH 14/25] Fix DeprecationWarning --- kornia/feature/sold2/sold2_detector.py | 31 ++++++++++++++++++++------ 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/kornia/feature/sold2/sold2_detector.py b/kornia/feature/sold2/sold2_detector.py index a026bf190a..70f0bdaa91 100644 --- a/kornia/feature/sold2/sold2_detector.py +++ b/kornia/feature/sold2/sold2_detector.py @@ -1,4 +1,5 @@ import math +import warnings from dataclasses import asdict, dataclass, field, is_dataclass from typing import Any, Dict, Optional, Tuple @@ -79,6 +80,20 @@ def dataclass_to_dict(obj): return obj +def dict_to_dataclass(dict_obj, dataclass_type): + """Recursively convert dictionaries to dataclass instances.""" + if not isinstance(dict_obj, dict): + return TypeError("Input conf must be dict") + field_types = {f.name: f.type for f in dataclass_type.__dataclass_fields__.values()} + constructor_args = {} + for key, value in dict_obj.items(): + if key in field_types and is_dataclass(field_types[key]): + constructor_args[key] = dict_to_dataclass(value, field_types[key]) + else: + constructor_args[key] = value + return dataclass_type(**constructor_args) + + class SOLD2_detector(Module): r"""Module, which detects line segments in an image. @@ -101,13 +116,15 @@ class SOLD2_detector(Module): """ def __init__(self, pretrained: bool = True, config: Optional[DetectorCfg] = None) -> None: - if isinstance(cofig, dict): + if isinstance(config, dict): warnings.warn( - f"Usage of config as a plain dictionary is deprecated in favor of `kornia.feature.sold2.sold2_detector.DetectorCfg`. The support of plain dictionaries as config will be removed in kornia v0.8.0 (December 2024).", - category=DeprecationWarning, - stacklevel=2, - ) - config = dict_to_dataclass(**config, DetectorCfg) + "Usage of config as a plain dictionary is deprecated in favor of" + " `kornia.feature.sold2.sold2_detector.DetectorCfg`. The support of plain dictionaries" + "as config will be removed in kornia v0.8.0 (December 2024).", + category=DeprecationWarning, + stacklevel=2, + ) + config = dict_to_dataclass(config, DetectorCfg) super().__init__() # Initialize some parameters self.config = config if config is not None else DetectorCfg() @@ -194,7 +211,7 @@ class LineSegmentDetectionModule: >>> line_map, junctions, _ = module.detect(junctions, heatmap) """ - def __init__(self, detect_thresh: float = 0.5, config: LineDetectorCfg = LineDetectorCfg()) -> None: + def __init__(self, config: LineDetectorCfg = LineDetectorCfg()) -> None: # Load LineDetectorCfg self.config = config From f63d2d8475606d0269c21ac57b6b84f51986f6f4 Mon Sep 17 00:00:00 2001 From: Michael Lappert Date: Mon, 15 Apr 2024 11:37:55 +0200 Subject: [PATCH 15/25] Fix downstream errors of dataclass changes and typos LineDetectorCfg --- kornia/feature/sold2/sold2.py | 23 ++------------ kornia/feature/sold2/sold2_detector.py | 43 ++++++++++++++------------ 2 files changed, 27 insertions(+), 39 deletions(-) diff --git a/kornia/feature/sold2/sold2.py b/kornia/feature/sold2/sold2.py index 8958064c0a..76e3be406e 100644 --- a/kornia/feature/sold2/sold2.py +++ b/kornia/feature/sold2/sold2.py @@ -9,7 +9,7 @@ from kornia.utils import map_location_to_cpu from .backbones import SOLD2Net -from .sold2_detector import LineSegmentDetectionModule, line_map_to_segments, prob_to_junctions +from .sold2_detector import LineDetectorCfg, LineSegmentDetectionModule, line_map_to_segments, prob_to_junctions urls: Dict[str, str] = {} urls["wireframe"] = "http://cmp.felk.cvut.cz/~mishkdmy/models/sold2_wireframe.pth" @@ -22,23 +22,7 @@ "keep_border_valid": True, "detection_thresh": 0.0153846, # = 1/65: threshold of junction detection "max_num_junctions": 500, # maximum number of junctions per image - "line_detector_cfg": { - "detect_thresh": 0.5, - "num_samples": 64, - "inlier_thresh": 0.99, - "use_candidate_suppression": True, - "nms_dist_tolerance": 3.0, - "use_heatmap_refinement": True, - "heatmap_refine_cfg": { - "mode": "local", - "ratio": 0.2, - "valid_thresh": 0.001, - "num_blocks": 20, - "overlap_ratio": 0.5, - }, - "use_junction_refinement": True, - "junction_refine_cfg": {"num_perturbs": 9, "perturb_interval": 0.25}, - }, + "line_detector_cfg": LineDetectorCfg(), "line_matcher_cfg": { "cross_check": True, "num_samples": 5, @@ -92,8 +76,7 @@ def __init__(self, pretrained: bool = True, config: Optional[Dict[str, Any]] = N self.eval() # Initialize the line detector - self.line_detector_cfg = self.config["line_detector_cfg"] - self.line_detector = LineSegmentDetectionModule(**self.config["line_detector_cfg"]) + self.line_detector = LineSegmentDetectionModule(LineDetectorCfg()) # Initialize the line matcher self.line_matcher = WunschLineMatcher(**self.config["line_matcher_cfg"]) diff --git a/kornia/feature/sold2/sold2_detector.py b/kornia/feature/sold2/sold2_detector.py index 70f0bdaa91..cd19c797d1 100644 --- a/kornia/feature/sold2/sold2_detector.py +++ b/kornia/feature/sold2/sold2_detector.py @@ -38,10 +38,10 @@ class LineDetectorCfg: inlier_thresh: float = 0.99 use_candidate_suppression: bool = True nms_dist_tolerance: float = 3.0 - heatmap_low_thresh: float = (0.15,) - heatmap_high_thresh: float = (0.2,) - max_local_patch_radius: float = (3,) - lambda_radius: float = (2.0,) + heatmap_low_thresh: float = 0.15 + heatmap_high_thresh: float = 0.2 + max_local_patch_radius: float = 3 + lambda_radius: float = 2.0 use_heatmap_refinement: bool = True heatmap_refine_cfg: HeatMapRefineCfg = field(default_factory=HeatMapRefineCfg) use_junction_refinement: bool = True @@ -63,8 +63,8 @@ class DetectorCfg: use_descriptor: bool = False grid_size: int = 8 keep_border_valid: bool = True - detection_thresh: float = 0.0153846 - max_num_junctions: int = 500 + detection_thresh: float = 0.0153846 # = 1/65: threshold of junction detection + max_num_junctions: int = 500 # maximum number of junctions per image line_detector_cfg: LineDetectorCfg = field(default_factory=LineDetectorCfg) @@ -133,7 +133,8 @@ def __init__(self, pretrained: bool = True, config: Optional[DetectorCfg] = None self.max_num_junctions = self.config.max_num_junctions # Load the pre-trained model - self.model = SOLD2Net(**dataclass_to_dict(self.config)) + self.model = SOLD2Net(dataclass_to_dict(self.config)) + if pretrained: pretrained_dict = torch.hub.load_state_dict_from_url(urls["wireframe"], map_location=map_location_to_cpu) state_dict = self.adapt_state_dict(pretrained_dict["model_state_dict"]) @@ -217,6 +218,7 @@ def __init__(self, config: LineDetectorCfg = LineDetectorCfg()) -> None: # Line detection parameters self.detect_thresh = self.config.detect_thresh + # self.detect_thresh = detect_thresh # Line sampling parameters self.num_samples = self.config.num_samples @@ -254,18 +256,18 @@ def detect(self, junctions: Tensor, heatmap: Tensor) -> Tuple[Tensor, Tensor, Te device = junctions.device # Perform the heatmap refinement - if self.use_heatmap_refinement and isinstance(self.heatmap_refine_cfg, dict): - if self.heatmap_refine_cfg["mode"] == "global": + if self.use_heatmap_refinement and isinstance(self.heatmap_refine_cfg, HeatMapRefineCfg): + if self.heatmap_refine_cfg.mode == "global": heatmap = self.refine_heatmap( - heatmap, self.heatmap_refine_cfg["ratio"], self.heatmap_refine_cfg["valid_thresh"] + heatmap, self.heatmap_refine_cfg.ratio, self.heatmap_refine_cfg.valid_thresh ) - elif self.heatmap_refine_cfg["mode"] == "local": + elif self.heatmap_refine_cfg.mode == "local": heatmap = self.refine_heatmap_local( heatmap, - self.heatmap_refine_cfg["num_blocks"], - self.heatmap_refine_cfg["overlap_ratio"], - self.heatmap_refine_cfg["ratio"], - self.heatmap_refine_cfg["valid_thresh"], + self.heatmap_refine_cfg.num_blocks, + self.heatmap_refine_cfg.overlap_ratio, + self.heatmap_refine_cfg.ratio, + self.heatmap_refine_cfg.valid_thresh, ) # Initialize empty line map @@ -450,10 +452,13 @@ def refine_junction_perturb( ) -> Tuple[Tensor, Tensor]: """Refine the line endpoints in a similar way as in LSD.""" # Fetch refinement parameters - if not isinstance(self.junction_refine_cfg, dict): - raise TypeError(f"Expected to have a dict of config for junction. Gotcha {type(self.junction_refine_cfg)}") - num_perturbs = self.junction_refine_cfg["num_perturbs"] - perturb_interval = self.junction_refine_cfg["perturb_interval"] + if not isinstance(self.junction_refine_cfg, JunctionRefineCfg): + raise TypeError( + "Expected to have dataclass of type JunctionRefineCfg for junction." + f"Gotcha {type(self.junction_refine_cfg)}" + ) + num_perturbs = self.junction_refine_cfg.num_perturbs + perturb_interval = self.junction_refine_cfg.perturb_interval side_perturbs = (num_perturbs - 1) // 2 # Fetch the 2D perturb mat From e77adbb81552ee29f92b1b013f80f2d1d2f414cc Mon Sep 17 00:00:00 2001 From: Michael Lappert Date: Mon, 15 Apr 2024 12:01:19 +0200 Subject: [PATCH 16/25] Add typing to dict_to_dataclass and dataclass_to_dict functions --- kornia/feature/sold2/sold2_detector.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/kornia/feature/sold2/sold2_detector.py b/kornia/feature/sold2/sold2_detector.py index cd19c797d1..1225422ae1 100644 --- a/kornia/feature/sold2/sold2_detector.py +++ b/kornia/feature/sold2/sold2_detector.py @@ -1,7 +1,7 @@ import math import warnings from dataclasses import asdict, dataclass, field, is_dataclass -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Union import torch @@ -68,7 +68,7 @@ class DetectorCfg: line_detector_cfg: LineDetectorCfg = field(default_factory=LineDetectorCfg) -def dataclass_to_dict(obj): +def dataclass_to_dict(obj: Any) -> Union[Dict[str, Any], list, tuple]: """Recursively convert dataclass instances to dictionaries.""" if is_dataclass(obj) and not isinstance(obj, type): return {key: dataclass_to_dict(value) for key, value in asdict(obj).items()} @@ -80,7 +80,9 @@ def dataclass_to_dict(obj): return obj -def dict_to_dataclass(dict_obj, dataclass_type): +def dict_to_dataclass( + dict_obj: Dict[str, Any], dataclass_type: Union[Dict[str, Any], list, tuple] +) -> Union[Dict[str, Any], list, tuple]: """Recursively convert dictionaries to dataclass instances.""" if not isinstance(dict_obj, dict): return TypeError("Input conf must be dict") From afaf740d56d3c3a8cdb9a239ffff97cfd840bafc Mon Sep 17 00:00:00 2001 From: Michael Lappert Date: Mon, 15 Apr 2024 12:18:57 +0200 Subject: [PATCH 17/25] Fix typ checking --- kornia/feature/sold2/sold2_detector.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/kornia/feature/sold2/sold2_detector.py b/kornia/feature/sold2/sold2_detector.py index 1225422ae1..5c00a3f42b 100644 --- a/kornia/feature/sold2/sold2_detector.py +++ b/kornia/feature/sold2/sold2_detector.py @@ -1,7 +1,7 @@ import math import warnings -from dataclasses import asdict, dataclass, field, is_dataclass -from typing import Any, Dict, Optional, Tuple, Union +from dataclasses import asdict, dataclass, field, fields, is_dataclass +from typing import Any, Dict, List, Optional, Tuple, Type, Union import torch @@ -68,7 +68,7 @@ class DetectorCfg: line_detector_cfg: LineDetectorCfg = field(default_factory=LineDetectorCfg) -def dataclass_to_dict(obj: Any) -> Union[Dict[str, Any], list, tuple]: +def dataclass_to_dict(obj: Any) -> Union[Dict[str, Any], List[Any], Tuple[Any, ...]]: """Recursively convert dataclass instances to dictionaries.""" if is_dataclass(obj) and not isinstance(obj, type): return {key: dataclass_to_dict(value) for key, value in asdict(obj).items()} @@ -80,13 +80,11 @@ def dataclass_to_dict(obj: Any) -> Union[Dict[str, Any], list, tuple]: return obj -def dict_to_dataclass( - dict_obj: Dict[str, Any], dataclass_type: Union[Dict[str, Any], list, tuple] -) -> Union[Dict[str, Any], list, tuple]: +def dict_to_dataclass(dict_obj: Dict[str, Any], dataclass_type: Type[Any]) -> Any: """Recursively convert dictionaries to dataclass instances.""" if not isinstance(dict_obj, dict): - return TypeError("Input conf must be dict") - field_types = {f.name: f.type for f in dataclass_type.__dataclass_fields__.values()} + raise TypeError("Input conf must be dict") + field_types = {f.name: f.type for f in fields(dataclass_type)} constructor_args = {} for key, value in dict_obj.items(): if key in field_types and is_dataclass(field_types[key]): From b4340fba8458b8612bffd503e473b608569da5a4 Mon Sep 17 00:00:00 2001 From: Michael <61876623+lappemic@users.noreply.github.com> Date: Tue, 16 Apr 2024 15:21:45 +0200 Subject: [PATCH 18/25] Update dataclass typing to Any MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: João Gustavo A. Amorim --- kornia/feature/sold2/sold2_detector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kornia/feature/sold2/sold2_detector.py b/kornia/feature/sold2/sold2_detector.py index 5c00a3f42b..9d9399e121 100644 --- a/kornia/feature/sold2/sold2_detector.py +++ b/kornia/feature/sold2/sold2_detector.py @@ -68,7 +68,7 @@ class DetectorCfg: line_detector_cfg: LineDetectorCfg = field(default_factory=LineDetectorCfg) -def dataclass_to_dict(obj: Any) -> Union[Dict[str, Any], List[Any], Tuple[Any, ...]]: +def dataclass_to_dict(obj: Any) -> Any: """Recursively convert dataclass instances to dictionaries.""" if is_dataclass(obj) and not isinstance(obj, type): return {key: dataclass_to_dict(value) for key, value in asdict(obj).items()} From 6552f893212b5611b6db10543702880c224dd54d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 16 Apr 2024 13:23:58 +0000 Subject: [PATCH 19/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- kornia/feature/sold2/sold2_detector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kornia/feature/sold2/sold2_detector.py b/kornia/feature/sold2/sold2_detector.py index 9d9399e121..36cc926d56 100644 --- a/kornia/feature/sold2/sold2_detector.py +++ b/kornia/feature/sold2/sold2_detector.py @@ -1,7 +1,7 @@ import math import warnings from dataclasses import asdict, dataclass, field, fields, is_dataclass -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Dict, Optional, Tuple, Type import torch From 91897e11eca74858b65b47204ac5b208cdc40efa Mon Sep 17 00:00:00 2001 From: Michael <61876623+lappemic@users.noreply.github.com> Date: Tue, 16 Apr 2024 15:26:49 +0200 Subject: [PATCH 20/25] Update dict_to_dataclass to TypeVar typing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: João Gustavo A. Amorim --- kornia/feature/sold2/sold2_detector.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/kornia/feature/sold2/sold2_detector.py b/kornia/feature/sold2/sold2_detector.py index 36cc926d56..edfb92bb6b 100644 --- a/kornia/feature/sold2/sold2_detector.py +++ b/kornia/feature/sold2/sold2_detector.py @@ -79,8 +79,11 @@ def dataclass_to_dict(obj: Any) -> Any: else: return obj +from typing import TypeVar -def dict_to_dataclass(dict_obj: Dict[str, Any], dataclass_type: Type[Any]) -> Any: +T = TypeVar('T') + +def dict_to_dataclass(dict_obj: Dict[str, Any], dataclass_type: T) -> T: """Recursively convert dictionaries to dataclass instances.""" if not isinstance(dict_obj, dict): raise TypeError("Input conf must be dict") From af76418eab0457111b2957b46891a3004870a2b3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 16 Apr 2024 13:27:31 +0000 Subject: [PATCH 21/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- kornia/feature/sold2/sold2_detector.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/kornia/feature/sold2/sold2_detector.py b/kornia/feature/sold2/sold2_detector.py index edfb92bb6b..36f90829c1 100644 --- a/kornia/feature/sold2/sold2_detector.py +++ b/kornia/feature/sold2/sold2_detector.py @@ -1,7 +1,7 @@ import math import warnings from dataclasses import asdict, dataclass, field, fields, is_dataclass -from typing import Any, Dict, Optional, Tuple, Type +from typing import Any, Dict, Optional, Tuple import torch @@ -79,9 +79,11 @@ def dataclass_to_dict(obj: Any) -> Any: else: return obj + from typing import TypeVar -T = TypeVar('T') +T = TypeVar("T") + def dict_to_dataclass(dict_obj: Dict[str, Any], dataclass_type: T) -> T: """Recursively convert dictionaries to dataclass instances.""" From 53673fde0465bdd76c4d189d33ae64aa1e2f7b3c Mon Sep 17 00:00:00 2001 From: Michael Lappert Date: Tue, 16 Apr 2024 15:47:05 +0200 Subject: [PATCH 22/25] move dataclass_to_dict and dict_to_dataclass to kornia/utils/helpers.py --- kornia/feature/sold2/sold2_detector.py | 35 ++------------------------ kornia/utils/__init__.py | 4 +++ kornia/utils/helpers.py | 32 ++++++++++++++++++++++- 3 files changed, 37 insertions(+), 34 deletions(-) diff --git a/kornia/feature/sold2/sold2_detector.py b/kornia/feature/sold2/sold2_detector.py index 36f90829c1..61ed8280d7 100644 --- a/kornia/feature/sold2/sold2_detector.py +++ b/kornia/feature/sold2/sold2_detector.py @@ -1,6 +1,6 @@ import math import warnings -from dataclasses import asdict, dataclass, field, fields, is_dataclass +from dataclasses import dataclass, field from typing import Any, Dict, Optional, Tuple import torch @@ -8,7 +8,7 @@ from kornia.core import Module, Tensor, concatenate, sin, stack, tensor, where, zeros from kornia.core.check import KORNIA_CHECK_SHAPE from kornia.geometry.bbox import nms -from kornia.utils import map_location_to_cpu, torch_meshgrid +from kornia.utils import dataclass_to_dict, dict_to_dataclass, map_location_to_cpu, torch_meshgrid from .backbones import SOLD2Net @@ -68,37 +68,6 @@ class DetectorCfg: line_detector_cfg: LineDetectorCfg = field(default_factory=LineDetectorCfg) -def dataclass_to_dict(obj: Any) -> Any: - """Recursively convert dataclass instances to dictionaries.""" - if is_dataclass(obj) and not isinstance(obj, type): - return {key: dataclass_to_dict(value) for key, value in asdict(obj).items()} - elif isinstance(obj, (list, tuple)): - return type(obj)(dataclass_to_dict(item) for item in obj) - elif isinstance(obj, dict): - return {key: dataclass_to_dict(value) for key, value in obj.items()} - else: - return obj - - -from typing import TypeVar - -T = TypeVar("T") - - -def dict_to_dataclass(dict_obj: Dict[str, Any], dataclass_type: T) -> T: - """Recursively convert dictionaries to dataclass instances.""" - if not isinstance(dict_obj, dict): - raise TypeError("Input conf must be dict") - field_types = {f.name: f.type for f in fields(dataclass_type)} - constructor_args = {} - for key, value in dict_obj.items(): - if key in field_types and is_dataclass(field_types[key]): - constructor_args[key] = dict_to_dataclass(value, field_types[key]) - else: - constructor_args[key] = value - return dataclass_type(**constructor_args) - - class SOLD2_detector(Module): r"""Module, which detects line segments in an image. diff --git a/kornia/utils/__init__.py b/kornia/utils/__init__.py index 10efb4ee85..8fc674f6fb 100644 --- a/kornia/utils/__init__.py +++ b/kornia/utils/__init__.py @@ -3,7 +3,9 @@ from .grid import create_meshgrid, create_meshgrid3d from .helpers import ( _extract_device_dtype, + dataclass_to_dict, deprecated, + dict_to_dataclass, get_cuda_device_if_available, get_cuda_or_mps_device_if_available, get_mps_device_if_available, @@ -58,4 +60,6 @@ "print_image", "xla_is_available", "is_mps_tensor_safe", + "dataclass_to_dict", + "dict_to_dataclass", ] diff --git a/kornia/utils/helpers.py b/kornia/utils/helpers.py index 8d03713d63..81cadba660 100644 --- a/kornia/utils/helpers.py +++ b/kornia/utils/helpers.py @@ -2,9 +2,10 @@ import platform import sys import warnings +from dataclasses import asdict, fields, is_dataclass from functools import wraps from inspect import isclass, isfunction -from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union, overload +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union, overload import torch from torch.linalg import inv_ex @@ -318,3 +319,32 @@ def is_autocast_enabled(both: bool = True) -> bool: return torch.is_autocast_enabled() or torch.is_autocast_cpu_enabled() return torch.is_autocast_enabled() + + +def dataclass_to_dict(obj: Any) -> Any: + """Recursively convert dataclass instances to dictionaries.""" + if is_dataclass(obj) and not isinstance(obj, type): + return {key: dataclass_to_dict(value) for key, value in asdict(obj).items()} + elif isinstance(obj, (list, tuple)): + return type(obj)(dataclass_to_dict(item) for item in obj) + elif isinstance(obj, dict): + return {key: dataclass_to_dict(value) for key, value in obj.items()} + else: + return obj + + +T = TypeVar("T") + + +def dict_to_dataclass(dict_obj: Dict[str, Any], dataclass_type: T) -> T: + """Recursively convert dictionaries to dataclass instances.""" + if not isinstance(dict_obj, dict): + raise TypeError("Input conf must be dict") + field_types = {f.name: f.type for f in fields(dataclass_type)} + constructor_args = {} + for key, value in dict_obj.items(): + if key in field_types and is_dataclass(field_types[key]): + constructor_args[key] = dict_to_dataclass(value, field_types[key]) + else: + constructor_args[key] = value + return dataclass_type(**constructor_args) From cb6ec158f30835db1f749acac43cf07aa04251a0 Mon Sep 17 00:00:00 2001 From: Michael Lappert Date: Tue, 16 Apr 2024 16:51:47 +0200 Subject: [PATCH 23/25] Fix type checking errors in dict_to_dataclass by bounding TypeVar to dataclass type --- kornia/utils/helpers.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/kornia/utils/helpers.py b/kornia/utils/helpers.py index 81cadba660..a4531496d1 100644 --- a/kornia/utils/helpers.py +++ b/kornia/utils/helpers.py @@ -5,7 +5,7 @@ from dataclasses import asdict, fields, is_dataclass from functools import wraps from inspect import isclass, isfunction -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union, overload +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union, overload import torch from torch.linalg import inv_ex @@ -333,13 +333,15 @@ def dataclass_to_dict(obj: Any) -> Any: return obj -T = TypeVar("T") +T = TypeVar("T", bound=Type[Any]) -def dict_to_dataclass(dict_obj: Dict[str, Any], dataclass_type: T) -> T: +def dict_to_dataclass(dict_obj: Dict[str, Any], dataclass_type: T) -> Any: """Recursively convert dictionaries to dataclass instances.""" if not isinstance(dict_obj, dict): raise TypeError("Input conf must be dict") + if not is_dataclass(dataclass_type): + raise TypeError("dataclass_type must be a dataclass") field_types = {f.name: f.type for f in fields(dataclass_type)} constructor_args = {} for key, value in dict_obj.items(): From 517f43d4325a3d6440ae0f25b939bf20abcfa7f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Gustavo=20A=2E=20Amorim?= Date: Tue, 16 Apr 2024 20:16:17 -0300 Subject: [PATCH 24/25] Remove any from dict to dataclass --- kornia/utils/helpers.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/kornia/utils/helpers.py b/kornia/utils/helpers.py index a4531496d1..b83c385c85 100644 --- a/kornia/utils/helpers.py +++ b/kornia/utils/helpers.py @@ -333,10 +333,10 @@ def dataclass_to_dict(obj: Any) -> Any: return obj -T = TypeVar("T", bound=Type[Any]) +T = TypeVar("T") -def dict_to_dataclass(dict_obj: Dict[str, Any], dataclass_type: T) -> Any: +def dict_to_dataclass(dict_obj: Dict[str, Any], dataclass_type: Type[T]) -> T: """Recursively convert dictionaries to dataclass instances.""" if not isinstance(dict_obj, dict): raise TypeError("Input conf must be dict") @@ -349,4 +349,6 @@ def dict_to_dataclass(dict_obj: Dict[str, Any], dataclass_type: T) -> Any: constructor_args[key] = dict_to_dataclass(value, field_types[key]) else: constructor_args[key] = value - return dataclass_type(**constructor_args) + # TODO: remove type ignore when https://github.com/python/mypy/issues/14941 be andressed + return dataclass_type(**constructor_args) # type: ignore[return-value] + From 408aaf38edaa8bfaeec8e99d414c49916e1b9ae0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 16 Apr 2024 23:16:29 +0000 Subject: [PATCH 25/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- kornia/utils/helpers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/kornia/utils/helpers.py b/kornia/utils/helpers.py index b83c385c85..3738ee21c6 100644 --- a/kornia/utils/helpers.py +++ b/kornia/utils/helpers.py @@ -351,4 +351,3 @@ def dict_to_dataclass(dict_obj: Dict[str, Any], dataclass_type: Type[T]) -> T: constructor_args[key] = value # TODO: remove type ignore when https://github.com/python/mypy/issues/14941 be andressed return dataclass_type(**constructor_args) # type: ignore[return-value] -