From 4964e0a198b6a793c9fca681e88e5f21050ea2b7 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Thu, 9 May 2024 01:08:09 -0400 Subject: [PATCH] =?UTF-8?q?perf:=20=F0=9F=8E=A8=20Remove=20use=20of=20nump?= =?UTF-8?q?y=20random=20number=20generator.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/cellmap_data/dataset.py | 30 ++++++++++++++++-------------- src/cellmap_data/multidataset.py | 6 +++++- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/src/cellmap_data/dataset.py b/src/cellmap_data/dataset.py index 8c638a5..149bea5 100644 --- a/src/cellmap_data/dataset.py +++ b/src/cellmap_data/dataset.py @@ -69,7 +69,7 @@ def __init__( is_train: bool = False, axis_order: str = "zyx", context: Optional[tensorstore.Context] = None, # type: ignore - rng: Optional[np.random.Generator] = None, + rng: Optional[torch.Generator] = None, force_has_data: bool = False, ): """Initializes the CellMapDataset class. @@ -101,7 +101,7 @@ def __init__( target_value_transforms (Optional[Callable | Sequence[Callable] | dict[str, Callable]], optional): A function to convert the ground truth data to target arrays. Defaults to None. Example is to convert the ground truth data to a signed distance transform. May be a single function, a list of functions, or a dictionary of functions for each class. In the case of a list of functions, it is assumed that the functions correspond to each class in the classes list in order. is_train (bool, optional): Whether the dataset is for training. Defaults to False. context (Optional[tensorstore.Context], optional): The context for the image data. Defaults to None. - rng (Optional[np.random.Generator], optional): A random number generator. Defaults to None. + rng (Optional[torch.Generator], optional): A random number generator. Defaults to None. force_has_data (bool, optional): Whether to force the dataset to report that it has data. Defaults to False. """ self.raw_path = raw_path @@ -116,7 +116,7 @@ def __init__( self.is_train = is_train self.axis_order = axis_order self.context = context - self._rng = rng + self.rng = rng self.force_has_data = force_has_data self._bounding_box = None self._bounding_box_shape = None @@ -217,8 +217,6 @@ def __getitem__(self, idx): spatial_transforms ) array = self.target_sources[array_name][label][center] - # if array.shape[0] != 1: - # array = array[None, ...] class_arrays.append(array) outputs[array_name] = torch.stack(class_arrays) return outputs @@ -402,10 +400,7 @@ def to(self, device): def generate_spatial_transforms(self) -> Optional[dict[str, Any]]: """Generates spatial transforms for the dataset.""" - # TODO: use torch random number generator so accerlerators can synchronize across workers - if self._rng is None: - self._rng = np.random.default_rng() - rng = self._rng + # *TODO: use torch random number generator so accerlerators can synchronize across workers if not self.is_train or self.spatial_transforms is None: return None @@ -416,20 +411,27 @@ def generate_spatial_transforms(self) -> Optional[dict[str, Any]]: # output: {"mirror": ["x", "y"]} spatial_transforms[transform] = [] for axis, prob in params["axes"].items(): - if rng.random() < prob: + if torch.rand(1, generator=self.rng).item() < prob: spatial_transforms[transform].append(axis) elif transform == "transpose": # only reorder axes specified in params # input: "transpose": {"axes": ["x", "z"]} - # output: {"transpose": {"x": 2, "y": 1, "z": 0}} + # params["axes"] = ["x", "z"] + # axes = {"x": 0, "y": 1, "z": 2} axes = {axis: i for i, axis in enumerate(self.axis_order)} - shuffled_axes = rng.permutation( - [axes[a] for a in params["axes"]] - ) # shuffle indices + # shuffled_axes = [0, 2] + shuffled_axes = [axes[a] for a in params["axes"]] + # shuffled_axes = [2, 0] + shuffled_axes = shuffled_axes[ + torch.randperm(len(shuffled_axes), generator=self.rng) + ] # shuffle indices + # shuffled_axes = {"x": 2, "z": 0} shuffled_axes = { axis: shuffled_axes[i] for i, axis in enumerate(params["axes"]) } # reassign axes + # axes = {"x": 2, "y": 1, "z": 0} axes.update(shuffled_axes) + # output: {"transpose": {"x": 2, "y": 1, "z": 0}} spatial_transforms[transform] = axes else: raise ValueError(f"Unknown spatial transform: {transform}") diff --git a/src/cellmap_data/multidataset.py b/src/cellmap_data/multidataset.py index fc48145..b42b532 100644 --- a/src/cellmap_data/multidataset.py +++ b/src/cellmap_data/multidataset.py @@ -41,7 +41,11 @@ def __repr__(self) -> str: @property def class_counts(self): - if not hasattr(self, "_class_counts"): + if ( + not hasattr(self, "_class_counts") + or self._class_counts is None + or len(self._class_counts) == 0 + ): class_counts = {} for c in self.classes: class_counts[c] = {}