Skip to content

Commit

Permalink
refactor: ♻️ Refactor transforms.
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed Mar 29, 2024
1 parent 12fcd33 commit 9a8b057
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 14 deletions.
9 changes: 4 additions & 5 deletions src/cellmap_data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,9 @@ class CellMapDataset(Dataset):
input_sources: dict[str, CellMapImage]
target_sources: dict[str, dict[str, CellMapImage | EmptyImage]]
spatial_transforms: Optional[Sequence[dict[str, any]]] # type: ignore
raw_value_transforms: Optional[Callable] # For instance, normalizing the raw data
gt_value_transforms: Optional[
Callable | Sequence[Callable] | dict[str, Callable]
] # For instance, converting the ground truth data to target arrays
raw_value_transforms: Optional[Callable | Sequence[Callable]]
gt_value_transforms: Optional[Callable | Sequence[Callable] | dict[str, Callable]]
context: Optional[tensorstore.Context] # type: ignore
has_data: bool
is_train: bool
_bounding_box: Optional[Dict[str, list[int]]]
Expand All @@ -53,7 +52,7 @@ def __init__(
input_arrays: dict[str, dict[str, Sequence[int | float]]],
target_arrays: dict[str, dict[str, Sequence[int | float]]],
spatial_transforms: Optional[Sequence[dict[str, any]]] = None, # type: ignore
raw_value_transforms: Optional[Callable] = None,
raw_value_transforms: Optional[Callable | Sequence[Callable]] = None,
gt_value_transforms: Optional[
Callable | Sequence[Callable] | dict[str, Callable]
] = None,
Expand Down
38 changes: 29 additions & 9 deletions src/cellmap_data/datasplit.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import csv
from torchvision.transforms.v2 import RandomApply
from typing import Callable, Dict, Iterable, Optional, Sequence
import tensorstore

from .multidataset import CellMapMultiDataset
from .dataset import CellMapDataset
Expand All @@ -20,8 +21,12 @@ class CellMapDataSplit:
validate_datasets: Iterable[CellMapDataset]
train_datasets_combined: CellMapMultiDataset
validate_datasets_combined: CellMapMultiDataset
# TODO: Correct transform passing (value for raw and labels, and spatial)
transforms: RandomApply | None
spatial_transforms: Optional[Sequence[dict[str, any]]]
raw_value_transforms: Optional[Callable | Sequence[Callable]] = None
gt_value_transforms: Optional[
Callable | Sequence[Callable] | dict[str, Callable]
] = None
context: Optional[tensorstore.Context] = None # type: ignore

def __init__(
self,
Expand All @@ -32,7 +37,12 @@ def __init__(
datasets: Optional[Dict[str, Iterable[CellMapDataset]]] = None,
dataset_dict: Optional[Dict[str, Dict[str, str]]] = None,
csv_path: Optional[str] = None,
transforms: Optional[Sequence[Callable]] = None,
spatial_transforms: Optional[Sequence[dict[str, any]]] = None,
raw_value_transforms: Optional[Callable | Sequence[Callable]] = None,
gt_value_transforms: Optional[
Callable | Sequence[Callable] | dict[str, Callable]
] = None,
context: Optional[tensorstore.Context] = None, # type: ignore
):
"""Initializes the CellMapDatasets class.
Expand Down Expand Up @@ -71,7 +81,13 @@ def to_target(gt: torch.Tensor, classes: Sequence[str]) -> dict[str, torch.Tenso
}. Defaults to None.
csv_path (Optional[str], optional): A path to a csv file containing the dataset data. Defaults to None. Each row in the csv file should have the following structure:
train | validate, raw path, gt path
transforms (Optional[Iterable[dict[str, any]]], optional): A list of transforms to apply to the data. Each augmentation should be a dictionary containing the following structure:
spatial_transforms (Optional[Sequence[dict[str, any]]], optional): A sequence of dictionaries containing the spatial transformations to apply to the data. The dictionary should have the following structure:
{transform_name: {transform_args}}
Defaults to None.
raw_value_transforms (Optional[Callable], optional): A function to apply to the raw data. Defaults to None. Example is to normalize the raw data.
gt_value_transforms (Optional[Callable | Sequence[Callable] | dict[str, Callable]], optional): A function to convert the ground truth data to target arrays. Defaults to None. Example is to convert the ground truth data to a signed distance transform. May be a single function, a list of functions, or a dictionary of functions for each class. In the case of a list of functions, it is assumed that the functions correspond to each class in the classes list in order.
is_train (bool, optional): Whether the dataset is for training. Defaults to False.
context (Optional[tensorstore.Context], optional): The context for the image data. Defaults to None.
"""
self.input_arrays = input_arrays
self.target_arrays = target_arrays
Expand All @@ -85,8 +101,11 @@ def to_target(gt: torch.Tensor, classes: Sequence[str]) -> dict[str, torch.Tenso
self.construct(dataset_dict)
elif csv_path is not None:
self.from_csv(csv_path)
self.to_target = to_target
self.transforms = RandomApply(transforms) if transforms is not None else None
self.spatial_transforms = spatial_transforms
self.raw_value_transforms = raw_value_transforms
self.gt_value_transforms = gt_value_transforms
self.context = context
self.construct(self.dataset_dict)

def from_csv(self, csv_path):
# Load file data from csv file
Expand All @@ -113,8 +132,9 @@ def construct(self, dataset_dict):
self.classes,
self.input_arrays,
self.target_arrays,
self.to_target,
self.transforms,
self.spatial_transforms,
self.raw_value_transforms,
self.gt_value_transforms,
)
)
for raw, gt in zip(
Expand All @@ -127,7 +147,7 @@ def construct(self, dataset_dict):
self.classes,
self.input_arrays,
self.target_arrays,
self.to_target,
gt_value_transforms=self.gt_value_transforms,
)
)
self.train_datasets_combined = CellMapMultiDataset(
Expand Down

0 comments on commit 9a8b057

Please sign in to comment.