Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change sold2 detector config to dataclass #21

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
d86d71d
Introduce dataclasses for discussion
lappemic Apr 10, 2024
291f038
Add dict_to_dataclass function
lappemic Apr 10, 2024
9a8b865
Add dataclass_to_dict function
lappemic Apr 11, 2024
b559a57
Update SOLD2_detector to use newly introduced dataclasses
lappemic Apr 11, 2024
cec2f2b
Remove default_detector_cfg dict
lappemic Apr 12, 2024
7f2e062
Remove comment
lappemic Apr 12, 2024
f8eba24
Extend LineDetectorCfg to all configs used in LineSegmentDetectionModule
lappemic Apr 12, 2024
4175970
Update LineSegmentDetectionModule init to use LineDetectorCfg dataclass
lappemic Apr 12, 2024
65302f3
Update LineSegmentDetectinoModule docstring
lappemic Apr 12, 2024
e107d0c
Update SOLD2_detector docstring
lappemic Apr 12, 2024
8a2fe1b
Update LineSegmentDetectionModule call to use its dataclass
lappemic Apr 12, 2024
e76bf13
Remove dict_to_dataclass since its not used
lappemic Apr 12, 2024
c7c1ba9
Add DeprecationWarning for dict as config in favour of dataclass config
lappemic Apr 15, 2024
f1da475
Fix DeprecationWarning
lappemic Apr 15, 2024
f63d2d8
Fix downstream errors of dataclass changes and typos LineDetectorCfg
lappemic Apr 15, 2024
e77adbb
Add typing to dict_to_dataclass and dataclass_to_dict functions
lappemic Apr 15, 2024
afaf740
Fix typ checking
lappemic Apr 15, 2024
b4340fb
Update dataclass typing to Any
lappemic Apr 16, 2024
6552f89
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 16, 2024
91897e1
Update dict_to_dataclass to TypeVar typing
lappemic Apr 16, 2024
af76418
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 16, 2024
53673fd
move dataclass_to_dict and dict_to_dataclass to kornia/utils/helpers.py
lappemic Apr 16, 2024
cb6ec15
Fix type checking errors in dict_to_dataclass by bounding TypeVar to …
lappemic Apr 16, 2024
517f43d
Remove any from dict to dataclass
johnnv1 Apr 16, 2024
408aaf3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 3 additions & 20 deletions kornia/feature/sold2/sold2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from kornia.utils import map_location_to_cpu

from .backbones import SOLD2Net
from .sold2_detector import LineSegmentDetectionModule, line_map_to_segments, prob_to_junctions
from .sold2_detector import LineDetectorCfg, LineSegmentDetectionModule, line_map_to_segments, prob_to_junctions

urls: Dict[str, str] = {}
urls["wireframe"] = "http://cmp.felk.cvut.cz/~mishkdmy/models/sold2_wireframe.pth"
Expand All @@ -22,23 +22,7 @@
"keep_border_valid": True,
"detection_thresh": 0.0153846, # = 1/65: threshold of junction detection
"max_num_junctions": 500, # maximum number of junctions per image
"line_detector_cfg": {
"detect_thresh": 0.5,
"num_samples": 64,
"inlier_thresh": 0.99,
"use_candidate_suppression": True,
"nms_dist_tolerance": 3.0,
"use_heatmap_refinement": True,
"heatmap_refine_cfg": {
"mode": "local",
"ratio": 0.2,
"valid_thresh": 0.001,
"num_blocks": 20,
"overlap_ratio": 0.5,
},
"use_junction_refinement": True,
"junction_refine_cfg": {"num_perturbs": 9, "perturb_interval": 0.25},
},
"line_detector_cfg": LineDetectorCfg(),
"line_matcher_cfg": {
"cross_check": True,
"num_samples": 5,
Expand Down Expand Up @@ -92,8 +76,7 @@ def __init__(self, pretrained: bool = True, config: Optional[Dict[str, Any]] = N
self.eval()

# Initialize the line detector
self.line_detector_cfg = self.config["line_detector_cfg"]
self.line_detector = LineSegmentDetectionModule(**self.config["line_detector_cfg"])
self.line_detector = LineSegmentDetectionModule(LineDetectorCfg())

# Initialize the line matcher
self.line_matcher = WunschLineMatcher(**self.config["line_matcher_cfg"])
Expand Down
220 changes: 128 additions & 92 deletions kornia/feature/sold2/sold2_detector.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,71 @@
import math
import warnings
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Tuple

import torch

from kornia.core import Module, Tensor, concatenate, sin, stack, tensor, where, zeros
from kornia.core.check import KORNIA_CHECK_SHAPE
from kornia.geometry.bbox import nms
from kornia.utils import map_location_to_cpu, torch_meshgrid
from kornia.utils import dataclass_to_dict, dict_to_dataclass, map_location_to_cpu, torch_meshgrid

from .backbones import SOLD2Net

urls: Dict[str, str] = {}
urls["wireframe"] = "https://www.polybox.ethz.ch/index.php/s/blOrW89gqSLoHOk/download"


default_detector_cfg = {
"backbone_cfg": {"input_channel": 1, "depth": 4, "num_stacks": 2, "num_blocks": 1, "num_classes": 5},
"use_descriptor": False,
"grid_size": 8,
"keep_border_valid": True,
"detection_thresh": 0.0153846, # = 1/65: threshold of junction detection
"max_num_junctions": 500, # maximum number of junctions per image
"line_detector_cfg": {
"detect_thresh": 0.5,
"num_samples": 64,
"inlier_thresh": 0.99,
"use_candidate_suppression": True,
"nms_dist_tolerance": 3.0,
"use_heatmap_refinement": True,
"heatmap_refine_cfg": {
"mode": "local",
"ratio": 0.2,
"valid_thresh": 0.001,
"num_blocks": 20,
"overlap_ratio": 0.5,
},
"use_junction_refinement": True,
"junction_refine_cfg": {"num_perturbs": 9, "perturb_interval": 0.25},
},
}
@dataclass
class HeatMapRefineCfg:
mode: str = "local"
ratio: float = 0.2
valid_thresh: float = 0.001
num_blocks: int = 20
overlap_ratio: float = 0.5


@dataclass
class JunctionRefineCfg:
num_perturbs: int = 9
perturb_interval: float = 0.25


@dataclass
class LineDetectorCfg:
detect_thresh: float = 0.5
num_samples: int = 64
inlier_thresh: float = 0.99
use_candidate_suppression: bool = True
nms_dist_tolerance: float = 3.0
heatmap_low_thresh: float = 0.15
heatmap_high_thresh: float = 0.2
max_local_patch_radius: float = 3
lambda_radius: float = 2.0
use_heatmap_refinement: bool = True
heatmap_refine_cfg: HeatMapRefineCfg = field(default_factory=HeatMapRefineCfg)
use_junction_refinement: bool = True
junction_refine_cfg: JunctionRefineCfg = field(default_factory=JunctionRefineCfg)


@dataclass
class BackboneCfg:
input_channel: int = 1
depth: int = 4
num_stacks: int = 2
num_blocks: int = 1
num_classes: int = 5


@dataclass
class DetectorCfg:
backbone_cfg: BackboneCfg = field(default_factory=BackboneCfg)
use_descriptor: bool = False
grid_size: int = 8
keep_border_valid: bool = True
detection_thresh: float = 0.0153846 # = 1/65: threshold of junction detection
max_num_junctions: int = 500 # maximum number of junctions per image
line_detector_cfg: LineDetectorCfg = field(default_factory=LineDetectorCfg)


class SOLD2_detector(Module):
Expand All @@ -48,9 +75,10 @@ class SOLD2_detector(Module):
Occlusion-aware Line Detector and Descriptor". See :cite:`SOLD22021` for more details.

Args:
config: Dict specifying parameters. None will load the default parameters,
which are tuned for images in the range 400~800 px.
pretrained: If True, download and set pretrained weights to the model.
config (DetectorCfg): Configuration object containing all parameters. None will load the default parameters,
which are tuned for images in the range 400~800 px. Using a dataclass ensures type safety and clearer
parameter management.
pretrained (bool): If True, download and set pretrained weights to the model.

Returns:
The raw junction and line heatmaps, as well as the list of detected line segments (ij coordinates convention).
Expand All @@ -61,25 +89,34 @@ class SOLD2_detector(Module):
>>> line_segments = sold2_detector(img)["line_segments"]
"""

def __init__(self, pretrained: bool = True, config: Optional[Dict[str, Any]] = None) -> None:
def __init__(self, pretrained: bool = True, config: Optional[DetectorCfg] = None) -> None:
if isinstance(config, dict):
warnings.warn(
"Usage of config as a plain dictionary is deprecated in favor of"
" `kornia.feature.sold2.sold2_detector.DetectorCfg`. The support of plain dictionaries"
"as config will be removed in kornia v0.8.0 (December 2024).",
category=DeprecationWarning,
stacklevel=2,
)
config = dict_to_dataclass(config, DetectorCfg)
super().__init__()
# Initialize some parameters
self.config = default_detector_cfg if config is None else config
self.grid_size = self.config["grid_size"]
self.junc_detect_thresh = self.config.get("detection_thresh", 1 / 65)
self.max_num_junctions = self.config.get("max_num_junctions", 500)
self.config = config if config is not None else DetectorCfg()
self.grid_size = self.config.grid_size
self.junc_detect_thresh = self.config.detection_thresh
self.max_num_junctions = self.config.max_num_junctions

# Load the pre-trained model
self.model = SOLD2Net(self.config)
self.model = SOLD2Net(dataclass_to_dict(self.config))

if pretrained:
pretrained_dict = torch.hub.load_state_dict_from_url(urls["wireframe"], map_location=map_location_to_cpu)
state_dict = self.adapt_state_dict(pretrained_dict["model_state_dict"])
self.model.load_state_dict(state_dict)
self.eval()

# Initialize the line detector
self.line_detector_cfg = self.config["line_detector_cfg"]
self.line_detector = LineSegmentDetectionModule(**self.config["line_detector_cfg"])
# Initialize the line detector with a configuration from the dataclass
self.line_detector = LineSegmentDetectionModule(self.config.line_detector_cfg)

def adapt_state_dict(self, state_dict: Dict[str, Any]) -> Dict[str, Any]:
del state_dict["w_junc"]
Expand Down Expand Up @@ -127,66 +164,62 @@ class LineSegmentDetectionModule:
r"""Module extracting line segments from junctions and line heatmaps.

Args:
detect_thresh: The probability threshold for mean activation (0. ~ 1.)
num_samples: Number of sampling locations along the line segments.
inlier_thresh: The min inlier ratio to satisfy (0. ~ 1.) => 0. means no threshold.
heatmap_low_thresh: The lowest threshold for the pixel to be considered as candidate in junction recovery.
heatmap_high_thresh: The higher threshold for NMS in junction recovery.
max_local_patch_radius: The max patch to be considered in local maximum search.
lambda_radius: The lambda factor in linear local maximum search formulation
use_candidate_suppression: Apply candidate suppression to break long segments into short sub-segments.
nms_dist_tolerance: The distance tolerance for nms. Decide whether the junctions are on the line.
use_heatmap_refinement: Use heatmap refinement method or not.
heatmap_refine_cfg: The configs for heatmap refinement methods.
use_junction_refinement: Use junction refinement method or not.
junction_refine_cfg: The configs for junction refinement methods.
config (LineDetectorCfg): Configuration dataclass containing all settings required for line segment detection.
- detect_thresh (float): Probability threshold for mean activation (0. ~ 1.).
- num_samples (int): Number of sampling locations along the line segments.
- inlier_thresh (float): Minimum inlier ratio to satisfy (0. ~ 1.) => 0. means no threshold.
- heatmap_low_thresh (float): Lowest threshold for pixel considered as a candidate in junction recovery.
- heatmap_high_thresh (float): Higher threshold for NMS in junction recovery.
- max_local_patch_radius (float): Maximum patch to be considered in local maximum search.
- lambda_radius (float): Lambda factor in linear local maximum search formulation.
- use_candidate_suppression (bool): Apply candidate suppression to break long segments into sub-segments.
- nms_dist_tolerance (float): Distance tolerance for NMS. Decides whether the junctions are on the line.
- use_heatmap_refinement (bool): Whether to use heatmap refinement methods.
- heatmap_refine_cfg: Configuration for heatmap refinement methods.
- use_junction_refinement (bool): Whether to use junction refinement methods.
- junction_refine_cfg: Configuration for junction refinement methods.

Example:
>>> config = LineDetectorCfg(detect_thresh=0.6, use_heatmap_refinement=True)
>>> module = LineSegmentDetectionModule(config)
>>> junctions, heatmap = torch.rand(10, 2), torch.rand(256, 256)
>>> line_map, junctions, _ = module.detect(junctions, heatmap)
"""

def __init__(
self,
detect_thresh: float,
num_samples: int = 64,
inlier_thresh: float = 0.0,
heatmap_low_thresh: float = 0.15,
heatmap_high_thresh: float = 0.2,
max_local_patch_radius: float = 3,
lambda_radius: float = 2.0,
use_candidate_suppression: bool = False,
nms_dist_tolerance: float = 3.0,
use_heatmap_refinement: bool = False,
heatmap_refine_cfg: Optional[Dict[str, Any]] = None,
use_junction_refinement: bool = False,
junction_refine_cfg: Optional[Dict[str, Any]] = None,
) -> None:
def __init__(self, config: LineDetectorCfg = LineDetectorCfg()) -> None:
# Load LineDetectorCfg
self.config = config

# Line detection parameters
self.detect_thresh = detect_thresh
self.detect_thresh = self.config.detect_thresh
# self.detect_thresh = detect_thresh

# Line sampling parameters
self.num_samples = num_samples
self.inlier_thresh = inlier_thresh
self.local_patch_radius = max_local_patch_radius
self.lambda_radius = lambda_radius
self.num_samples = self.config.num_samples
self.inlier_thresh = self.config.inlier_thresh
self.local_patch_radius = self.config.max_local_patch_radius
self.lambda_radius = self.config.lambda_radius

# Detecting junctions on the boundary parameters
self.low_thresh = heatmap_low_thresh
self.high_thresh = heatmap_high_thresh
self.low_thresh = self.config.heatmap_low_thresh
self.high_thresh = self.config.heatmap_high_thresh

# Pre-compute the linspace sampler
self.torch_sampler = torch.linspace(0, 1, self.num_samples)

# Long line segment suppression configuration
self.use_candidate_suppression = use_candidate_suppression
self.nms_dist_tolerance = nms_dist_tolerance
self.use_candidate_suppression = self.config.use_candidate_suppression
self.nms_dist_tolerance = self.config.nms_dist_tolerance

# Heatmap refinement configuration
self.use_heatmap_refinement = use_heatmap_refinement
self.heatmap_refine_cfg = heatmap_refine_cfg
self.use_heatmap_refinement = self.config.use_heatmap_refinement
self.heatmap_refine_cfg = self.config.heatmap_refine_cfg
if self.use_heatmap_refinement and self.heatmap_refine_cfg is None:
raise ValueError("[Error] Missing heatmap refinement config.")

# Junction refinement configuration
self.use_junction_refinement = use_junction_refinement
self.junction_refine_cfg = junction_refine_cfg
self.use_junction_refinement = self.config.use_junction_refinement
self.junction_refine_cfg = self.config.junction_refine_cfg
if self.use_junction_refinement and self.junction_refine_cfg is None:
raise ValueError("[Error] Missing junction refinement config.")

Expand All @@ -197,18 +230,18 @@ def detect(self, junctions: Tensor, heatmap: Tensor) -> Tuple[Tensor, Tensor, Te
device = junctions.device

# Perform the heatmap refinement
if self.use_heatmap_refinement and isinstance(self.heatmap_refine_cfg, dict):
if self.heatmap_refine_cfg["mode"] == "global":
if self.use_heatmap_refinement and isinstance(self.heatmap_refine_cfg, HeatMapRefineCfg):
if self.heatmap_refine_cfg.mode == "global":
heatmap = self.refine_heatmap(
heatmap, self.heatmap_refine_cfg["ratio"], self.heatmap_refine_cfg["valid_thresh"]
heatmap, self.heatmap_refine_cfg.ratio, self.heatmap_refine_cfg.valid_thresh
)
elif self.heatmap_refine_cfg["mode"] == "local":
elif self.heatmap_refine_cfg.mode == "local":
heatmap = self.refine_heatmap_local(
heatmap,
self.heatmap_refine_cfg["num_blocks"],
self.heatmap_refine_cfg["overlap_ratio"],
self.heatmap_refine_cfg["ratio"],
self.heatmap_refine_cfg["valid_thresh"],
self.heatmap_refine_cfg.num_blocks,
self.heatmap_refine_cfg.overlap_ratio,
self.heatmap_refine_cfg.ratio,
self.heatmap_refine_cfg.valid_thresh,
)

# Initialize empty line map
Expand Down Expand Up @@ -393,10 +426,13 @@ def refine_junction_perturb(
) -> Tuple[Tensor, Tensor]:
"""Refine the line endpoints in a similar way as in LSD."""
# Fetch refinement parameters
if not isinstance(self.junction_refine_cfg, dict):
raise TypeError(f"Expected to have a dict of config for junction. Gotcha {type(self.junction_refine_cfg)}")
num_perturbs = self.junction_refine_cfg["num_perturbs"]
perturb_interval = self.junction_refine_cfg["perturb_interval"]
if not isinstance(self.junction_refine_cfg, JunctionRefineCfg):
raise TypeError(
"Expected to have dataclass of type JunctionRefineCfg for junction."
f"Gotcha {type(self.junction_refine_cfg)}"
)
num_perturbs = self.junction_refine_cfg.num_perturbs
perturb_interval = self.junction_refine_cfg.perturb_interval
side_perturbs = (num_perturbs - 1) // 2

# Fetch the 2D perturb mat
Expand Down
4 changes: 4 additions & 0 deletions kornia/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from .grid import create_meshgrid, create_meshgrid3d
from .helpers import (
_extract_device_dtype,
dataclass_to_dict,
deprecated,
dict_to_dataclass,
get_cuda_device_if_available,
get_cuda_or_mps_device_if_available,
get_mps_device_if_available,
Expand Down Expand Up @@ -58,4 +60,6 @@
"print_image",
"xla_is_available",
"is_mps_tensor_safe",
"dataclass_to_dict",
"dict_to_dataclass",
]
Loading
Loading