diff --git a/src/cellmap_data/dataset.py b/src/cellmap_data/dataset.py index 18865d2..6634654 100644 --- a/src/cellmap_data/dataset.py +++ b/src/cellmap_data/dataset.py @@ -30,10 +30,9 @@ class CellMapDataset(Dataset): input_sources: dict[str, CellMapImage] target_sources: dict[str, dict[str, CellMapImage | EmptyImage]] spatial_transforms: Optional[Sequence[dict[str, any]]] # type: ignore - raw_value_transforms: Optional[Callable] # For instance, normalizing the raw data - gt_value_transforms: Optional[ - Callable | Sequence[Callable] | dict[str, Callable] - ] # For instance, converting the ground truth data to target arrays + raw_value_transforms: Optional[Callable | Sequence[Callable]] + gt_value_transforms: Optional[Callable | Sequence[Callable] | dict[str, Callable]] + context: Optional[tensorstore.Context] # type: ignore has_data: bool is_train: bool _bounding_box: Optional[Dict[str, list[int]]] @@ -53,7 +52,7 @@ def __init__( input_arrays: dict[str, dict[str, Sequence[int | float]]], target_arrays: dict[str, dict[str, Sequence[int | float]]], spatial_transforms: Optional[Sequence[dict[str, any]]] = None, # type: ignore - raw_value_transforms: Optional[Callable] = None, + raw_value_transforms: Optional[Callable | Sequence[Callable]] = None, gt_value_transforms: Optional[ Callable | Sequence[Callable] | dict[str, Callable] ] = None, diff --git a/src/cellmap_data/datasplit.py b/src/cellmap_data/datasplit.py index 0616caf..ffe3c9c 100644 --- a/src/cellmap_data/datasplit.py +++ b/src/cellmap_data/datasplit.py @@ -1,6 +1,7 @@ import csv from torchvision.transforms.v2 import RandomApply from typing import Callable, Dict, Iterable, Optional, Sequence +import tensorstore from .multidataset import CellMapMultiDataset from .dataset import CellMapDataset @@ -20,8 +21,12 @@ class CellMapDataSplit: validate_datasets: Iterable[CellMapDataset] train_datasets_combined: CellMapMultiDataset validate_datasets_combined: CellMapMultiDataset - # TODO: Correct transform passing (value for raw and labels, and spatial) - transforms: RandomApply | None + spatial_transforms: Optional[Sequence[dict[str, any]]] + raw_value_transforms: Optional[Callable | Sequence[Callable]] = None + gt_value_transforms: Optional[ + Callable | Sequence[Callable] | dict[str, Callable] + ] = None + context: Optional[tensorstore.Context] = None # type: ignore def __init__( self, @@ -32,7 +37,12 @@ def __init__( datasets: Optional[Dict[str, Iterable[CellMapDataset]]] = None, dataset_dict: Optional[Dict[str, Dict[str, str]]] = None, csv_path: Optional[str] = None, - transforms: Optional[Sequence[Callable]] = None, + spatial_transforms: Optional[Sequence[dict[str, any]]] = None, + raw_value_transforms: Optional[Callable | Sequence[Callable]] = None, + gt_value_transforms: Optional[ + Callable | Sequence[Callable] | dict[str, Callable] + ] = None, + context: Optional[tensorstore.Context] = None, # type: ignore ): """Initializes the CellMapDatasets class. @@ -71,7 +81,13 @@ def to_target(gt: torch.Tensor, classes: Sequence[str]) -> dict[str, torch.Tenso }. Defaults to None. csv_path (Optional[str], optional): A path to a csv file containing the dataset data. Defaults to None. Each row in the csv file should have the following structure: train | validate, raw path, gt path - transforms (Optional[Iterable[dict[str, any]]], optional): A list of transforms to apply to the data. Each augmentation should be a dictionary containing the following structure: + spatial_transforms (Optional[Sequence[dict[str, any]]], optional): A sequence of dictionaries containing the spatial transformations to apply to the data. The dictionary should have the following structure: + {transform_name: {transform_args}} + Defaults to None. + raw_value_transforms (Optional[Callable], optional): A function to apply to the raw data. Defaults to None. Example is to normalize the raw data. + gt_value_transforms (Optional[Callable | Sequence[Callable] | dict[str, Callable]], optional): A function to convert the ground truth data to target arrays. Defaults to None. Example is to convert the ground truth data to a signed distance transform. May be a single function, a list of functions, or a dictionary of functions for each class. In the case of a list of functions, it is assumed that the functions correspond to each class in the classes list in order. + is_train (bool, optional): Whether the dataset is for training. Defaults to False. + context (Optional[tensorstore.Context], optional): The context for the image data. Defaults to None. """ self.input_arrays = input_arrays self.target_arrays = target_arrays @@ -85,8 +101,11 @@ def to_target(gt: torch.Tensor, classes: Sequence[str]) -> dict[str, torch.Tenso self.construct(dataset_dict) elif csv_path is not None: self.from_csv(csv_path) - self.to_target = to_target - self.transforms = RandomApply(transforms) if transforms is not None else None + self.spatial_transforms = spatial_transforms + self.raw_value_transforms = raw_value_transforms + self.gt_value_transforms = gt_value_transforms + self.context = context + self.construct(self.dataset_dict) def from_csv(self, csv_path): # Load file data from csv file @@ -113,8 +132,9 @@ def construct(self, dataset_dict): self.classes, self.input_arrays, self.target_arrays, - self.to_target, - self.transforms, + self.spatial_transforms, + self.raw_value_transforms, + self.gt_value_transforms, ) ) for raw, gt in zip( @@ -127,7 +147,7 @@ def construct(self, dataset_dict): self.classes, self.input_arrays, self.target_arrays, - self.to_target, + gt_value_transforms=self.gt_value_transforms, ) ) self.train_datasets_combined = CellMapMultiDataset(