Skip to content

Commit

Permalink
refactor: 🎨 Refactor calls
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed May 13, 2024
1 parent 68bbdff commit 036fc49
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 11 deletions.
2 changes: 1 addition & 1 deletion src/cellmap_data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(
assert isinstance(
self.dataset, CellMapMultiDataset
), "Weighted sampler only relevant for CellMapMultiDataset"
self.sampler = self.dataset.weighted_sampler(self.batch_size, self.rng)
self.sampler = self.dataset.get_weighted_sampler(self.batch_size, self.rng)
if torch.cuda.is_available():
self.dataset.to("cuda")
kwargs.update(
Expand Down
14 changes: 4 additions & 10 deletions src/cellmap_data/multidataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def __init__(
self.target_arrays = target_arrays
self.classes = classes
self.datasets = datasets
self._weighted_sampler = None

def __repr__(self) -> str:
out_string = f"CellMapMultiDataset(["
Expand Down Expand Up @@ -138,17 +137,12 @@ def to(self, device: str):
dataset.to(device)
return self

def weighted_sampler(
def get_weighted_sampler(
self, batch_size: int = 1, rng: Optional[torch.Generator] = None
):
if self._weighted_sampler is None:
# TODO: calculate weights for each sample
sample_weights = self.sample_weights

self._weighted_sampler = WeightedRandomSampler(
sample_weights, batch_size, replacement=False, generator=rng
)
return self._weighted_sampler
return WeightedRandomSampler(
self.sample_weights, batch_size, replacement=False, generator=rng
)

def get_subset_random_sampler(
self,
Expand Down

0 comments on commit 036fc49

Please sign in to comment.