Skip to content

Commit

Permalink
post processor as in the original codew
Browse files Browse the repository at this point in the history
  • Loading branch information
edgarriba committed Sep 7, 2024
1 parent 7e87160 commit db1cb53
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 19 deletions.
47 changes: 34 additions & 13 deletions kornia/contrib/models/rt_detr/post_processor.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,35 @@
"""Post-processor for the RT-DETR model."""

from __future__ import annotations

# TODO:
import torch

from kornia.core import Module, Tensor, concatenate


def mod(a, b):
"""Compute the modulo operation for two numbers.
This function calculates the remainder of the division of 'a' by 'b'
using the formula: a - (a // b) * b, which is equivalent to the modulo operation.
Args:
a: The dividend.
b: The divisor.
Returns:
The remainder of a divided by b.
Example:
>>> mod(7, 3)
1
>>> mod(8.5, 3.2)
2.1
"""
return a - (a // b) * b


# TODO: deprecate the confidence threshold and add the num_top_queries as a parameter and num_classes as a parameter
class DETRPostProcessor(Module):
def __init__(self, confidence_threshold: float) -> None:
super().__init__()
Expand Down Expand Up @@ -45,16 +69,13 @@ def forward(self, logits: Tensor, boxes: Tensor, original_sizes: Tensor) -> list
boxes_xy = boxes_xy * sizes_wh
scores = logits.sigmoid() # RT-DETR was trained with focal loss. thus sigmoid is used instead of softmax

# the original code is slightly different
# it allows 1 bounding box to have multiple classes (multi-label)
scores, labels = scores.max(-1)

detections: list[Tensor] = []
for i in range(scores.shape[0]):
mask = scores[i] >= self.confidence_threshold
labels_i = labels[i, mask].unsqueeze(-1)
scores_i = scores[i, mask].unsqueeze(-1)
boxes_i = boxes_xy[i, mask]
detections.append(concatenate([labels_i, scores_i, boxes_i], -1))
# retrieve the boxes with the highest score for each class
# https://github.com/lyuwenyu/RT-DETR/blob/b6bf0200b249a6e35b44e0308b6058f55b99696b/rtdetrv2_pytorch/src/zoo/rtdetr/rtdetr_postprocessor.py#L55-L62
num_top_queries = 300 # TODO: make this configurable
num_classes = 80 # TODO: make this configurable
scores, index = torch.topk(scores.flatten(1), num_top_queries, dim=-1)
labels = mod(index, num_classes)
index = index // num_classes
boxes = boxes_xy.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes_xy.shape[-1]))

return detections
return concatenate([labels[..., None], scores[..., None], boxes], -1)
16 changes: 10 additions & 6 deletions kornia/contrib/object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from kornia.core.check import KORNIA_CHECK_SHAPE
from kornia.core.external import PILImage as Image
from kornia.core.external import numpy as np
from kornia.geometry.transform import resize
from kornia.io import write_image
from kornia.utils.draw import draw_rectangle

Expand Down Expand Up @@ -126,13 +127,12 @@ def forward(self, imgs: list[Tensor]) -> tuple[Tensor, Tensor]:
"""
# TODO: support other input formats e.g. file path, numpy
resized_imgs, original_sizes = [], []
for i in range(len(imgs)):
for i in range(imgs.shape[0]):
img = imgs[i]
# NOTE: assume that image layout is CHW
original_sizes.append([img.shape[-2], img.shape[-1]])
resized_imgs.append(
# TODO: fix kornia resize to support onnx
torch.nn.functional.interpolate(img.unsqueeze(0), size=self.size, mode=self.interpolation_mode)
# TODO: fix kornia resize warnings
resize(img[None], size=self.size, interpolation=self.interpolation_mode)
)
return concatenate(resized_imgs), as_tensor(original_sizes)

Expand Down Expand Up @@ -181,7 +181,8 @@ def draw(self, images: list[Tensor], output_type: str = "torch") -> list[Tensor]
out_img = image[None].clone()
for out in detection:
out_img = draw_rectangle(
out_img, torch.Tensor([[[out[-4], out[-3], out[-4] + out[-2], out[-3] + out[-1]]]])
out_img,
torch.Tensor([[[out[-4], out[-3], out[-4] + out[-2], out[-3] + out[-1]]]]),
)
if output_type == "torch":
output.append(out_img[0])
Expand All @@ -204,7 +205,10 @@ def save(self, images: list[Tensor], directory: Optional[str] = None) -> None:
outputs = self.draw(images)
os.makedirs(directory, exist_ok=True)
for i, out_image in enumerate(outputs):
write_image(os.path.join(directory, f"{str(i).zfill(6)}.jpg"), out_image.mul(255.0).byte())
write_image(
os.path.join(directory, f"{str(i).zfill(6)}.jpg"),
out_image.mul(255.0).byte(),
)

def compile(
self,
Expand Down

0 comments on commit db1cb53

Please sign in to comment.