-
Notifications
You must be signed in to change notification settings - Fork 0
/
tong_hop.py
182 lines (155 loc) · 6.2 KB
/
tong_hop.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import sys
import cv2
import numpy as np
import torch
from time import time
from PyQt5.QtCore import Qt
from PyQt5 import QtGui
from PyQt5.QtCore import QThread, pyqtSignal
from PyQt5.QtGui import QPixmap
from PyQt5.QtWidgets import QApplication, QMainWindow
from gui import Ui_MainWindow
class MainWindow(QMainWindow):
def __init__(self):
super().__init__()
self.uic = Ui_MainWindow()
self.uic.setupUi(self)
self.uic.Button_start.clicked.connect(self.start_capture_video)
self.uic.Button_stop.clicked.connect(self.stop_capture_video)
self.thread = {}
def closeEvent(self, event):
self.stop_capture_video()
def stop_capture_video(self):
self.thread[1].pause_stream()
self.thread[1].stop()
def start_capture_video(self):
self.thread[1] = live_stream(index=1)
self.thread[1].start()
self.thread[1].signal.connect(self.show_wedcam)
def show_wedcam(self, cv_img):
"""Updates the image_label with a new opencv image"""
qt_img = convert_cv_qt(cv_img)
self.uic.label.setPixmap(qt_img)
def convert_cv_qt(cv_img):
"""Convert from an opencv image to QPixmap"""
rgb_image = cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB)
h, w, ch = rgb_image.shape
bytes_per_line = ch * w
convert_to_Qt_format = QtGui.QImage(rgb_image.data, w, h, bytes_per_line, QtGui.QImage.Format_RGB888)
p = convert_to_Qt_format.scaled(700, 500, Qt.KeepAspectRatio)
return QPixmap.fromImage(p)
class live_stream(QThread):
signal = pyqtSignal(np.ndarray)
def __init__(self, index):
self.device = None
self.out_file = None
self.classes = None
self.model = None
self.gg = True
self.player = None
self.index = index
print("start threading", self.index)
super(live_stream, self).__init__()
def run(self):
"""
Initializes the class with youtube url and output file.
:param url: Has to be as youtube URL,on which prediction is made.
:param out_file: A valid output file name.
"""
self.model = self.load_model() # load model
self.classes = self.model.names
self.out_file = "Labeled_Video.avi"
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.run_program()
def get_video_from_url(self):
"""
Creates a new video streaming object to extract video frame by frame to make prediction on.
:return: opencv2 video capture object, with lowest quality frame available for video.
"""
return cv2.VideoCapture(0) # "D:/8.Record video/movie.mp4"
def load_model(self):
"""
Loads Yolo5 model from pytorch hub.
:return: Trained Pytorch model.
"""
# model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
model = torch.hub.load('ultralytics/yolov5', 'custom', path='yolov5s.pt')
# model = torch.hub.load('yolov5-master', 'custom', path='yolov5s.pt', source='local')
return model
def score_frame(self, frame):
"""
Takes a single frame as input, and scores the frame using yolo5 model.
:param frame: input frame in numpy/list/tuple format.
:return: Labels and Coordinates of objects detected by model in the frame.
"""
self.model.to(self.device)
frame = [frame]
results = self.model(frame)
labels, cord = results.xyxyn[0][:, -1].numpy(), results.xyxyn[0][:, :-1].numpy()
return labels, cord
def class_to_label(self, x):
"""
For a given label value, return corresponding string label.
:param x: numeric label
:return: corresponding string label
"""
return self.classes[int(x)]
def plot_boxes(self, results, frame):
"""
Takes a frame and its results as input, and plots the bounding boxes and label on to the frame.
:param results: contains labels and coordinates predicted by model on the given frame.
:param frame: Frame which has been scored.
:return: Frame with bounding boxes and labels ploted on it.
"""
labels, cord = results
n = len(labels)
x_shape, y_shape = frame.shape[1], frame.shape[0]
for i in range(n):
row = cord[i]
print("ddd", round(cord[i][4], 2))
if row[4] >= 0.2:
x1, y1, x2, y2 = int(row[0] * x_shape), int(row[1] * y_shape), int(row[2] * x_shape), int(
row[3] * y_shape)
bgr = (0, 255, 0)
cv2.rectangle(frame, (x1, y1), (x2, y2), bgr, 2)
cv2.putText(frame, self.class_to_label(labels[i]) + " " + str(round(row[4], 2)), (x1, y1),
cv2.FONT_HERSHEY_SIMPLEX, 0.9, bgr, 2)
return frame
def run_program(self):
"""
This function is called when class is executed, it runs the loop to read the video frame by frame,
and write the output into a new file.
:return: void
"""
self.player = self.get_video_from_url()
assert self.player.isOpened()
x_shape = int(self.player.get(cv2.CAP_PROP_FRAME_WIDTH))
y_shape = int(self.player.get(cv2.CAP_PROP_FRAME_HEIGHT))
four_cc = cv2.VideoWriter_fourcc(*"MJPG")
out = cv2.VideoWriter(self.out_file, four_cc, 20, (x_shape, y_shape))
while True:
start_time = time()
ret, frame = self.player.read()
assert ret
results = self.score_frame(frame)
frame = self.plot_boxes(results, frame)
end_time = time()
fps = 1 / (np.round(end_time - start_time, 3))
print(f"Frames Per Second : {round(fps, 2)} FPS")
# out.write(frame)
self.signal.emit(frame)
if not self.gg:
print("stop capture video")
break
def stop(self):
print("stop threading", self.index)
self.player.release()
cv2.destroyAllWindows()
self.terminate()
def pause_stream(self):
self.gg = False
if __name__ == "__main__":
app = QApplication(sys.argv)
main_win = MainWindow()
main_win.show()
sys.exit(app.exec())