Skip to content

Commit

Permalink
feat: ✨ Add Cellpose value transform.
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed Apr 5, 2024
1 parent 8dd846d commit ca76c7f
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 5 deletions.
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "hatchling.build"

# https://peps.python.org/pep-0621/
[project]
name = "cellmap.data"
name = "cellmap-data"
description = "Utility for loading CellMap data for machine learning training, utilizing PyTorch, and Jackson Lab's ZarrDataset."
readme = "README.md"
requires-python = ">=3.8"
Expand Down Expand Up @@ -58,8 +58,8 @@ all = [
]

[project.urls]
homepage = "https://github.com/rhoadesScholar/cellmap.data"
repository = "https://github.com/rhoadesScholar/cellmap.data"
homepage = "https://github.com/rhoadesScholar/cellmap-data"
repository = "https://github.com/rhoadesScholar/cellmap-data"

# same as console_scripts entry point
# [project.scripts]
Expand Down
2 changes: 1 addition & 1 deletion src/cellmap_data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@ def construct(self):
self._iter_coords = None
self._current_center = None
self._current_spatial_transforms = None
self._rng = None
self.input_sources = {}
for array_name, array_info in self.input_arrays.items():
self.input_sources[array_name] = CellMapImage(
Expand Down Expand Up @@ -235,6 +234,7 @@ def construct(self):

def generate_spatial_transforms(self) -> Optional[dict[str, any]]:
"""Generates spatial transforms for the dataset."""
# TODO: use torch random number generator so accerlerators can synchronize across workers
if self._rng is None:
self._rng = np.random.default_rng()
rng = self._rng
Expand Down
30 changes: 29 additions & 1 deletion src/cellmap_data/transforms/targets/cellpose.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,30 @@
from cellpose.dynamics import masks_to_flows_gpu_3d
from cellpose.dynamics import masks_to_flows_gpu_3d, masks_to_flows
from cellpose.dynamics import masks_to_flows_gpu as masks_to_flows_gpu_2d
import torch


class CellposeFlow:
def __init__(self, ndim: int, device: str | None = None):
self.ndim = ndim
if device is None:
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
_device = torch.device(device)
if device == "cuda" or device == "mps":
if ndim == 3:
flows_func = lambda x: masks_to_flows_gpu_3d(x, device=_device)
elif ndim == 2:
flows_func = lambda x: masks_to_flows_gpu_2d(x, device=_device)
else:
raise ValueError(f"Unsupported dimension {ndim}")
else:
flows_func = lambda x: masks_to_flows(x, device=_device)
self.flows_func = flows_func
self.device = _device

def __call__(self, masks):
return self.flows_func(masks)

0 comments on commit ca76c7f

Please sign in to comment.