Skip to content

Commit

Permalink
modify image_data loading to convert images to grayscale before runni…
Browse files Browse the repository at this point in the history
…ng through model training
  • Loading branch information
sidhulyalkar committed Oct 31, 2023
1 parent 35d4287 commit b050b6c
Showing 1 changed file with 21 additions and 12 deletions.
33 changes: 21 additions & 12 deletions element_facemap/facemap_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def activate(
create_schema (bool): when True (default), create schema in the database if it
does not yet exist.
create_tables (bool): when True (default), create schema tables in the database
if they do not yet exist.
if they do not yet exist.i
linking_module (str): a module (or name) containing the required dependencies.
Dependencies:
Expand Down Expand Up @@ -298,6 +298,7 @@ def make(self, key):
from facemap.pose import pose
from facemap import utils
import torch
import cv2

train_output_dir = (FacemapModelTrainingTask & key).fetch1("train_output_dir")
output_dir = find_full_path(get_facemap_root_data_dir(), train_output_dir)
Expand Down Expand Up @@ -349,18 +350,28 @@ def make(self, key):

# Currently, only support single video training
assert len(video_files) == 1

video_file = video_files[0]
if len(pre_selected_frame_ind) == 0: # set selected frames to all frames
import cv2

cap = cv2.VideoCapture(video_file)
selected_frame_ind = np.arange(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)))
# Load video capture to iterate through frames and convert to grayscale
cap = cv2.VideoCapture(video_file)
if len(pre_selected_frame_ind) == 0: # set selected frames to all frames
selected_frame_indices = np.arange(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)))
else:
selected_frame_ind = pre_selected_frame_ind
selected_frame_indices = pre_selected_frame_ind
frames = []
for frame_ind in selected_frame_indices:
if int(cap.get(cv2.CAP_PROP_POS_FRAMES)) != frame_ind:
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_ind)
ret, frame = cap.read()
gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
if ret:
frames.append(gray_frame)
else:
print("Error reading frame")
image_data = np.array(frames)

# Load image frames from video
image_data = utils.load_images_from_video(video_file, selected_frame_ind)
# image_data = utils.load_images_from_video(video_file, selected_frame_ind)

keypoints_data = utils.load_keypoints(
list(zip(*facemap_inference.BodyPart.contents))[0], keypoints_file
Expand All @@ -375,10 +386,8 @@ def make(self, key):
) # default = "refined_model"

# Train model using train function defined in Pose class
train_model.net = train_model.train(
image_data[
:, :, :, 0
], # note: using 0 index for now (could average across this dimension)
train_model.train(
image_data,
keypoints_data.T, # needs to be transposed
int(training_params["epochs"]),
int(training_params["batch_size"]),
Expand Down

0 comments on commit b050b6c

Please sign in to comment.