Skip to content

Commit

Permalink
refactor: ♻️ Add get_class_weights function
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed May 7, 2024
1 parent be10de8 commit 81666e6
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions src/cellmap_data/multidataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,10 @@ def get_subset_random_sampler(
indices = indices[torch.randperm(len(indices), generator=rng)]
return torch.utils.data.SubsetRandomSampler(indices, generator=rng)

def get_dataset_weights(self):
def get_class_weights(self):
"""
Returns the weights for each dataset in the multi-dataset based on the number of samples in each dataset.
Returns the class weights for the multi-dataset based on the number of samples in each class.
"""

if len(self.classes) > 1:
class_counts = {c: 0 for c in self.classes}
class_count_sum = 0
Expand All @@ -122,7 +121,14 @@ def get_dataset_weights(self):
c: 1 - (class_counts[c] / class_count_sum) for c in self.classes
}
else:
class_weights = {self.classes[0]: 1}
class_weights = {self.classes[0]: 0.1} # less than 1 to avoid overflow
return class_weights

def get_dataset_weights(self):
"""
Returns the weights for each dataset in the multi-dataset based on the number of samples in each dataset.
"""
class_weights = self.get_class_weights()

dataset_weights = {}
for dataset in self.datasets:
Expand Down

0 comments on commit 81666e6

Please sign in to comment.