diff --git a/src/scip/segmentation/cellpose.py b/src/scip/segmentation/cellpose.py index d500672..c8c2688 100644 --- a/src/scip/segmentation/cellpose.py +++ b/src/scip/segmentation/cellpose.py @@ -22,6 +22,20 @@ import torch +def _get_gpu_device(worker): + if isinstance(get_client().cluster, LocalCluster): + gpu_id = '0' + else: + gpu_workers = [ + address + for address, w in get_client().scheduler_info()["workers"].items() + if "cellpose" in w["resources"] + ] + gpu_id = gpu_workers.index(worker.address) + + return torch.device(f'cuda:{gpu_id}') + + def segment_block( events: List[Mapping[str, Any]], *, @@ -53,27 +67,16 @@ def segment_block( if len(events) == 0: return events - w = get_worker() - if hasattr(w, "cellpose"): - model = w.cellpose + worker = get_worker() + if hasattr(worker, "cellpose"): + model = worker.cellpose else: if gpu_accelerated: - - if isinstance(get_client().cluster, LocalCluster): - gpu_id = '0' - else: - gpu_workers = [ - address - for address, w in get_client().scheduler_info()["workers"].items() - if "cellpose" in w["resources"] - ] - gpu_id = gpu_workers.index(w.address) - - device = torch.device(f'cuda:{gpu_id}') + device = _get_gpu_device(worker) model = models.Cellpose(gpu=True, device=device, model_type='cyto2') else: model = models.Cellpose(gpu=False, model_type='cyto2') - w.cellpose = model + worker.cellpose = model parents, _, _, _ = model.eval( x=[e["pixels"][[parent_channel_index, dapi_channel_index]] for e in events],