Skip to content

Commit

Permalink
perf: ⚡️ Refactor to clean code and catch bugs.
Browse files Browse the repository at this point in the history
Also begin working on sampler.
  • Loading branch information
rhoadesScholar committed May 10, 2024
1 parent 504eda9 commit 4b4589b
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 20 deletions.
33 changes: 14 additions & 19 deletions src/cellmap_data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,7 @@ def __init__(
self._rng = rng
self.force_has_data = force_has_data
self._bounding_box = None
self._bounding_box_shape = None
self._sampling_box = None
self._sampling_box_shape = None
self._class_counts = None
self._largest_voxel_sizes = None
self._len = None
Expand Down Expand Up @@ -193,7 +191,7 @@ def __getitem__(self, idx):
)
logger.warning(f"Returning closest index in bounds")
# TODO: This is a hacky temprorary fix. Need to figure out why this is happening
center = [self.sampling_box_shape[c] for c in self.axis_order]
center = [self.sampling_box_shape[c] - 1 for c in self.axis_order]
center = {
c: center[i] * self.largest_voxel_sizes[c] + self.sampling_box[c][0]
for i, c in enumerate(self.axis_order)
Expand Down Expand Up @@ -267,13 +265,8 @@ def bounding_box(self):
@property
def bounding_box_shape(self):
"""Returns the shape of the bounding box of the dataset in voxels of the largest voxel size."""
if self._bounding_box_shape is None:
bounding_box_shape = {c: 0 for c in self.axis_order}
for c, (start, stop) in self.bounding_box.items():
size = stop - start
size /= self.largest_voxel_sizes[c]
bounding_box_shape[c] = int(size)
self._bounding_box_shape = bounding_box_shape
if not hasattr(self, "_bounding_box_shape"):
self._bounding_box_shape = self._get_box_shape(self.bounding_box)
return self._bounding_box_shape

@property
Expand All @@ -295,13 +288,8 @@ def sampling_box(self):
@property
def sampling_box_shape(self):
"""Returns the shape of the sampling box of the dataset in voxels of the largest voxel size."""
if self._sampling_box_shape is None:
sampling_box_shape = {}
for c, (start, stop) in self.sampling_box.items():
size = stop - start
size /= self.largest_voxel_sizes[c]
sampling_box_shape[c] = int(np.floor(size))
self._sampling_box_shape = sampling_box_shape
if not hasattr(self, "_sampling_box_shape"):
self._sampling_box_shape = self._get_box_shape(self.sampling_box)
return self._sampling_box_shape

@property
Expand Down Expand Up @@ -349,6 +337,14 @@ def validation_indices(self) -> Sequence[int]:
chunk_size[c] = np.ceil(size - self.sampling_box_shape[c]).astype(int)
return self.get_indices(chunk_size)

def _get_box_shape(self, source_box: dict[str, list[int]]) -> dict[str, int]:
box_shape = {}
for c, (start, stop) in source_box.items():
size = stop - start
size /= self.largest_voxel_sizes[c]
box_shape[c] = int(np.floor(size))
return box_shape

def _get_box(
self, source_box: dict[str, list[int]], current_box: dict[str, list[int]]
) -> dict[str, list[int]]:
Expand All @@ -363,8 +359,7 @@ def verify(self):
"""Verifies that the dataset is valid."""
# TODO: make more robust
try:
length = len(self)
return True
return len(self) > 0
except Exception as e:
# print(e)
return False
Expand Down
1 change: 1 addition & 0 deletions src/cellmap_data/datasplit.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def to_target(gt: torch.Tensor, classes: Sequence[str]) -> dict[str, torch.Tenso
if self.dataset_dict is not None:
self.construct(self.dataset_dict)
self.verify_datasets()
assert len(self.train_datasets) > 0, "No valid training datasets found."

def __repr__(self):
return f"CellMapDataSplit(\n\tInput arrays: {self.input_arrays}\n\tTarget arrays:{self.target_arrays}\n\tClasses: {self.classes}\n\tDataset dict: {self.dataset_dict}\n\tSpatial transforms: {self.spatial_transforms}\n\tRaw value transforms: {self.train_raw_value_transforms}\n\tGT value transforms: {self.target_value_transforms}\n\tForce has data: {self.force_has_data}\n\tContext: {self.context})"
Expand Down
7 changes: 6 additions & 1 deletion src/cellmap_data/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def __init__(
target_voxel_shape: Sequence[int],
store: Optional[torch.Tensor] = None,
axis_order: str = "zyx",
empty_value: float | int = -100,
):
"""Initializes an empty image object.
Expand All @@ -305,10 +306,14 @@ def __init__(
self._bounding_box = None
self._class_counts = 0
self.scale = {c: 1 for c in self.axes}
self.empty_value = empty_value
if store is not None:
self.store = store
else:
self.store = torch.zeros([1] + [self.output_shape[c] for c in self.axes])
self.store = (
torch.ones([1] + [self.output_shape[c] for c in self.axes])
* self.empty_value
)

def __getitem__(self, center: dict[str, float]) -> torch.Tensor:
"""Returns image data centered around the given point, based on the scale and shape of the target output image."""
Expand Down
1 change: 1 addition & 0 deletions src/cellmap_data/samplers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .sampler import CellMapSampler
26 changes: 26 additions & 0 deletions src/cellmap_data/samplers/sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Callable, Optional, Sequence, Sized
import torch
from torch.utils.data import Sampler


class CellMapSampler(Sampler):
def __init__(
self,
data_source: Sized,
indices: Sequence[int],
check_function: Optional[Callable] = None,
generator: Optional[torch.Generator] = None,
):
super().__init__()
self.data_source = data_source
self.indices = indices
self.check_function = check_function
self.generator = generator

def __iter__(self):
for i in torch.randperm(len(self.indices), generator=self.generator):
if self.check_function is None or self.check_function(self.indices[i]):
yield self.indices[i]

def __len__(self):
return len(self.indices)

0 comments on commit 4b4589b

Please sign in to comment.