Skip to content

Commit

Permalink
refactor: 🚧 Add EmptyImage
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed Mar 27, 2024
1 parent 58e72a6 commit c940cf3
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 34 deletions.
24 changes: 0 additions & 24 deletions src/cellmap_data/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import _BaseDataLoaderIter
from .dataset import CellMapDataset
from .datasplit import CellMapDataSplit
from typing import Callable, Iterable

Expand All @@ -17,25 +15,3 @@ class CellMapDataLoader(DataLoader):
is_train: bool
augmentations: list[dict[str, any]]
to_target: Callable

def __init__(
...)

def __getitem__(self, idx: int) -> dict[str, tuple[torch.Tensor, torch.Tensor]]:
"""Returns the input and target data for the given index."""
...

def __len__(self) -> int:
"""Returns the length of the dataset."""
...

def __iter__(self):


def _apply_augmentations(self, data: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""Applies augmentations to the data."""
...

def _to_target(self, data: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""Converts the input data to the target data."""
...
42 changes: 32 additions & 10 deletions src/cellmap_data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
# %%
import csv
from typing import Callable, Dict, Iterable, Optional
import torch
from torch.utils.data import Dataset
import tensorstore as tswift
from fibsem_tools.io.core import read, read_xarray
from .image import CellMapImage
from .image import CellMapImage, EmptyImage


def split_gt_path(path: str) -> tuple[str, list[str]]:
"""Splits a path to groundtruth data into the main path string, and the classes supplied for it."""
path_prefix, path_rem = path.split("[")
classes, path_suffix = path_rem.split("]")
classes = classes.split(",")
path_string = path_prefix + "{label}" + path_suffix
return path_string, classes


# %%
Expand All @@ -17,7 +27,7 @@ class CellMapDataset(Dataset):
input_arrays: dict[str, dict[str, Iterable[int | float]]]
target_arrays: dict[str, dict[str, Iterable[int | float]]]
input_sources: dict[str, CellMapImage]
target_sources: dict[str, dict[str, CellMapImage]]
target_sources: dict[str, dict[str, CellMapImage | EmptyImage]]

def __init__(
self,
Expand Down Expand Up @@ -51,7 +61,8 @@ def __init__(
}
"""
self.raw_path = raw_path
self.gt_path = gt_path
self.gt_paths = gt_path
self.gt_path_str, self.classes_with_path = split_gt_path(gt_path)
self.classes = classes
self.input_arrays = input_arrays
self.target_arrays = target_arrays
Expand All @@ -65,7 +76,9 @@ def __getitem__(self, idx):
"""Returns a random crop of the input and target data as PyTorch tensors."""
...

def __iter__(self): ...
def __iter__(self):
"""Iterates over the dataset, covering each section of the bounding box. For instance, for calculating validation scores."""
...

def construct(self):
"""Constructs the input and target sources for the dataset."""
Expand All @@ -78,22 +91,31 @@ def construct(self):
array_info["shape"],
)
self.target_sources = {}
self.has_data = False
for array_name, array_info in self.target_arrays.items():
self.target_sources[array_name] = {}
for label in self.classes:
self.target_sources[array_name][label] = CellMapImage(
self.gt_path,
label,
array_info["scale"],
array_info["shape"],
)
if label in self.classes_with_path:
self.target_sources[array_name][label] = CellMapImage(
self.gt_path_str.format(label=label),
label,
array_info["scale"],
array_info["shape"],
)
self.has_data = True
else:
self.target_sources[array_name][label] = EmptyImage(
label, array_info["shape"]
)

@property
def bounding_box(self):
"""Returns the bounding box of the dataset."""
if self._bounding_box is None:
bounding_box = {c: [0, 2**32] for c in "xyz"}
for source in [self.input_sources.values(), self.target_sources.values()]:
if source.bounding_box is None:
continue
for c, (start, stop) in source.bounding_box.items():
bounding_box[c][0] = max(bounding_box[c][0], start)
bounding_box[c][1] = min(bounding_box[c][1], stop)
Expand Down
32 changes: 32 additions & 0 deletions src/cellmap_data/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,35 @@ def class_counts(self) -> int:
# TODO
...
return self._class_counts


class EmptyImage:
shape: tuple[float, ...]
label_class: str
class_count: int
store: torch.Tensor

def __init__(
self,
target_class: str,
target_voxel_shape: Iterable[int],
):
self.label_class = target_class
self.output_shape = tuple(target_voxel_shape)
self._bounding_box = None
self._class_counts = 0
self.store = torch.zeros(self.output_shape)

def __getitem__(self, center: Iterable[float]):
"""Returns image data centered around the given point, based on the scale and shape of the target output image."""
return self.store

@property
def bounding_box(self) -> None:
"""Returns the bounding box of the dataset."""
return self._bounding_box

@property
def class_counts(self) -> int:
"""Returns the number of pixels for the contained class in the ground truth data, normalized by the resolution."""
return self._class_counts

0 comments on commit c940cf3

Please sign in to comment.