Skip to content

Commit

Permalink
updated video loading and pose_estimation creation
Browse files Browse the repository at this point in the history
  • Loading branch information
keyaloding committed Aug 7, 2024
1 parent 0abaca4 commit 4ba67ea
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 19 deletions.
4 changes: 4 additions & 0 deletions sleap_io/io/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ def load_nwb(filename: str) -> Labels:
The dataset as a `Labels` object.
"""
with NWBHDF5IO(filename, "r", load_namespaces=True) as io:
nwb_processing = io.read().processing
for module in nwb_processing.values():
if 'PoseTraining' in module:
return nwb.read_nwb_training(nwb_processing)
return nwb.read_nwb(filename)


Expand Down
42 changes: 23 additions & 19 deletions sleap_io/io/nwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
except ImportError:
ArrayLike = np.ndarray

from hdmf.utils import LabelledDict

from pynwb import NWBFile, NWBHDF5IO, ProcessingModule
from pynwb.file import Subject
from pynwb.image import ImageSeries
Expand Down Expand Up @@ -425,6 +427,21 @@ def read_nwb(path: str) -> Labels:
return labels


def read_nwb_training(processing_modules: LabelledDict) -> Labels:
"""Read an NWB formatted file with NWB training data to a
SLEAP `Labels` object.
Args:
processing_modules: A dictionary of processing modules from the NWB file.
Returns:
A `Labels` object.
"""
for name, processing_module in processing_modules.items():
if isinstance(processing_module, PoseTraining):
return pose_training_to_labels(processing_module)


def write_nwb(
labels: Labels,
nwbfile_path: str,
Expand Down Expand Up @@ -626,15 +643,6 @@ def append_nwb_training(
pose_training = labels_to_pose_training(labels, skeletons_list, video_info)
nwb_processing_module.add(pose_training)

camera = nwbfile.create_device(
name="Camera",
description="Camera used to record the video",
manufacturer="N/A",
)

reference_frame = (
"The coordinates are in (x, y) relative to the top-left of the image."
)
confidence_definition = "Softmax output of the deep neural network"
pose_estimation_series_list = []
for node in skeletons_list[0].nodes:
Expand All @@ -657,16 +665,12 @@ def append_nwb_training(
except AttributeError:
dimensions = np.array([[400, 400]])

pose_estimation = PoseEstimation(
pose_estimation_series=pose_estimation_series_list,
description="Estimated positions of the nodes in the video",
original_videos=[video.filename for video in labels.videos],
labeled_videos=[video.filename for video in labels.videos],
dimensions=dimensions,
devices=[camera],
source_software="SLEAP",
source_software_version=sleap_version,
skeleton=skeletons_list[0],
pose_estimation = build_pose_estimation_container_for_track(
labels_data_df=pd.DataFrame(),
labels=labels,
track_name="track",
video=video,
pose_estimation_metadata=pose_estimation_metadata,
)
nwb_processing_module.add(pose_estimation)

Expand Down

0 comments on commit 4ba67ea

Please sign in to comment.