Skip to content

Commit

Permalink
compute area for all types of boxes (kornia#2996)
Browse files Browse the repository at this point in the history
* compute area for all types of boxes

* tests for compute area of box

* linter
  • Loading branch information
Isalia20 authored Aug 28, 2024
1 parent f841e6c commit 343357d
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 3 deletions.
16 changes: 13 additions & 3 deletions kornia/geometry/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,9 +392,19 @@ def filter_boxes_by_area(

def compute_area(self) -> torch.Tensor:
"""Returns :math:`(B, N)`."""
w = self._data[..., 1, 0] - self._data[..., 0, 0]
h = self._data[..., 2, 1] - self._data[..., 0, 1]
return (w * h).unsqueeze(0) if self._data.ndim == 3 else (w * h)
coords = self._data.view((-1, 4, 2)) if self._data.ndim == 4 else self._data
# calculate centroid of the box
centroid = coords.mean(dim=1, keepdim=True)
# calculate the angle from centroid to each corner
angles = torch.atan2(coords[..., 1] - centroid[..., 1], coords[..., 0] - centroid[..., 0])
# sort the corners by angle to get an order for shoelace formula
_, clockwise_indices = torch.sort(angles, dim=1, descending=True)
# gather the corners in the new order
ordered_corners = torch.gather(coords, 1, clockwise_indices.unsqueeze(-1).expand(-1, -1, 2))
x, y = ordered_corners[..., 0], ordered_corners[..., 1]
# Gaussian/Shoelace formula https://en.wikipedia.org/wiki/Shoelace_formula
area = 0.5 * torch.abs(torch.sum((x * torch.roll(y, 1, 1)) - (y * torch.roll(x, 1, 1)), dim=1))
return area.view(self._data.shape[:2]) if self._data.ndim == 4 else area

@classmethod
def from_tensor(
Expand Down
37 changes: 37 additions & 0 deletions tests/geometry/test_boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,43 @@ def apply_boxes_method(tensor: torch.Tensor, method: str, **kwargs):
self.gradcheck(lambda x: Boxes.from_tensor(x, mode="xyxy_plus").data, (t_boxes_xyxy,))
self.gradcheck(lambda x: Boxes.from_tensor(x, mode="xywh").data, (t_boxes_xyxy1,))

def test_compute_area(self):
# Rectangle
box_1 = [[0.0, 0.0], [100.0, 0.0], [100.0, 50.0], [0.0, 50.0]]
# Trapezoid
box_2 = [[0.0, 0.0], [60.0, 0.0], [40.0, 50.0], [20.0, 50.0]]
# Parallelogram
box_3 = [[0.0, 0.0], [100.0, 0.0], [120.0, 50.0], [20.0, 50.0]]
# Random quadrilateral
box_4 = [
[50.0, 50.0],
[150.0, 250.0],
[0.0, 500.0],
[27.0, 80],
]
# Random quadrilateral
box_5 = [
[0.0, 0.0],
[150.0, 0.0],
[150.0, 150.0],
[0.0, 0.5],
]
# Rectangle with minus coordinates
box_6 = [[-500.0, -500.0], [-300.0, -500.0], [-300.0, -300.0], [-500.0, -300.0]]

expected_values = [5000.0, 2000.0, 5000.0, 31925.0, 11287.5, 40000.0]
box_coordinates = torch.tensor([box_1, box_2, box_3, box_4, box_5, box_6])
computed_areas = Boxes(box_coordinates).compute_area().tolist()
computed_areas_w_batch = Boxes(box_coordinates.reshape(2, 3, 4, 2)).compute_area().tolist()
flattened_computed_areas_w_batch = [area for batch in computed_areas_w_batch for area in batch]
assert all(
computed_area == expected_area for computed_area, expected_area in zip(computed_areas, expected_values)
)
assert all(
computed_area == expected_area
for computed_area, expected_area in zip(flattened_computed_areas_w_batch, expected_values)
)


class TestTransformBoxes2D(BaseTester):
def test_transform_boxes(self, device, dtype):
Expand Down

0 comments on commit 343357d

Please sign in to comment.