diff --git a/src/cellmap_data/multidataset.py b/src/cellmap_data/multidataset.py index a8042bd..251bc4f 100644 --- a/src/cellmap_data/multidataset.py +++ b/src/cellmap_data/multidataset.py @@ -105,11 +105,10 @@ def get_subset_random_sampler( indices = indices[torch.randperm(len(indices), generator=rng)] return torch.utils.data.SubsetRandomSampler(indices, generator=rng) - def get_dataset_weights(self): + def get_class_weights(self): """ - Returns the weights for each dataset in the multi-dataset based on the number of samples in each dataset. + Returns the class weights for the multi-dataset based on the number of samples in each class. """ - if len(self.classes) > 1: class_counts = {c: 0 for c in self.classes} class_count_sum = 0 @@ -122,7 +121,14 @@ def get_dataset_weights(self): c: 1 - (class_counts[c] / class_count_sum) for c in self.classes } else: - class_weights = {self.classes[0]: 1} + class_weights = {self.classes[0]: 0.1} # less than 1 to avoid overflow + return class_weights + + def get_dataset_weights(self): + """ + Returns the weights for each dataset in the multi-dataset based on the number of samples in each dataset. + """ + class_weights = self.get_class_weights() dataset_weights = {} for dataset in self.datasets: