-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlunge_labeling.py
73 lines (59 loc) · 2.59 KB
/
lunge_labeling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# import libraries
import mediapipe as mp
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2
import csv
import os
mp_drawing = mp.solutions.drawing_utils
mp_drawing_styles = mp.solutions.drawing_styles
mp_holistic = mp.solutions.holistic
mp_pose = mp.solutions.pose
landmarks = ['class', 'posture_type']
for val in range(1, 33 + 1):
landmarks += [f'x{val}', f'y{val}', f'z{val}', f'v{val}']
# CSV 파일 초기화
output_csv = 'lunge_labeling.csv'
with open(output_csv, mode='w', newline='') as f:
csv_writer = csv.writer(f, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
csv_writer.writerow(landmarks)
def export_landmark(results, filename, action_label, posture_type):
try:
keypoints = [filename, action_label, posture_type] + [
coord for res in results.pose_landmarks.landmark for coord in [res.x, res.y, res.z, res.visibility]
]
with open(output_csv, mode='a', newline='') as f:
csv_writer = csv.writer(f, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
csv_writer.writerow(keypoints)
except Exception as e:
print(f"Error: {e}")
# 이미지 경로
folder = "/root/posepal/lunge/"
image_files = sorted([file for file in os.listdir(folder) if file.endswith('.jpg')])
# 이미지 순회 및 랜드마크 추출
with mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5) as pose:
for file in image_files:
file_path = os.path.join(folder, file)
# Action Label: 파일명 앞 3자리 숫자 추출
try:
action_label = file.split('-')[0]
except IndexError:
print(f"Inavlid file name format: {file}")
continue
# Posture Type 결정
posture_type = 'correct' if action_label == '081' else 'incorrect'
# 이미지 읽기
image = cv2.imread(file_path)
if image is None:
print(f"Could not read image: {file}")
continue
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Mediapipe로 Pose 추정
results = pose.process(image_rgb)
# Landmarks 추출 및 CSV 저장
if results.pose_landmarks:
export_landmark(results, file, action_label, posture_type)
print(f"Processed: {file} | Label: {action_label} | Type: {posture_type}")