-
Notifications
You must be signed in to change notification settings - Fork 1
/
pose_estimation.py
85 lines (70 loc) · 2.49 KB
/
pose_estimation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
import cv2
interpreter = tf.lite.Interpreter(model_path='/home/rohit/PycharmProjects/computer_vision_project/Pose_estimation/lite-model_movenet_singlepose_lightning_3.tflite')
interpreter.allocate_tensors()
EDGES = {
(0, 1): 'm',
(0, 2): 'c',
(1, 3): 'm',
(2, 4): 'c',
(0, 5): 'm',
(0, 6): 'c',
(5, 7): 'm',
(7, 9): 'm',
(6, 8): 'c',
(8, 10): 'c',
(5, 6): 'y',
(5, 11): 'm',
(6, 12): 'c',
(11, 12): 'y',
(11, 13): 'm',
(13, 15): 'm',
(12, 14): 'c',
(14, 16): 'c'
}
def draw_connections(frame, keypoints, edges, confidence_threshold):
y, x, c = frame.shape
shaped = np.squeeze(np.multiply(keypoints, [y, x, 1]))
for edge, color in edges.items():
p1, p2 = edge
y1, x1, c1 = shaped[p1]
y2, x2, c2 = shaped[p2]
if (c1 > confidence_threshold) & (c2 > confidence_threshold):
cv2.line(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255), 2)
def draw_keypoints(frame, keypoints, confidence_threshold):
y, x, c = frame.shape
shaped = np.squeeze(np.multiply(keypoints, [y, x, 1]))
for kp in shaped:
ky, kx, kp_conf = kp
if kp_conf > confidence_threshold:
cv2.circle(frame, (int(kx), int(ky)), 4, (0, 255, 0), -1)
cap = cv2.VideoCapture(0)
while True:
ret, frame = cap.read()
# Reshape image
img = frame.copy()
img = tf.image.resize_with_pad(np.expand_dims(img, axis=0), 192, 192)
input_image = tf.cast(img, dtype=tf.float32)
# Setup input and output
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# print(interpreter.get_output_details())
# Make predictions
interpreter.set_tensor(input_details[0]['index'], np.array(input_image))
interpreter.invoke()
keypoints_with_scores = interpreter.get_tensor(output_details[0]['index'])
#to get the location/ co-ordinates of eldow in image
# left_elbow = keypoints_with_scores[0][0][7]
# np.array(left_elbow[:2]*[480,640]).astype(int)
# Rendering
draw_connections(frame, keypoints_with_scores, EDGES, 0.4)
draw_keypoints(frame, keypoints_with_scores, 0.4)
cv2.imshow('video', frame)
# wait 20 milliseconds between frames and break the loop if the `q` key is pressed
if cv2.waitKey(30) == ord('q'):
break
# we also need to close the video and destroy all Windows
cap.release()
cv2.destroyAllWindows()