-
Notifications
You must be signed in to change notification settings - Fork 8
/
rt_detr_ui.py
68 lines (62 loc) · 2.35 KB
/
rt_detr_ui.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import cv2
import gradio as gr
import numpy as np
from ultralytics import RTDETR
import warnings
from PIL import Image
warnings.filterwarnings("ignore")
class RT_DETR_WebUI:
def __init__(self):
pass
# Define a function for model inference
def predict(self, image, conf, iou, line_width, device, model_type, model_path):
# choose model type
if model_type == "rtdetr-l":
self.model = RTDETR('../../weights/rtdetr/rtdetr-l.pt')
elif model_type == "rtdetr-x":
self.model = RTDETR('../../weights/rtdetr/rtdetr-x.pt')
else:
self.model = RTDETR(model_path)
# results = self.model(image)
results = self.model.predict(image,save_dir='output/')
# results.save(save_dir='output/')
res = results[0]
print("res:", res)
save_dir = res.save_dir
path = res.path
# img_path = save_dir + path
# print("img_path:", path)
# dst = Image.open(path)
# dst = cv2.cvtColor(np.array(dst), cv2.COLOR_RGB2BGR)
# dst = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
dst = image
return dst
if __name__ == '__main__':
# Instantiate RT_DETR_WebUI class
detector = RT_DETR_WebUI()
# Define Gradio interface
iface = gr.Interface(
fn=detector.predict,
inputs=["image",
gr.inputs.Slider(minimum=0, maximum=1, step=0.01, default=0.25,
label="Confidence Threshold"),
gr.inputs.Slider(minimum=0, maximum=1, step=0.01, default=0.45,
label="IoU Threshold"),
gr.inputs.Number(default=2, label="Line Width"),
gr.inputs.Radio(["cpu", "cuda"], label="Device", default="cpu"),
gr.inputs.Radio(["rtdetr-l", "rtdetr-x"],
label="Model Type", default="rtdetr-l"),
gr.inputs.Textbox(default="rtdetr-l.pt", label="Model Path")],
outputs="image",
title="RT_DETR Object Detector",
description="Detect objects in an image using RT_DETR model.",
theme="default",
layout="vertical",
allow_flagging=True,
analytics_enabled=True,
server_port=None,
server_name=None,
server_protocol=None,
)
# Run Gradio interface
iface.launch(share=True)