From 4ba67ea4dad7d8d72dc041a9762584574c679b61 Mon Sep 17 00:00:00 2001 From: Keya Loding Date: Wed, 7 Aug 2024 09:59:35 -0700 Subject: [PATCH] updated video loading and pose_estimation creation --- sleap_io/io/main.py | 4 ++++ sleap_io/io/nwb.py | 42 +++++++++++++++++++++++------------------- 2 files changed, 27 insertions(+), 19 deletions(-) diff --git a/sleap_io/io/main.py b/sleap_io/io/main.py index 4e158d51..0f8fb871 100644 --- a/sleap_io/io/main.py +++ b/sleap_io/io/main.py @@ -61,6 +61,10 @@ def load_nwb(filename: str) -> Labels: The dataset as a `Labels` object. """ with NWBHDF5IO(filename, "r", load_namespaces=True) as io: + nwb_processing = io.read().processing + for module in nwb_processing.values(): + if 'PoseTraining' in module: + return nwb.read_nwb_training(nwb_processing) return nwb.read_nwb(filename) diff --git a/sleap_io/io/nwb.py b/sleap_io/io/nwb.py index 94031e4f..904f1a52 100644 --- a/sleap_io/io/nwb.py +++ b/sleap_io/io/nwb.py @@ -24,6 +24,8 @@ except ImportError: ArrayLike = np.ndarray +from hdmf.utils import LabelledDict + from pynwb import NWBFile, NWBHDF5IO, ProcessingModule from pynwb.file import Subject from pynwb.image import ImageSeries @@ -425,6 +427,21 @@ def read_nwb(path: str) -> Labels: return labels +def read_nwb_training(processing_modules: LabelledDict) -> Labels: + """Read an NWB formatted file with NWB training data to a + SLEAP `Labels` object. + + Args: + processing_modules: A dictionary of processing modules from the NWB file. + + Returns: + A `Labels` object. + """ + for name, processing_module in processing_modules.items(): + if isinstance(processing_module, PoseTraining): + return pose_training_to_labels(processing_module) + + def write_nwb( labels: Labels, nwbfile_path: str, @@ -626,15 +643,6 @@ def append_nwb_training( pose_training = labels_to_pose_training(labels, skeletons_list, video_info) nwb_processing_module.add(pose_training) - camera = nwbfile.create_device( - name="Camera", - description="Camera used to record the video", - manufacturer="N/A", - ) - - reference_frame = ( - "The coordinates are in (x, y) relative to the top-left of the image." - ) confidence_definition = "Softmax output of the deep neural network" pose_estimation_series_list = [] for node in skeletons_list[0].nodes: @@ -657,16 +665,12 @@ def append_nwb_training( except AttributeError: dimensions = np.array([[400, 400]]) - pose_estimation = PoseEstimation( - pose_estimation_series=pose_estimation_series_list, - description="Estimated positions of the nodes in the video", - original_videos=[video.filename for video in labels.videos], - labeled_videos=[video.filename for video in labels.videos], - dimensions=dimensions, - devices=[camera], - source_software="SLEAP", - source_software_version=sleap_version, - skeleton=skeletons_list[0], + pose_estimation = build_pose_estimation_container_for_track( + labels_data_df=pd.DataFrame(), + labels=labels, + track_name="track", + video=video, + pose_estimation_metadata=pose_estimation_metadata, ) nwb_processing_module.add(pose_estimation)