Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
keyaloding committed Aug 9, 2024
1 parent 1cd6180 commit a2f96ab
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
8 changes: 4 additions & 4 deletions sleap_io/io/nwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,8 @@ def slp_skeleton_to_nwb(
subject = Subject(
species_id="No specified species", subject_id="No specified id"
)
skeleton_edges = dict(enumerate(skeleton.nodes))
nwb_edges = []
skeleton_edges = dict(enumerate(skeleton.nodes))
for i, source in skeleton_edges.items():
for destination in list(skeleton_edges.values())[i:]:
if Edge(source, destination) in skeleton.edges:
Expand Down Expand Up @@ -267,7 +267,7 @@ def write_video_to_path(
"""
index_data = {}
if frame_inds is None:
frame_inds = list(range(video.shape[0]))
frame_inds = list(range(video.backend.num_frames))

if isinstance(video.filename, list):
save_path = video.filename[0].split(".")[0]
Expand Down Expand Up @@ -643,9 +643,9 @@ def append_nwb_training(
]
skeletons = Skeletons(skeletons=skeletons_list)
nwb_processing_module.add(skeletons)

print(labels.videos)
video_info = write_video_to_path(
labels.video, frame_inds, frame_path=frame_path
labels.videos[0], frame_inds, frame_path=frame_path
)
pose_training = labels_to_pose_training(labels, skeletons_list, video_info)
nwb_processing_module.add(pose_training)
Expand Down
10 changes: 9 additions & 1 deletion tests/io/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,21 @@ def test_load_slp(slp_typical):

def test_nwb_training(tmp_path, slp_typical):
labels = load_slp(slp_typical)
save_nwb(labels, tmp_path / "test.nwb", True)
save_nwb(labels, tmp_path / "test_nwb.nwb")
loaded_labels = load_nwb(tmp_path / "test.nwb")
assert type(loaded_labels) == Labels
assert type(load_file(tmp_path / "test.nwb")) == Labels
assert len(loaded_labels) == len(labels)


def test_nwb(tmp_path, slp_typical):
labels = load_slp(slp_typical)
save_nwb(labels, tmp_path / "test.nwb")
loaded_labels = load_nwb(tmp_path / "test_nwb.nwb")
assert type(loaded_labels) == Labels
assert type(load_file(tmp_path / "test_nwb.nwb")) == Labels


def test_labelstudio(tmp_path, slp_typical):
labels = load_slp(slp_typical)
save_labelstudio(labels, tmp_path / "test_labelstudio.json")
Expand Down

0 comments on commit a2f96ab

Please sign in to comment.