diff --git a/kornia/models/detector/rtdetr.py b/kornia/models/detector/rtdetr.py index cabf060027..d16bb9e372 100644 --- a/kornia/models/detector/rtdetr.py +++ b/kornia/models/detector/rtdetr.py @@ -2,7 +2,7 @@ from typing import Optional import torch -import torch.nn as nn +from torch import nn from kornia.contrib.models.rt_detr import DETRPostProcessor from kornia.contrib.models.rt_detr.model import RTDETR, RTDETRConfig @@ -66,7 +66,7 @@ def build( return ObjectDetector( model, ResizePreProcessor(image_size) if image_size is not None else nn.Identity(), - DETRPostProcessor(confidence_threshold) + DETRPostProcessor(confidence_threshold), ) @staticmethod @@ -119,16 +119,13 @@ def to_onnx( if image_size is None: val_image = rand(1, 3, 640, 640) - dynamic_axes={ - 'input' : {0 : 'batch_size', 2: 'height', 3: 'width'}, - 'output' : {0 : 'batch_size', 2: 'height', 3: 'width'} + dynamic_axes = { + "input": {0: "batch_size", 2: "height", 3: "width"}, + "output": {0: "batch_size", 2: "height", 3: "width"}, } else: val_image = rand(1, 3, image_size, image_size) - dynamic_axes={ - 'input' : {0 : 'batch_size'}, - 'output' : {0 : 'batch_size'} - } + dynamic_axes = {"input": {0: "batch_size"}, "output": {0: "batch_size"}} torch.onnx.export( detector, val_image, @@ -136,7 +133,7 @@ def to_onnx( export_params=True, opset_version=17, do_constant_folding=True, - input_names=['input'], - output_names=['output'], - dynamic_axes=dynamic_axes + input_names=["input"], + output_names=["output"], + dynamic_axes=dynamic_axes, )