From 4b4589bf85ef891c54f2c8f0f28b8406c97c9c61 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Fri, 10 May 2024 15:17:26 -0400 Subject: [PATCH] =?UTF-8?q?perf:=20=E2=9A=A1=EF=B8=8F=20Refactor=20to=20cl?= =?UTF-8?q?ean=20code=20and=20catch=20bugs.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Also begin working on sampler. --- src/cellmap_data/dataset.py | 33 ++++++++++++--------------- src/cellmap_data/datasplit.py | 1 + src/cellmap_data/image.py | 7 +++++- src/cellmap_data/samplers/__init__.py | 1 + src/cellmap_data/samplers/sampler.py | 26 +++++++++++++++++++++ 5 files changed, 48 insertions(+), 20 deletions(-) create mode 100644 src/cellmap_data/samplers/__init__.py create mode 100644 src/cellmap_data/samplers/sampler.py diff --git a/src/cellmap_data/dataset.py b/src/cellmap_data/dataset.py index cfd9b24..abcee05 100644 --- a/src/cellmap_data/dataset.py +++ b/src/cellmap_data/dataset.py @@ -119,9 +119,7 @@ def __init__( self._rng = rng self.force_has_data = force_has_data self._bounding_box = None - self._bounding_box_shape = None self._sampling_box = None - self._sampling_box_shape = None self._class_counts = None self._largest_voxel_sizes = None self._len = None @@ -193,7 +191,7 @@ def __getitem__(self, idx): ) logger.warning(f"Returning closest index in bounds") # TODO: This is a hacky temprorary fix. Need to figure out why this is happening - center = [self.sampling_box_shape[c] for c in self.axis_order] + center = [self.sampling_box_shape[c] - 1 for c in self.axis_order] center = { c: center[i] * self.largest_voxel_sizes[c] + self.sampling_box[c][0] for i, c in enumerate(self.axis_order) @@ -267,13 +265,8 @@ def bounding_box(self): @property def bounding_box_shape(self): """Returns the shape of the bounding box of the dataset in voxels of the largest voxel size.""" - if self._bounding_box_shape is None: - bounding_box_shape = {c: 0 for c in self.axis_order} - for c, (start, stop) in self.bounding_box.items(): - size = stop - start - size /= self.largest_voxel_sizes[c] - bounding_box_shape[c] = int(size) - self._bounding_box_shape = bounding_box_shape + if not hasattr(self, "_bounding_box_shape"): + self._bounding_box_shape = self._get_box_shape(self.bounding_box) return self._bounding_box_shape @property @@ -295,13 +288,8 @@ def sampling_box(self): @property def sampling_box_shape(self): """Returns the shape of the sampling box of the dataset in voxels of the largest voxel size.""" - if self._sampling_box_shape is None: - sampling_box_shape = {} - for c, (start, stop) in self.sampling_box.items(): - size = stop - start - size /= self.largest_voxel_sizes[c] - sampling_box_shape[c] = int(np.floor(size)) - self._sampling_box_shape = sampling_box_shape + if not hasattr(self, "_sampling_box_shape"): + self._sampling_box_shape = self._get_box_shape(self.sampling_box) return self._sampling_box_shape @property @@ -349,6 +337,14 @@ def validation_indices(self) -> Sequence[int]: chunk_size[c] = np.ceil(size - self.sampling_box_shape[c]).astype(int) return self.get_indices(chunk_size) + def _get_box_shape(self, source_box: dict[str, list[int]]) -> dict[str, int]: + box_shape = {} + for c, (start, stop) in source_box.items(): + size = stop - start + size /= self.largest_voxel_sizes[c] + box_shape[c] = int(np.floor(size)) + return box_shape + def _get_box( self, source_box: dict[str, list[int]], current_box: dict[str, list[int]] ) -> dict[str, list[int]]: @@ -363,8 +359,7 @@ def verify(self): """Verifies that the dataset is valid.""" # TODO: make more robust try: - length = len(self) - return True + return len(self) > 0 except Exception as e: # print(e) return False diff --git a/src/cellmap_data/datasplit.py b/src/cellmap_data/datasplit.py index 3a79004..40700b0 100644 --- a/src/cellmap_data/datasplit.py +++ b/src/cellmap_data/datasplit.py @@ -114,6 +114,7 @@ def to_target(gt: torch.Tensor, classes: Sequence[str]) -> dict[str, torch.Tenso if self.dataset_dict is not None: self.construct(self.dataset_dict) self.verify_datasets() + assert len(self.train_datasets) > 0, "No valid training datasets found." def __repr__(self): return f"CellMapDataSplit(\n\tInput arrays: {self.input_arrays}\n\tTarget arrays:{self.target_arrays}\n\tClasses: {self.classes}\n\tDataset dict: {self.dataset_dict}\n\tSpatial transforms: {self.spatial_transforms}\n\tRaw value transforms: {self.train_raw_value_transforms}\n\tGT value transforms: {self.target_value_transforms}\n\tForce has data: {self.force_has_data}\n\tContext: {self.context})" diff --git a/src/cellmap_data/image.py b/src/cellmap_data/image.py index 5ba5c12..54bdb35 100644 --- a/src/cellmap_data/image.py +++ b/src/cellmap_data/image.py @@ -288,6 +288,7 @@ def __init__( target_voxel_shape: Sequence[int], store: Optional[torch.Tensor] = None, axis_order: str = "zyx", + empty_value: float | int = -100, ): """Initializes an empty image object. @@ -305,10 +306,14 @@ def __init__( self._bounding_box = None self._class_counts = 0 self.scale = {c: 1 for c in self.axes} + self.empty_value = empty_value if store is not None: self.store = store else: - self.store = torch.zeros([1] + [self.output_shape[c] for c in self.axes]) + self.store = ( + torch.ones([1] + [self.output_shape[c] for c in self.axes]) + * self.empty_value + ) def __getitem__(self, center: dict[str, float]) -> torch.Tensor: """Returns image data centered around the given point, based on the scale and shape of the target output image.""" diff --git a/src/cellmap_data/samplers/__init__.py b/src/cellmap_data/samplers/__init__.py new file mode 100644 index 0000000..49f6245 --- /dev/null +++ b/src/cellmap_data/samplers/__init__.py @@ -0,0 +1 @@ +from .sampler import CellMapSampler diff --git a/src/cellmap_data/samplers/sampler.py b/src/cellmap_data/samplers/sampler.py new file mode 100644 index 0000000..3d70bf7 --- /dev/null +++ b/src/cellmap_data/samplers/sampler.py @@ -0,0 +1,26 @@ +from typing import Callable, Optional, Sequence, Sized +import torch +from torch.utils.data import Sampler + + +class CellMapSampler(Sampler): + def __init__( + self, + data_source: Sized, + indices: Sequence[int], + check_function: Optional[Callable] = None, + generator: Optional[torch.Generator] = None, + ): + super().__init__() + self.data_source = data_source + self.indices = indices + self.check_function = check_function + self.generator = generator + + def __iter__(self): + for i in torch.randperm(len(self.indices), generator=self.generator): + if self.check_function is None or self.check_function(self.indices[i]): + yield self.indices[i] + + def __len__(self): + return len(self.indices)