Skip to content

Commit

Permalink
implemented metadata indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
keyaloding committed Jul 29, 2024
1 parent 7cb04c4 commit a66cf8f
Showing 1 changed file with 13 additions and 21 deletions.
34 changes: 13 additions & 21 deletions sleap_io/io/nwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,22 +108,24 @@ def nwb_skeleton_to_sleap(skeleton: NWBSkeleton) -> SLEAPSkeleton: # type: igno


def labels_to_pose_training(
labels: Labels, skeletons_list: list[NWBSkeleton], index_data, # type: ignore[return]
labels: Labels,
skeletons_list: list[NWBSkeleton], # type: ignore[return]
video_info: tuple[dict[int, str], Video, ImageSeries],
) -> PoseTraining: # type: ignore[return]
"""Creates an NWB PoseTraining object from a Labels object.
Args:
labels: A Labels object.
skeletons_list: A list of NWB skeletons.
video_info: A tuple containing a dictionary mapping frame indices to file paths,
the video, and the `ImageSeries`.
Returns:
A PoseTraining object.
"""
training_frame_list = []
skeleton_instances_list = []
source_video_list = []
# image_series: dict[Video, ImageSeries] = {}
# path_index: dict[tuple[Video, int], str] = {}
for i, labeled_frame in enumerate(labels.labeled_frames):
for instance, skeleton in zip(labeled_frame.instances, skeletons_list):
skeleton_instance = instance_to_skeleton_instance(instance, skeleton)
Expand All @@ -135,19 +137,8 @@ def labels_to_pose_training(
training_frame_video = labeled_frame.video
training_frame_video_index = labeled_frame.frame_idx

source_video = ImageSeries(
name=f"video_{i}",
description="N/A",
unit="NA",
format="external",
external_file=[training_frame_video.filename],
dimension=[
training_frame_video.backend.img_shape[0],
training_frame_video.backend.img_shape[1],
],
starting_frame=[0],
rate=30.0, # change to `video.backend.fps` when available
)
_, _, image_series = video_info
source_video = image_series
source_video_list.append(source_video)
training_frame = TrainingFrame(
name=f"training_frame_{i}",
Expand Down Expand Up @@ -247,7 +238,8 @@ def write_video_to_path(
image_format: str = "png",
) -> tuple[dict[int, str], Video, ImageSeries]:
"""
Write individual frames of a video to a path and return .
Write individual frames of a video to a path and return the frame indices,
file paths, video, and `ImageSeries`.
Args:
video: The video to write.
Expand All @@ -256,7 +248,7 @@ def write_video_to_path(
Returns:
A tuple containing a dictionary mapping frame indices to file paths,
the video, and the ImageSeries.
the video, and the `ImageSeries`.
"""
index_data = {}
if frame_inds is None:
Expand Down Expand Up @@ -288,7 +280,7 @@ def write_video_to_path(
name="video",
external_file=img_paths,
starting_frame=frame_inds,
rate=30.0,
rate=30.0, # TODO - change to `video.backend.fps` when available
)
return index_data, video, image_series

Expand Down Expand Up @@ -608,8 +600,8 @@ def append_nwb_training(
skeletons = Skeletons(skeletons=skeletons_list)
nwb_processing_module.add(skeletons)

index_data = write_video_to_path(labels.video, frame_inds)[0]
pose_training = labels_to_pose_training(labels, skeletons_list, index_data)
video_info = write_video_to_path(labels.video, frame_inds)
pose_training = labels_to_pose_training(labels, skeletons_list, video_info)
nwb_processing_module.add(pose_training)

camera = nwbfile.create_device(
Expand Down

0 comments on commit a66cf8f

Please sign in to comment.