diff --git a/src/cellmap_data/dataloader.py b/src/cellmap_data/dataloader.py index 2a8f944..1c49441 100644 --- a/src/cellmap_data/dataloader.py +++ b/src/cellmap_data/dataloader.py @@ -1,6 +1,4 @@ from torch.utils.data import DataLoader -from torch.utils.data.dataloader import _BaseDataLoaderIter -from .dataset import CellMapDataset from .datasplit import CellMapDataSplit from typing import Callable, Iterable @@ -17,25 +15,3 @@ class CellMapDataLoader(DataLoader): is_train: bool augmentations: list[dict[str, any]] to_target: Callable - - def __init__( - ...) - - def __getitem__(self, idx: int) -> dict[str, tuple[torch.Tensor, torch.Tensor]]: - """Returns the input and target data for the given index.""" - ... - - def __len__(self) -> int: - """Returns the length of the dataset.""" - ... - - def __iter__(self): - - - def _apply_augmentations(self, data: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: - """Applies augmentations to the data.""" - ... - - def _to_target(self, data: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: - """Converts the input data to the target data.""" - ... diff --git a/src/cellmap_data/dataset.py b/src/cellmap_data/dataset.py index b3ff157..5ece7e5 100644 --- a/src/cellmap_data/dataset.py +++ b/src/cellmap_data/dataset.py @@ -1,10 +1,20 @@ # %% import csv from typing import Callable, Dict, Iterable, Optional +import torch from torch.utils.data import Dataset import tensorstore as tswift from fibsem_tools.io.core import read, read_xarray -from .image import CellMapImage +from .image import CellMapImage, EmptyImage + + +def split_gt_path(path: str) -> tuple[str, list[str]]: + """Splits a path to groundtruth data into the main path string, and the classes supplied for it.""" + path_prefix, path_rem = path.split("[") + classes, path_suffix = path_rem.split("]") + classes = classes.split(",") + path_string = path_prefix + "{label}" + path_suffix + return path_string, classes # %% @@ -17,7 +27,7 @@ class CellMapDataset(Dataset): input_arrays: dict[str, dict[str, Iterable[int | float]]] target_arrays: dict[str, dict[str, Iterable[int | float]]] input_sources: dict[str, CellMapImage] - target_sources: dict[str, dict[str, CellMapImage]] + target_sources: dict[str, dict[str, CellMapImage | EmptyImage]] def __init__( self, @@ -51,7 +61,8 @@ def __init__( } """ self.raw_path = raw_path - self.gt_path = gt_path + self.gt_paths = gt_path + self.gt_path_str, self.classes_with_path = split_gt_path(gt_path) self.classes = classes self.input_arrays = input_arrays self.target_arrays = target_arrays @@ -65,7 +76,9 @@ def __getitem__(self, idx): """Returns a random crop of the input and target data as PyTorch tensors.""" ... - def __iter__(self): ... + def __iter__(self): + """Iterates over the dataset, covering each section of the bounding box. For instance, for calculating validation scores.""" + ... def construct(self): """Constructs the input and target sources for the dataset.""" @@ -78,15 +91,22 @@ def construct(self): array_info["shape"], ) self.target_sources = {} + self.has_data = False for array_name, array_info in self.target_arrays.items(): self.target_sources[array_name] = {} for label in self.classes: - self.target_sources[array_name][label] = CellMapImage( - self.gt_path, - label, - array_info["scale"], - array_info["shape"], - ) + if label in self.classes_with_path: + self.target_sources[array_name][label] = CellMapImage( + self.gt_path_str.format(label=label), + label, + array_info["scale"], + array_info["shape"], + ) + self.has_data = True + else: + self.target_sources[array_name][label] = EmptyImage( + label, array_info["shape"] + ) @property def bounding_box(self): @@ -94,6 +114,8 @@ def bounding_box(self): if self._bounding_box is None: bounding_box = {c: [0, 2**32] for c in "xyz"} for source in [self.input_sources.values(), self.target_sources.values()]: + if source.bounding_box is None: + continue for c, (start, stop) in source.bounding_box.items(): bounding_box[c][0] = max(bounding_box[c][0], start) bounding_box[c][1] = min(bounding_box[c][1], stop) diff --git a/src/cellmap_data/image.py b/src/cellmap_data/image.py index 8ded8d0..577dad4 100644 --- a/src/cellmap_data/image.py +++ b/src/cellmap_data/image.py @@ -52,3 +52,35 @@ def class_counts(self) -> int: # TODO ... return self._class_counts + + +class EmptyImage: + shape: tuple[float, ...] + label_class: str + class_count: int + store: torch.Tensor + + def __init__( + self, + target_class: str, + target_voxel_shape: Iterable[int], + ): + self.label_class = target_class + self.output_shape = tuple(target_voxel_shape) + self._bounding_box = None + self._class_counts = 0 + self.store = torch.zeros(self.output_shape) + + def __getitem__(self, center: Iterable[float]): + """Returns image data centered around the given point, based on the scale and shape of the target output image.""" + return self.store + + @property + def bounding_box(self) -> None: + """Returns the bounding box of the dataset.""" + return self._bounding_box + + @property + def class_counts(self) -> int: + """Returns the number of pixels for the contained class in the ground truth data, normalized by the resolution.""" + return self._class_counts