Skip to content

Commit

Permalink
perf: 🎨 Remove use of numpy random number generator.
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed May 9, 2024
1 parent 176d568 commit 4964e0a
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 15 deletions.
30 changes: 16 additions & 14 deletions src/cellmap_data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(
is_train: bool = False,
axis_order: str = "zyx",
context: Optional[tensorstore.Context] = None, # type: ignore
rng: Optional[np.random.Generator] = None,
rng: Optional[torch.Generator] = None,
force_has_data: bool = False,
):
"""Initializes the CellMapDataset class.
Expand Down Expand Up @@ -101,7 +101,7 @@ def __init__(
target_value_transforms (Optional[Callable | Sequence[Callable] | dict[str, Callable]], optional): A function to convert the ground truth data to target arrays. Defaults to None. Example is to convert the ground truth data to a signed distance transform. May be a single function, a list of functions, or a dictionary of functions for each class. In the case of a list of functions, it is assumed that the functions correspond to each class in the classes list in order.
is_train (bool, optional): Whether the dataset is for training. Defaults to False.
context (Optional[tensorstore.Context], optional): The context for the image data. Defaults to None.
rng (Optional[np.random.Generator], optional): A random number generator. Defaults to None.
rng (Optional[torch.Generator], optional): A random number generator. Defaults to None.
force_has_data (bool, optional): Whether to force the dataset to report that it has data. Defaults to False.
"""
self.raw_path = raw_path
Expand All @@ -116,7 +116,7 @@ def __init__(
self.is_train = is_train
self.axis_order = axis_order
self.context = context
self._rng = rng
self.rng = rng
self.force_has_data = force_has_data
self._bounding_box = None
self._bounding_box_shape = None
Expand Down Expand Up @@ -217,8 +217,6 @@ def __getitem__(self, idx):
spatial_transforms
)
array = self.target_sources[array_name][label][center]
# if array.shape[0] != 1:
# array = array[None, ...]
class_arrays.append(array)
outputs[array_name] = torch.stack(class_arrays)
return outputs
Expand Down Expand Up @@ -402,10 +400,7 @@ def to(self, device):

def generate_spatial_transforms(self) -> Optional[dict[str, Any]]:
"""Generates spatial transforms for the dataset."""
# TODO: use torch random number generator so accerlerators can synchronize across workers
if self._rng is None:
self._rng = np.random.default_rng()
rng = self._rng
# *TODO: use torch random number generator so accerlerators can synchronize across workers

if not self.is_train or self.spatial_transforms is None:
return None
Expand All @@ -416,20 +411,27 @@ def generate_spatial_transforms(self) -> Optional[dict[str, Any]]:
# output: {"mirror": ["x", "y"]}
spatial_transforms[transform] = []
for axis, prob in params["axes"].items():
if rng.random() < prob:
if torch.rand(1, generator=self.rng).item() < prob:
spatial_transforms[transform].append(axis)
elif transform == "transpose":
# only reorder axes specified in params
# input: "transpose": {"axes": ["x", "z"]}
# output: {"transpose": {"x": 2, "y": 1, "z": 0}}
# params["axes"] = ["x", "z"]
# axes = {"x": 0, "y": 1, "z": 2}
axes = {axis: i for i, axis in enumerate(self.axis_order)}
shuffled_axes = rng.permutation(
[axes[a] for a in params["axes"]]
) # shuffle indices
# shuffled_axes = [0, 2]
shuffled_axes = [axes[a] for a in params["axes"]]
# shuffled_axes = [2, 0]
shuffled_axes = shuffled_axes[
torch.randperm(len(shuffled_axes), generator=self.rng)
] # shuffle indices
# shuffled_axes = {"x": 2, "z": 0}
shuffled_axes = {
axis: shuffled_axes[i] for i, axis in enumerate(params["axes"])
} # reassign axes
# axes = {"x": 2, "y": 1, "z": 0}
axes.update(shuffled_axes)
# output: {"transpose": {"x": 2, "y": 1, "z": 0}}
spatial_transforms[transform] = axes
else:
raise ValueError(f"Unknown spatial transform: {transform}")
Expand Down
6 changes: 5 additions & 1 deletion src/cellmap_data/multidataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,11 @@ def __repr__(self) -> str:

@property
def class_counts(self):
if not hasattr(self, "_class_counts"):
if (
not hasattr(self, "_class_counts")
or self._class_counts is None
or len(self._class_counts) == 0
):
class_counts = {}
for c in self.classes:
class_counts[c] = {}
Expand Down

0 comments on commit 4964e0a

Please sign in to comment.