diff --git a/src/cellmap_data/dataset.py b/src/cellmap_data/dataset.py index ff27f00..96c1487 100644 --- a/src/cellmap_data/dataset.py +++ b/src/cellmap_data/dataset.py @@ -7,6 +7,10 @@ import tensorstore from .image import CellMapImage, EmptyImage +import logging + +logger = logging.getLogger(__name__) + def split_target_path(path: str) -> tuple[str, list[str]]: """Splits a path to groundtruth data into the main path string, and the classes supplied for it.""" @@ -170,18 +174,23 @@ def __len__(self): if not self.has_data and not self.force_has_data: return 0 if self._len is None: - size = 1 - for _, (start, stop) in self.sampling_box.items(): - size *= abs(stop - start) - size /= np.prod(list(self.largest_voxel_sizes.values())) - self._len = int(np.floor(size)) + size = np.prod([self.sampling_box_shape[c] for c in self.axis_order]) + self._len = int(size) return self._len def __getitem__(self, idx): """Returns a crop of the input and target data as PyTorch tensors, corresponding to the coordinate of the unwrapped index.""" - center = np.unravel_index( - idx, [self.sampling_box_shape[c] for c in self.axis_order] - ) + try: + center = np.unravel_index( + idx, [self.sampling_box_shape[c] for c in self.axis_order] + ) + except ValueError: + logger.error( + f"Index {idx} out of bounds for dataset {self} of length {len(self)}" + ) + logger.warning(f"Returning closest index in bounds") + # TODO: This is a hacky temprorary fix. Need to figure out why this is happening + center = [self.sampling_box_shape[c] for c in self.axis_order] center = { c: center[i] * self.largest_voxel_sizes[c] + self.sampling_box[c][0] for i, c in enumerate(self.axis_order)