diff --git a/src/cellmap_data/dataloader.py b/src/cellmap_data/dataloader.py index 301bdbc..0cf82d5 100644 --- a/src/cellmap_data/dataloader.py +++ b/src/cellmap_data/dataloader.py @@ -31,6 +31,7 @@ def __init__( sampler: Sampler | None = None, is_train: bool = True, rng: Optional[torch.Generator] = None, + **kwargs, ): self.dataset = dataset self.classes = classes @@ -47,19 +48,22 @@ def __init__( self.sampler = self.dataset.weighted_sampler(self.batch_size, self.rng) if torch.cuda.is_available(): self.dataset.to("cuda") - kwargs = { - "dataset": self.dataset, - "dataset": self.dataset, - "batch_size": self.batch_size, - "num_workers": self.num_workers, - "collate_fn": self.collate_fn, - } + kwargs.update( + { + "dataset": self.dataset, + "dataset": self.dataset, + "batch_size": self.batch_size, + "num_workers": self.num_workers, + "collate_fn": self.collate_fn, + } + ) if self.sampler is not None: kwargs["sampler"] = self.sampler elif self.is_train: kwargs["shuffle"] = True else: kwargs["shuffle"] = False + # TODO: Try persistent workers self.loader = DataLoader(**kwargs) def collate_fn(self, batch): diff --git a/src/cellmap_data/dataset.py b/src/cellmap_data/dataset.py index db37a36..6130cd5 100644 --- a/src/cellmap_data/dataset.py +++ b/src/cellmap_data/dataset.py @@ -110,7 +110,6 @@ def __init__( self.context = context self._rng = rng self.force_has_data = force_has_data - self._len = None self._current_center = None self._current_spatial_transforms = None self.input_sources = {} @@ -161,7 +160,7 @@ def __len__(self): """Returns the length of the dataset, determined by the number of coordinates that could be sampled as the center for a cube.""" if not self.has_data and not self.force_has_data: return 0 - if self._len is None: + if not hasattr(self, "_len"): size = np.prod([self.sampling_box_shape[c] for c in self.axis_order]) self._len = int(size) return self._len @@ -282,26 +281,28 @@ def sampling_box_shape(self) -> dict[str, int]: @property def class_weights(self) -> dict[str, float]: """ - Returns the class weights for the multi-dataset based on the number of samples in each class. + Returns the class weights for the dataset based on the number of samples in each class. """ - if len(self.classes) > 1: - class_counts = {c: 0 for c in self.classes} - class_count_sum = 0 - for c in self.classes: - class_counts[c] += self.class_counts["totals"][c] - class_count_sum += self.class_counts["totals"][c] - - class_weights = { - c: ( - 1 - (class_counts[c] / class_count_sum) - if class_counts[c] != class_count_sum - else 0.1 - ) - for c in self.classes - } - else: - class_weights = {self.classes[0]: 0.1} # less than 1 to avoid overflow - return class_weights + if not hasattr(self, "_class_weights"): + if len(self.classes) > 1: + class_counts = {c: 0 for c in self.classes} + class_count_sum = 0 + for c in self.classes: + class_counts[c] += self.class_counts["totals"][c] + class_count_sum += self.class_counts["totals"][c] + + class_weights = { + c: ( + 1 - (class_counts[c] / class_count_sum) + if class_counts[c] != class_count_sum + else 0.1 + ) + for c in self.classes + } + else: + class_weights = {self.classes[0]: 0.1} # less than 1 to avoid overflow + self._class_weights = class_weights + return self._class_weights @property def class_counts(self) -> Dict[str, Dict[str, int]]: @@ -319,10 +320,12 @@ def class_counts(self) -> Dict[str, Dict[str, int]]: @property def validation_indices(self) -> Sequence[int]: """Returns the indices of the dataset that will tile the dataset for validation.""" - chunk_size = {} - for c, size in self.bounding_box_shape.items(): - chunk_size[c] = np.ceil(size - self.sampling_box_shape[c]).astype(int) - return self.get_indices(chunk_size) + if not hasattr(self, "_validation_indices"): + chunk_size = {} + for c, size in self.bounding_box_shape.items(): + chunk_size[c] = np.ceil(size - self.sampling_box_shape[c]).astype(int) + self._validation_indices = self.get_indices(chunk_size) + return self._validation_indices def _get_box_shape(self, source_box: dict[str, list[float]]) -> dict[str, int]: box_shape = {} @@ -345,7 +348,7 @@ def _get_box( return current_box def verify(self): - """Verifies that the dataset is valid.""" + """Verifies that the dataset is valid to draw samples from.""" # TODO: make more robust try: return len(self) > 0 @@ -367,7 +370,6 @@ def get_indices(self, chunk_size: dict[str, int]) -> Sequence[int]: index = [indices_dict[c][j] for c, j in zip(self.axis_order, i)] index = np.ravel_multi_index(index, list(self.sampling_box_shape.values())) indices.append(index) - return indices def to(self, device): diff --git a/src/cellmap_data/datasplit.py b/src/cellmap_data/datasplit.py index 40700b0..98a03ac 100644 --- a/src/cellmap_data/datasplit.py +++ b/src/cellmap_data/datasplit.py @@ -119,23 +119,6 @@ def to_target(gt: torch.Tensor, classes: Sequence[str]) -> dict[str, torch.Tenso def __repr__(self): return f"CellMapDataSplit(\n\tInput arrays: {self.input_arrays}\n\tTarget arrays:{self.target_arrays}\n\tClasses: {self.classes}\n\tDataset dict: {self.dataset_dict}\n\tSpatial transforms: {self.spatial_transforms}\n\tRaw value transforms: {self.train_raw_value_transforms}\n\tGT value transforms: {self.target_value_transforms}\n\tForce has data: {self.force_has_data}\n\tContext: {self.context})" - def from_csv(self, csv_path): - # Load file data from csv file - dataset_dict = {} - with open(csv_path, "r") as f: - reader = csv.reader(f) - for row in reader: - if row[0] not in dataset_dict: - dataset_dict[row[0]] = [] - dataset_dict[row[0]].append( - { - "raw": os.path.join(row[1], row[2]), - "gt": os.path.join(row[3], row[4]) if len(row) > 3 else "", - } - ) - - return dataset_dict - @property def train_datasets_combined(self): if not hasattr(self, "_train_datasets_combined"): @@ -185,6 +168,23 @@ def class_counts(self): } return self._class_counts + def from_csv(self, csv_path): + # Load file data from csv file + dataset_dict = {} + with open(csv_path, "r") as f: + reader = csv.reader(f) + for row in reader: + if row[0] not in dataset_dict: + dataset_dict[row[0]] = [] + dataset_dict[row[0]].append( + { + "raw": os.path.join(row[1], row[2]), + "gt": os.path.join(row[3], row[4]) if len(row) > 3 else "", + } + ) + + return dataset_dict + def construct(self, dataset_dict): self.train_datasets = [] self.validation_datasets = []