-
Notifications
You must be signed in to change notification settings - Fork 8
/
yolonas_ui.py
83 lines (70 loc) · 3.2 KB
/
yolonas_ui.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import cv2
from PIL import Image
import numpy as np
import gradio as gr
from super_gradients.training import models
import warnings
warnings.filterwarnings("ignore")
class YOLO_NAS_WebUI:
def __init__(self):
pass
def predict(self, image_path,conf, iou, line_width, device, model_type, model_path):
self.device = device
if model_type == "yolo_nas_s":
self.model = models.get("yolo_nas_s", pretrained_weights="coco").to(self.device)
elif model_type == "yolo_nas_m":
self.model = models.get("yolo_nas_m", pretrained_weights="coco").to(self.device)
elif model_type == "yolo_nas_l":
self.model = models.get("yolo_nas_l", pretrained_weights="coco").to(self.device)
else:
self.model = models.get(model_path, pretrained_weights="coco").to(self.device)
if model_type not in ["yolo_nas_s", "yolo_nas_m", "yolo_nas_l"]:
self.model = models.get(model_path, pretrained_weights="coco").to(self.device)
results = self.model.predict(image_path)
# get image data and bbox information
image = results._images_prediction_lst[0].image
class_names = results._images_prediction_lst[0].class_names
prediction = results._images_prediction_lst[0].prediction
bboxes_xyxy = prediction.bboxes_xyxy
labels = prediction.labels
confidences = prediction.confidence
# draw rectangles and label names
for bbox, label, confidence in zip(bboxes_xyxy, labels, confidences):
color = tuple(np.random.randint(0, 255, 3).tolist())
x1, y1, x2, y2 = bbox.astype(int)
if confidence > conf:
cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
cla_name = class_names[int(label)]
label_name = f"{cla_name}: {confidence:.2f}"
cv2.putText(image, label_name, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
return image
if __name__ == '__main__':
# Instantiate YOLO_NAS_WebUI class
detector = YOLO_NAS_WebUI()
# Define Gradio interface
iface = gr.Interface(
fn=detector.predict,
inputs=["image",
gr.inputs.Slider(minimum=0, maximum=1, step=0.01, default=0.25,
label="Confidence Threshold"),
gr.inputs.Slider(minimum=0, maximum=1, step=0.01, default=0.45,
label="IoU Threshold"),
gr.inputs.Number(default=2, label="Line Width"),
gr.inputs.Radio(["cpu", "cuda"], label="Device", default="cpu"),
gr.inputs.Radio(["yolo_nas_s", "yolo_nas_m", "yolo_nas_l"],
label="Model Type", default="yolo_nas_s"),
gr.inputs.Textbox(default="yolo_nas_s.pt", label="Model Path")],
outputs="image",
title="YOLO-NAS Object Detector",
description="Detect objects in an image using YOLO-NAS model.",
theme="default",
layout="vertical",
allow_flagging=True,
analytics_enabled=True,
server_port=None,
server_name=None,
server_protocol=None,
)
# Run Gradio interface
iface.launch(share=True)