Skip to content

Commit

Permalink
feat: ✨ Dataloading works. Ready to train.
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed Apr 2, 2024
1 parent 85fd727 commit d83d1f3
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
6 changes: 2 additions & 4 deletions src/cellmap_data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,7 @@ def __getitem__(self, idx):
outputs = {}
for array_name in self.input_arrays.keys():
self.input_sources[array_name].set_spatial_transforms(spatial_transforms)
outputs[array_name] = self.input_sources[array_name][center][
None, None, ...
]
outputs[array_name] = self.input_sources[array_name][center][None, ...]
# TODO: Allow for distribtion of array gathering to multiple threads
for array_name in self.target_arrays.keys():
class_arrays = []
Expand All @@ -149,7 +147,7 @@ def __getitem__(self, idx):
spatial_transforms
)
class_arrays.append(self.target_sources[array_name][label][center])
outputs[array_name] = torch.stack(class_arrays)[None, ...]
outputs[array_name] = torch.stack(class_arrays)
return outputs

def __iter__(self):
Expand Down
1 change: 1 addition & 0 deletions src/cellmap_data/multidataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
target_arrays: dict[str, dict[str, Sequence[int | float]]],
datasets: Iterable[CellMapDataset],
):
super().__init__(datasets)
self.input_arrays = input_arrays
self.target_arrays = target_arrays
self.classes = classes
Expand Down

0 comments on commit d83d1f3

Please sign in to comment.