Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
shijianjian committed Sep 5, 2024
1 parent d322903 commit 37699e4
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 9 deletions.
9 changes: 4 additions & 5 deletions kornia/contrib/models/rt_detr/architecture/rtdetr_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,6 @@ def forward(
class TransformerDecoder(Module):
def __init__(self, hidden_dim: int, decoder_layer: nn.Module, num_layers: int, eval_idx: int = -1) -> None:
super().__init__()
# self.layers = decoder_layers
# TODO: come back to this later
self.layers = nn.ModuleList([copy.deepcopy(decoder_layer) for _ in range(num_layers)])
self.hidden_dim = hidden_dim
self.num_layers = num_layers
Expand Down Expand Up @@ -271,14 +269,15 @@ def __init__(
num_decoder_layers: int,
num_heads: int = 8,
num_decoder_points: int = 4,
# num_levels: int = 3,
num_levels: int = 3,
dropout: float = 0.0,
num_denoising: int = 100,
) -> None:
super().__init__()
self.num_queries = num_queries
# TODO: verify this is correct
self.num_levels = len(in_channels)
assert len(in_channels) <= num_levels
self.num_levels = num_levels

# build the input projection layers
self.input_proj = nn.ModuleList()
Expand All @@ -292,7 +291,7 @@ def __init__(
embed_dim=hidden_dim,
num_heads=num_heads,
dropout=dropout,
num_levels=len(in_channels),
num_levels=self.num_levels,
num_points=num_decoder_points,
)

Expand Down
2 changes: 0 additions & 2 deletions kornia/contrib/models/rt_detr/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,6 @@ def forward(self, images: Tensor) -> tuple[Tensor, Tensor]:
:math:`K` is the number of classes.
- **boxes** - Tensor of shape :math:`(N, Q, 4)`, where :math:`Q` is the number of queries.
"""
# if self.training:
# raise RuntimeError("Only evaluation mode is supported. Please call model.eval().")

feats = self.backbone(images)
feats_buf = self.encoder(feats)
Expand Down
2 changes: 1 addition & 1 deletion kornia/contrib/models/rt_detr/post_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def forward(self, logits: Tensor, boxes: Tensor, original_sizes: Tensor) -> list

sizes_wh = torch.empty(1, 1, 2, device=boxes.device, dtype=boxes.dtype)
sizes_wh[..., 0] = original_sizes[0][0]
sizes_wh[..., 1] = original_sizes[0][0]
sizes_wh[..., 1] = original_sizes[0][1]
sizes_wh = sizes_wh.repeat(1, 1, 2)

boxes_xy = boxes_xy * sizes_wh
Expand Down
7 changes: 6 additions & 1 deletion kornia/contrib/object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,17 @@ def __init__(self, size: tuple[int, int], interpolation_mode: str = "bilinear")
self.interpolation_mode = interpolation_mode

def forward(self, imgs: list[Tensor]) -> tuple[Tensor, Tensor]:
"""
Returns:
resized_imgs: resized images in a batch.
original_sizes: the original image sizes of (height, width).
"""
# TODO: support other input formats e.g. file path, numpy
resized_imgs, original_sizes = [], []
for i in range(len(imgs)):
img = imgs[i]
# NOTE: assume that image layout is CHW
original_sizes.append([img.shape[1], img.shape[2]])
original_sizes.append([img.shape[-2], img.shape[-1]])
resized_imgs.append(
# TODO: fix kornia resize to support onnx
torch.nn.functional.interpolate(img.unsqueeze(0), size=self.size, mode=self.interpolation_mode)
Expand Down

0 comments on commit 37699e4

Please sign in to comment.