From ca76c7f82723da5e776bf14639c5aa873c6f36de Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Fri, 5 Apr 2024 14:59:30 -0400 Subject: [PATCH] =?UTF-8?q?feat:=20=E2=9C=A8=20Add=20Cellpose=20value=20tr?= =?UTF-8?q?ansform.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 6 ++-- src/cellmap_data/dataset.py | 2 +- .../transforms/targets/cellpose.py | 30 ++++++++++++++++++- 3 files changed, 33 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2796160..2e6dfa9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "hatchling.build" # https://peps.python.org/pep-0621/ [project] -name = "cellmap.data" +name = "cellmap-data" description = "Utility for loading CellMap data for machine learning training, utilizing PyTorch, and Jackson Lab's ZarrDataset." readme = "README.md" requires-python = ">=3.8" @@ -58,8 +58,8 @@ all = [ ] [project.urls] -homepage = "https://github.com/rhoadesScholar/cellmap.data" -repository = "https://github.com/rhoadesScholar/cellmap.data" +homepage = "https://github.com/rhoadesScholar/cellmap-data" +repository = "https://github.com/rhoadesScholar/cellmap-data" # same as console_scripts entry point # [project.scripts] diff --git a/src/cellmap_data/dataset.py b/src/cellmap_data/dataset.py index 394a251..e4724bc 100644 --- a/src/cellmap_data/dataset.py +++ b/src/cellmap_data/dataset.py @@ -192,7 +192,6 @@ def construct(self): self._iter_coords = None self._current_center = None self._current_spatial_transforms = None - self._rng = None self.input_sources = {} for array_name, array_info in self.input_arrays.items(): self.input_sources[array_name] = CellMapImage( @@ -235,6 +234,7 @@ def construct(self): def generate_spatial_transforms(self) -> Optional[dict[str, any]]: """Generates spatial transforms for the dataset.""" + # TODO: use torch random number generator so accerlerators can synchronize across workers if self._rng is None: self._rng = np.random.default_rng() rng = self._rng diff --git a/src/cellmap_data/transforms/targets/cellpose.py b/src/cellmap_data/transforms/targets/cellpose.py index b39a234..0419133 100644 --- a/src/cellmap_data/transforms/targets/cellpose.py +++ b/src/cellmap_data/transforms/targets/cellpose.py @@ -1,2 +1,30 @@ -from cellpose.dynamics import masks_to_flows_gpu_3d +from cellpose.dynamics import masks_to_flows_gpu_3d, masks_to_flows from cellpose.dynamics import masks_to_flows_gpu as masks_to_flows_gpu_2d +import torch + + +class CellposeFlow: + def __init__(self, ndim: int, device: str | None = None): + self.ndim = ndim + if device is None: + if torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" + _device = torch.device(device) + if device == "cuda" or device == "mps": + if ndim == 3: + flows_func = lambda x: masks_to_flows_gpu_3d(x, device=_device) + elif ndim == 2: + flows_func = lambda x: masks_to_flows_gpu_2d(x, device=_device) + else: + raise ValueError(f"Unsupported dimension {ndim}") + else: + flows_func = lambda x: masks_to_flows(x, device=_device) + self.flows_func = flows_func + self.device = _device + + def __call__(self, masks): + return self.flows_func(masks)