Skip to content

Commit

Permalink
fix: 🐛 Image and Dataset seem debugged.
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed Mar 31, 2024
1 parent 68ec58f commit 2dbf137
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 17 deletions.
30 changes: 15 additions & 15 deletions src/cellmap_data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# %%
import csv
from typing import Callable, Dict, Sequence, Optional
from typing import Callable, Dict, Generator, Sequence, Optional
import numpy as np
import torch
from torch.utils.data import Dataset
Expand Down Expand Up @@ -60,6 +60,7 @@ def __init__(
is_train: bool = False,
axis_order: str = "zyx",
context: Optional[tensorstore.Context] = None, # type: ignore
rng: Optional[Generator] = None,
):
"""Initializes the CellMapDataset class.
Expand Down Expand Up @@ -90,6 +91,7 @@ def __init__(
gt_value_transforms (Optional[Callable | Sequence[Callable] | dict[str, Callable]], optional): A function to convert the ground truth data to target arrays. Defaults to None. Example is to convert the ground truth data to a signed distance transform. May be a single function, a list of functions, or a dictionary of functions for each class. In the case of a list of functions, it is assumed that the functions correspond to each class in the classes list in order.
is_train (bool, optional): Whether the dataset is for training. Defaults to False.
context (Optional[tensorstore.Context], optional): The context for the image data. Defaults to None.
rng (Optional[Generator], optional): A random number generator. Defaults to None.
"""
self.raw_path = raw_path
self.gt_paths = gt_path
Expand All @@ -103,6 +105,7 @@ def __init__(
self.is_train = is_train
self.axis_order = axis_order
self.context = context
self._rng = rng
self.construct()

def __len__(self):
Expand All @@ -128,21 +131,17 @@ def __getitem__(self, idx):
spatial_transforms = self.generate_spatial_transforms()
outputs = {}
for array_name in self.input_arrays.keys():
if spatial_transforms is not None:
self.input_sources[array_name].set_spatial_transforms(
spatial_transforms
)
self.input_sources[array_name].set_spatial_transforms(spatial_transforms)
outputs[array_name] = self.input_sources[array_name][center][
None, None, ...
]
# TODO: Allow for distribtion of array gathering to multiple threads
for array_name in self.target_arrays.keys():
class_arrays = []
for label in self.classes:
if spatial_transforms is not None:
self.target_sources[array_name][label].set_spatial_transforms(
spatial_transforms
)
self.target_sources[array_name][label].set_spatial_transforms(
spatial_transforms
)
class_arrays.append(self.target_sources[array_name][label][center])
outputs[array_name] = torch.stack(class_arrays)[None, ...]
return outputs
Expand Down Expand Up @@ -208,12 +207,11 @@ def construct(self):
label, array_info["shape"], empty_store # type: ignore
)

def generate_spatial_transforms(self):
def generate_spatial_transforms(self) -> Optional[dict[str, any]]:
"""Generates spatial transforms for the dataset."""
if self._rng is None:
rng = np.random.default_rng()
else:
rng = self._rng
self._rng = np.random.default_rng()
rng = self._rng

if not self.is_train or self.spatial_transforms is None:
return None
Expand All @@ -223,7 +221,7 @@ def generate_spatial_transforms(self):
# input: "mirror": {"axes": {"x": 0.5, "y": 0.5, "z":0.1}}
# output: {"mirror": ["x", "y"]}
spatial_transforms[transform] = []
for axis, prob in params["axes"]:
for axis, prob in params["axes"].items():
if rng.random() < prob:
spatial_transforms[transform].append(axis)
elif transform == "transpose":
Expand All @@ -237,10 +235,12 @@ def generate_spatial_transforms(self):
shuffled_axes = {
axis: shuffled_axes[i] for i, axis in enumerate(params["axes"])
} # reassign axes
spatial_transforms[transform] = axes.update(shuffled_axes)
axes.update(shuffled_axes)
spatial_transforms[transform] = axes
else:
raise ValueError(f"Unknown spatial transform: {transform}")
self._current_spatial_transforms = spatial_transforms
return spatial_transforms

@property
def largest_voxel_sizes(self):
Expand Down
2 changes: 1 addition & 1 deletion src/cellmap_data/datasplit.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ def to_target(gt: torch.Tensor, classes: Sequence[str]) -> dict[str, torch.Tenso
Defaults to None.
raw_value_transforms (Optional[Callable], optional): A function to apply to the raw data. Defaults to None. Example is to normalize the raw data.
gt_value_transforms (Optional[Callable | Sequence[Callable] | dict[str, Callable]], optional): A function to convert the ground truth data to target arrays. Defaults to None. Example is to convert the ground truth data to a signed distance transform. May be a single function, a list of functions, or a dictionary of functions for each class. In the case of a list of functions, it is assumed that the functions correspond to each class in the classes list in order.
is_train (bool, optional): Whether the dataset is for training. Defaults to False.
force_has_data (bool, optional): Whether to force the dataset to have data. Defaults to False.
context (Optional[tensorstore.Context], optional): The context for the image data. Defaults to None.
"""
Expand Down Expand Up @@ -137,6 +136,7 @@ def construct(self, dataset_dict):
self.spatial_transforms,
self.raw_value_transforms,
self.gt_value_transforms,
is_train=True,
)
)

Expand Down
2 changes: 1 addition & 1 deletion src/cellmap_data/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def get_spatial_metadata(self):
}
break

def set_spatial_transforms(self, transforms: dict[str, any]):
def set_spatial_transforms(self, transforms: dict[str, any] | None):
"""Sets spatial transformations for the image data."""
self._current_spatial_transforms = transforms

Expand Down

0 comments on commit 2dbf137

Please sign in to comment.