From d83d1f353c68d7aab76719f2324bb2b22a86dc04 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Mon, 1 Apr 2024 22:01:45 -0400 Subject: [PATCH] =?UTF-8?q?feat:=20=E2=9C=A8=20Dataloading=20works.=20Read?= =?UTF-8?q?y=20to=20train.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/cellmap_data/dataset.py | 6 ++---- src/cellmap_data/multidataset.py | 1 + 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/cellmap_data/dataset.py b/src/cellmap_data/dataset.py index c7e0be2..57afc7d 100644 --- a/src/cellmap_data/dataset.py +++ b/src/cellmap_data/dataset.py @@ -138,9 +138,7 @@ def __getitem__(self, idx): outputs = {} for array_name in self.input_arrays.keys(): self.input_sources[array_name].set_spatial_transforms(spatial_transforms) - outputs[array_name] = self.input_sources[array_name][center][ - None, None, ... - ] + outputs[array_name] = self.input_sources[array_name][center][None, ...] # TODO: Allow for distribtion of array gathering to multiple threads for array_name in self.target_arrays.keys(): class_arrays = [] @@ -149,7 +147,7 @@ def __getitem__(self, idx): spatial_transforms ) class_arrays.append(self.target_sources[array_name][label][center]) - outputs[array_name] = torch.stack(class_arrays)[None, ...] + outputs[array_name] = torch.stack(class_arrays) return outputs def __iter__(self): diff --git a/src/cellmap_data/multidataset.py b/src/cellmap_data/multidataset.py index fee787f..bc4dc9c 100644 --- a/src/cellmap_data/multidataset.py +++ b/src/cellmap_data/multidataset.py @@ -25,6 +25,7 @@ def __init__( target_arrays: dict[str, dict[str, Sequence[int | float]]], datasets: Iterable[CellMapDataset], ): + super().__init__(datasets) self.input_arrays = input_arrays self.target_arrays = target_arrays self.classes = classes