Skip to content

Commit

Permalink
fix: 🐛 Debug validation_blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed May 3, 2024
1 parent 7422143 commit 2bd5ec8
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 28 deletions.
1 change: 1 addition & 0 deletions src/cellmap_data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@
from .datasplit import CellMapDataSplit
from .dataset import CellMapDataset
from .image import CellMapImage
from .subdataset import CellMapSubset
18 changes: 4 additions & 14 deletions src/cellmap_data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from torch.utils.data import DataLoader, Sampler, Subset
from .dataset import CellMapDataset
from .multidataset import CellMapMultiDataset
from .subdataset import CellMapSubset

from typing import Iterable, Optional

Expand All @@ -10,11 +11,7 @@ class CellMapDataLoader:
# TODO: docstring corrections
"""This subclasses PyTorch DataLoader to load CellMap data for training. It maintains the same API as the DataLoader class. This includes applying augmentations to the data and returning the data in the correct format for training, such as generating the target arrays (e.g. signed distance transform of labels). It retrieves raw and groundtruth data from a CellMapDataSplit object, which is a subclass of PyTorch Dataset. Training and validation data are split using the CellMapDataSplit object, and separate dataloaders are maintained as `train_loader` and `validate_loader` respectively."""

dataset: (
CellMapMultiDataset
| CellMapDataset
| Subset[CellMapDataset | CellMapMultiDataset]
)
dataset: CellMapMultiDataset | CellMapDataset | CellMapSubset
classes: Iterable[str]
loader = DataLoader
batch_size: int
Expand All @@ -26,11 +23,7 @@ class CellMapDataLoader:

def __init__(
self,
dataset: (
CellMapMultiDataset
| CellMapDataset
| Subset[CellMapDataset | CellMapMultiDataset]
),
dataset: CellMapMultiDataset | CellMapDataset | CellMapSubset,
classes: Iterable[str],
batch_size: int = 1,
num_workers: int = 0,
Expand All @@ -53,10 +46,7 @@ def __init__(
), "Weighted sampler only relevant for CellMapMultiDataset"
self.sampler = self.dataset.weighted_sampler(self.batch_size, self.rng)
if torch.cuda.is_available():
if isinstance(self.dataset, Subset):
self.dataset.dataset.to("cuda") # type: ignore
else:
self.dataset.to("cuda")
self.dataset.to("cuda")
kwargs = {
"dataset": self.dataset,
"dataset": self.dataset,
Expand Down
4 changes: 2 additions & 2 deletions src/cellmap_data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,8 @@ def __getitem__(self, idx):
spatial_transforms
)
array = self.target_sources[array_name][label][center]
if array.shape[0] != 1:
array = array[None, ...]
# if array.shape[0] != 1:
# array = array[None, ...]
class_arrays.append(array)
outputs[array_name] = torch.stack(class_arrays)
return outputs
Expand Down
30 changes: 22 additions & 8 deletions src/cellmap_data/datasplit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
from .dataset import CellMapDataset
from .multidataset import CellMapMultiDataset
from .subdataset import CellMapSubset


class CellMapDataSplit:
Expand All @@ -20,7 +21,7 @@ class CellMapDataSplit:
train_datasets: Sequence[CellMapDataset]
validation_datasets: Sequence[CellMapDataset]
spatial_transforms: Optional[dict[str, Any]] = None
raw_value_transforms: Optional[Callable] = None
train_raw_value_transforms: Optional[Callable] = None
target_value_transforms: Optional[
Callable | Sequence[Callable] | dict[str, Callable]
] = None
Expand All @@ -36,7 +37,8 @@ def __init__(
dataset_dict: Optional[Mapping[str, Sequence[Dict[str, str]]]] = None,
csv_path: Optional[str] = None,
spatial_transforms: Optional[dict[str, Any]] = None,
raw_value_transforms: Optional[Callable] = None,
train_raw_value_transforms: Optional[Callable] = None,
val_raw_value_transforms: Optional[Callable] = None,
target_value_transforms: Optional[
Callable | Sequence[Callable] | dict[str, Callable]
] = None,
Expand Down Expand Up @@ -105,15 +107,16 @@ def to_target(gt: torch.Tensor, classes: Sequence[str]) -> dict[str, torch.Tenso
elif csv_path is not None:
self.dataset_dict = self.from_csv(csv_path)
self.spatial_transforms = spatial_transforms
self.raw_value_transforms = raw_value_transforms
self.train_raw_value_transforms = train_raw_value_transforms
self.val_raw_value_transforms = val_raw_value_transforms
self.target_value_transforms = target_value_transforms
self.context = context
if self.dataset_dict is not None:
self.construct(self.dataset_dict)
self.verify_datasets()

def __repr__(self):
return f"CellMapDataSplit(\n\tInput arrays: {self.input_arrays}\n\tTarget arrays:{self.target_arrays}\n\tClasses: {self.classes}\n\tDataset dict: {self.dataset_dict}\n\tSpatial transforms: {self.spatial_transforms}\n\tRaw value transforms: {self.raw_value_transforms}\n\tGT value transforms: {self.target_value_transforms}\n\tForce has data: {self.force_has_data}\n\tContext: {self.context})"
return f"CellMapDataSplit(\n\tInput arrays: {self.input_arrays}\n\tTarget arrays:{self.target_arrays}\n\tClasses: {self.classes}\n\tDataset dict: {self.dataset_dict}\n\tSpatial transforms: {self.spatial_transforms}\n\tRaw value transforms: {self.train_raw_value_transforms}\n\tGT value transforms: {self.target_value_transforms}\n\tForce has data: {self.force_has_data}\n\tContext: {self.context})"

def from_csv(self, csv_path):
# Load file data from csv file
Expand Down Expand Up @@ -166,7 +169,7 @@ def validation_datasets_combined(self):
@property
def validation_blocks(self):
if not hasattr(self, "_validation_blocks"):
self._validation_blocks = torch.utils.data.Subset(
self._validation_blocks = CellMapSubset(
self.validation_datasets_combined,
self.validation_datasets_combined.get_validation_indices(),
)
Expand Down Expand Up @@ -195,7 +198,7 @@ def construct(self, dataset_dict):
self.input_arrays,
self.target_arrays,
self.spatial_transforms,
self.raw_value_transforms,
self.train_raw_value_transforms,
self.target_value_transforms,
is_train=True,
context=self.context,
Expand All @@ -219,6 +222,7 @@ def construct(self, dataset_dict):
self.classes,
self.input_arrays,
self.target_arrays,
raw_value_transforms=self.val_raw_value_transforms,
target_value_transforms=self.target_value_transforms,
is_train=False,
context=self.context,
Expand All @@ -231,6 +235,8 @@ def construct(self, dataset_dict):
self.datasets["validate"] = self.validation_datasets

def verify_datasets(self):
if self.force_has_data:
return
verified_datasets = []
for ds in self.train_datasets:
if ds.verify():
Expand All @@ -243,10 +249,18 @@ def verify_datasets(self):
verified_datasets.append(ds)
self.validation_datasets = verified_datasets

def set_raw_value_transforms(self, transforms: Callable):
def set_raw_value_transforms(
self, train_transforms: Callable, val_transforms: Callable
):
"""Sets the raw value transforms for each dataset in the training multi-dataset."""
for dataset in self.train_datasets:
dataset.set_raw_value_transforms(transforms)
dataset.set_raw_value_transforms(train_transforms)
if hasattr(self, "_train_datasets_combined"):
self._train_datasets_combined.set_raw_value_transforms(train_transforms)
for dataset in self.validation_datasets:
dataset.set_raw_value_transforms(val_transforms)
if hasattr(self, "_validation_datasets_combined"):
self._validation_datasets_combined.set_raw_value_transforms(val_transforms)

def set_target_value_transforms(self, transforms: Callable):
"""Sets the target value transforms for each dataset in the multi-datasets."""
Expand Down
8 changes: 4 additions & 4 deletions src/cellmap_data/multidataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def get_subset_random_sampler(
generator=rng,
)
else:
dataset_weights = self.get_dataset_weights()
dataset_weights = list(self.get_dataset_weights().values())

datasets_sampled = torch.multinomial(
torch.tensor(dataset_weights), num_samples, replacement=True
Expand Down Expand Up @@ -120,15 +120,15 @@ def get_dataset_weights(self):
class_weights = {
c: 1 - (class_counts[c] / class_count_sum) for c in self.classes
}
dataset_weights = []
dataset_weights = {}
for dataset in self.datasets:
dataset_weight = np.sum(
[
dataset.class_counts["totals"][c] * class_weights[c]
for c in self.classes
]
)
dataset_weights.append(dataset_weight)
dataset_weights[dataset] = dataset_weight
return dataset_weights

def get_sample_weights(self):
Expand All @@ -149,7 +149,7 @@ def get_validation_indices(self) -> Sequence[int]:
validation_indices = []
index_offset = 0
for dataset in self.datasets:
validation_indices.append(dataset.get_validation_indices())
validation_indices.extend(dataset.get_validation_indices())
index_offset += len(dataset)
return validation_indices

Expand Down
18 changes: 18 additions & 0 deletions src/cellmap_data/subdataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from torch.utils.data import Dataset


class CellMapSubset(Dataset):
def __init__(self, dataset, indices):
super().__init__()
self.dataset = dataset
self.indices = indices

def __getitem__(self, idx):
return self.dataset[self.indices[idx]]

def __len__(self):
return len(self.indices)

def to(self, device):
self.dataset.to(device)
return self

0 comments on commit 2bd5ec8

Please sign in to comment.