diff --git a/kornia/contrib/models/rt_detr/post_processor.py b/kornia/contrib/models/rt_detr/post_processor.py index 95ad6109c7..0ae898ad5e 100644 --- a/kornia/contrib/models/rt_detr/post_processor.py +++ b/kornia/contrib/models/rt_detr/post_processor.py @@ -1,11 +1,35 @@ +"""Post-processor for the RT-DETR model.""" + from __future__ import annotations -# TODO: import torch from kornia.core import Module, Tensor, concatenate +def mod(a, b): + """Compute the modulo operation for two numbers. + + This function calculates the remainder of the division of 'a' by 'b' + using the formula: a - (a // b) * b, which is equivalent to the modulo operation. + + Args: + a: The dividend. + b: The divisor. + + Returns: + The remainder of a divided by b. + + Example: + >>> mod(7, 3) + 1 + >>> mod(8.5, 3.2) + 2.1 + """ + return a - (a // b) * b + + +# TODO: deprecate the confidence threshold and add the num_top_queries as a parameter and num_classes as a parameter class DETRPostProcessor(Module): def __init__(self, confidence_threshold: float) -> None: super().__init__() @@ -45,16 +69,13 @@ def forward(self, logits: Tensor, boxes: Tensor, original_sizes: Tensor) -> list boxes_xy = boxes_xy * sizes_wh scores = logits.sigmoid() # RT-DETR was trained with focal loss. thus sigmoid is used instead of softmax - # the original code is slightly different - # it allows 1 bounding box to have multiple classes (multi-label) - scores, labels = scores.max(-1) - - detections: list[Tensor] = [] - for i in range(scores.shape[0]): - mask = scores[i] >= self.confidence_threshold - labels_i = labels[i, mask].unsqueeze(-1) - scores_i = scores[i, mask].unsqueeze(-1) - boxes_i = boxes_xy[i, mask] - detections.append(concatenate([labels_i, scores_i, boxes_i], -1)) + # retrieve the boxes with the highest score for each class + # https://github.com/lyuwenyu/RT-DETR/blob/b6bf0200b249a6e35b44e0308b6058f55b99696b/rtdetrv2_pytorch/src/zoo/rtdetr/rtdetr_postprocessor.py#L55-L62 + num_top_queries = 300 # TODO: make this configurable + num_classes = 80 # TODO: make this configurable + scores, index = torch.topk(scores.flatten(1), num_top_queries, dim=-1) + labels = mod(index, num_classes) + index = index // num_classes + boxes = boxes_xy.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes_xy.shape[-1])) - return detections + return concatenate([labels[..., None], scores[..., None], boxes], -1) diff --git a/kornia/contrib/object_detection.py b/kornia/contrib/object_detection.py index b348899a3b..9a289b4f17 100644 --- a/kornia/contrib/object_detection.py +++ b/kornia/contrib/object_detection.py @@ -12,6 +12,7 @@ from kornia.core.check import KORNIA_CHECK_SHAPE from kornia.core.external import PILImage as Image from kornia.core.external import numpy as np +from kornia.geometry.transform import resize from kornia.io import write_image from kornia.utils.draw import draw_rectangle @@ -126,13 +127,12 @@ def forward(self, imgs: list[Tensor]) -> tuple[Tensor, Tensor]: """ # TODO: support other input formats e.g. file path, numpy resized_imgs, original_sizes = [], [] - for i in range(len(imgs)): + for i in range(imgs.shape[0]): img = imgs[i] - # NOTE: assume that image layout is CHW original_sizes.append([img.shape[-2], img.shape[-1]]) resized_imgs.append( - # TODO: fix kornia resize to support onnx - torch.nn.functional.interpolate(img.unsqueeze(0), size=self.size, mode=self.interpolation_mode) + # TODO: fix kornia resize warnings + resize(img[None], size=self.size, interpolation=self.interpolation_mode) ) return concatenate(resized_imgs), as_tensor(original_sizes) @@ -181,7 +181,8 @@ def draw(self, images: list[Tensor], output_type: str = "torch") -> list[Tensor] out_img = image[None].clone() for out in detection: out_img = draw_rectangle( - out_img, torch.Tensor([[[out[-4], out[-3], out[-4] + out[-2], out[-3] + out[-1]]]]) + out_img, + torch.Tensor([[[out[-4], out[-3], out[-4] + out[-2], out[-3] + out[-1]]]]), ) if output_type == "torch": output.append(out_img[0]) @@ -204,7 +205,10 @@ def save(self, images: list[Tensor], directory: Optional[str] = None) -> None: outputs = self.draw(images) os.makedirs(directory, exist_ok=True) for i, out_image in enumerate(outputs): - write_image(os.path.join(directory, f"{str(i).zfill(6)}.jpg"), out_image.mul(255.0).byte()) + write_image( + os.path.join(directory, f"{str(i).zfill(6)}.jpg"), + out_image.mul(255.0).byte(), + ) def compile( self,