Skip to content

Commit

Permalink
refactor: 🐛 Remove hidden property declarations
Browse files Browse the repository at this point in the history
Also allow kwargs to dataloader
  • Loading branch information
rhoadesScholar committed May 13, 2024
1 parent 001054e commit e2fd269
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 51 deletions.
18 changes: 11 additions & 7 deletions src/cellmap_data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
sampler: Sampler | None = None,
is_train: bool = True,
rng: Optional[torch.Generator] = None,
**kwargs,
):
self.dataset = dataset
self.classes = classes
Expand All @@ -47,19 +48,22 @@ def __init__(
self.sampler = self.dataset.weighted_sampler(self.batch_size, self.rng)
if torch.cuda.is_available():
self.dataset.to("cuda")
kwargs = {
"dataset": self.dataset,
"dataset": self.dataset,
"batch_size": self.batch_size,
"num_workers": self.num_workers,
"collate_fn": self.collate_fn,
}
kwargs.update(
{
"dataset": self.dataset,
"dataset": self.dataset,
"batch_size": self.batch_size,
"num_workers": self.num_workers,
"collate_fn": self.collate_fn,
}
)
if self.sampler is not None:
kwargs["sampler"] = self.sampler
elif self.is_train:
kwargs["shuffle"] = True
else:
kwargs["shuffle"] = False
# TODO: Try persistent workers
self.loader = DataLoader(**kwargs)

def collate_fn(self, batch):
Expand Down
56 changes: 29 additions & 27 deletions src/cellmap_data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ def __init__(
self.context = context
self._rng = rng
self.force_has_data = force_has_data
self._len = None
self._current_center = None
self._current_spatial_transforms = None
self.input_sources = {}
Expand Down Expand Up @@ -161,7 +160,7 @@ def __len__(self):
"""Returns the length of the dataset, determined by the number of coordinates that could be sampled as the center for a cube."""
if not self.has_data and not self.force_has_data:
return 0
if self._len is None:
if not hasattr(self, "_len"):
size = np.prod([self.sampling_box_shape[c] for c in self.axis_order])
self._len = int(size)
return self._len
Expand Down Expand Up @@ -282,26 +281,28 @@ def sampling_box_shape(self) -> dict[str, int]:
@property
def class_weights(self) -> dict[str, float]:
"""
Returns the class weights for the multi-dataset based on the number of samples in each class.
Returns the class weights for the dataset based on the number of samples in each class.
"""
if len(self.classes) > 1:
class_counts = {c: 0 for c in self.classes}
class_count_sum = 0
for c in self.classes:
class_counts[c] += self.class_counts["totals"][c]
class_count_sum += self.class_counts["totals"][c]

class_weights = {
c: (
1 - (class_counts[c] / class_count_sum)
if class_counts[c] != class_count_sum
else 0.1
)
for c in self.classes
}
else:
class_weights = {self.classes[0]: 0.1} # less than 1 to avoid overflow
return class_weights
if not hasattr(self, "_class_weights"):
if len(self.classes) > 1:
class_counts = {c: 0 for c in self.classes}
class_count_sum = 0
for c in self.classes:
class_counts[c] += self.class_counts["totals"][c]
class_count_sum += self.class_counts["totals"][c]

class_weights = {
c: (
1 - (class_counts[c] / class_count_sum)
if class_counts[c] != class_count_sum
else 0.1
)
for c in self.classes
}
else:
class_weights = {self.classes[0]: 0.1} # less than 1 to avoid overflow
self._class_weights = class_weights
return self._class_weights

@property
def class_counts(self) -> Dict[str, Dict[str, int]]:
Expand All @@ -319,10 +320,12 @@ def class_counts(self) -> Dict[str, Dict[str, int]]:
@property
def validation_indices(self) -> Sequence[int]:
"""Returns the indices of the dataset that will tile the dataset for validation."""
chunk_size = {}
for c, size in self.bounding_box_shape.items():
chunk_size[c] = np.ceil(size - self.sampling_box_shape[c]).astype(int)
return self.get_indices(chunk_size)
if not hasattr(self, "_validation_indices"):
chunk_size = {}
for c, size in self.bounding_box_shape.items():
chunk_size[c] = np.ceil(size - self.sampling_box_shape[c]).astype(int)
self._validation_indices = self.get_indices(chunk_size)
return self._validation_indices

def _get_box_shape(self, source_box: dict[str, list[float]]) -> dict[str, int]:
box_shape = {}
Expand All @@ -345,7 +348,7 @@ def _get_box(
return current_box

def verify(self):
"""Verifies that the dataset is valid."""
"""Verifies that the dataset is valid to draw samples from."""
# TODO: make more robust
try:
return len(self) > 0
Expand All @@ -367,7 +370,6 @@ def get_indices(self, chunk_size: dict[str, int]) -> Sequence[int]:
index = [indices_dict[c][j] for c, j in zip(self.axis_order, i)]
index = np.ravel_multi_index(index, list(self.sampling_box_shape.values()))
indices.append(index)

return indices

def to(self, device):
Expand Down
34 changes: 17 additions & 17 deletions src/cellmap_data/datasplit.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,23 +119,6 @@ def to_target(gt: torch.Tensor, classes: Sequence[str]) -> dict[str, torch.Tenso
def __repr__(self):
return f"CellMapDataSplit(\n\tInput arrays: {self.input_arrays}\n\tTarget arrays:{self.target_arrays}\n\tClasses: {self.classes}\n\tDataset dict: {self.dataset_dict}\n\tSpatial transforms: {self.spatial_transforms}\n\tRaw value transforms: {self.train_raw_value_transforms}\n\tGT value transforms: {self.target_value_transforms}\n\tForce has data: {self.force_has_data}\n\tContext: {self.context})"

def from_csv(self, csv_path):
# Load file data from csv file
dataset_dict = {}
with open(csv_path, "r") as f:
reader = csv.reader(f)
for row in reader:
if row[0] not in dataset_dict:
dataset_dict[row[0]] = []
dataset_dict[row[0]].append(
{
"raw": os.path.join(row[1], row[2]),
"gt": os.path.join(row[3], row[4]) if len(row) > 3 else "",
}
)

return dataset_dict

@property
def train_datasets_combined(self):
if not hasattr(self, "_train_datasets_combined"):
Expand Down Expand Up @@ -185,6 +168,23 @@ def class_counts(self):
}
return self._class_counts

def from_csv(self, csv_path):
# Load file data from csv file
dataset_dict = {}
with open(csv_path, "r") as f:
reader = csv.reader(f)
for row in reader:
if row[0] not in dataset_dict:
dataset_dict[row[0]] = []
dataset_dict[row[0]].append(
{
"raw": os.path.join(row[1], row[2]),
"gt": os.path.join(row[3], row[4]) if len(row) > 3 else "",
}
)

return dataset_dict

def construct(self, dataset_dict):
self.train_datasets = []
self.validation_datasets = []
Expand Down

0 comments on commit e2fd269

Please sign in to comment.