Skip to content

Commit

Permalink
Merge pull request #5 from TheJacksonLaboratory/change_cellpose_models
Browse files Browse the repository at this point in the history
Change cellpose models
  • Loading branch information
fercer authored Aug 6, 2024
2 parents d35b2a3 + 477a71f commit 6543033
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 16 deletions.
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,15 @@ classifiers = [
]
requires-python = ">=3.10"
dependencies = [
"napari",
"numpy",
"dask[array]",
"magicgui",
"qtpy",
"scikit-image",
"tensorstore==0.1.59",
"ome-zarr==0.9.0",
"zarr",
"zarrdataset>=0.2.0",
]

Expand All @@ -49,10 +52,12 @@ testing = [
[project.entry-points."napari.manifest"]
napari-activelearning = "napari_activelearning:napari.yaml"

[project.urls]
Homepage = "https://github.com/TheJacksonLaboratory/activelearning"


[build-system]
requires = ["setuptools>=42.0.0", "wheel", "setuptools_scm"]
requires = ["setuptools>=42.0.0", "setuptools_scm"]
build-backend = "setuptools.build_meta"

[tool.setuptools]
Expand Down
2 changes: 0 additions & 2 deletions setup.cfg

This file was deleted.

7 changes: 3 additions & 4 deletions src/napari_activelearning/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,9 @@ def _get_transform(self):
return self._transform

def _fine_tune(self, train_data, train_labels, test_data, test_labels):
if self._model is None:
self._model_init()
self._model_init()

model_path = train.train_seg(
self._pretrained_model = train.train_seg(
self._model.net,
train_data=train_data,
train_labels=train_labels,
Expand Down Expand Up @@ -163,7 +162,7 @@ def _fine_tune(self, train_data, train_labels, test_data, test_labels):
model_name=self._model_name
)

self._model_init(pretrained_model=model_path)
self._model_init(pretrained_model=self._pretrained_model)

USING_CELLPOSE = True

Expand Down
25 changes: 16 additions & 9 deletions src/napari_activelearning/_models_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,11 @@ class CellposeTunableWidget(CellposeTunable, QWidget):
def __init__(self):
super().__init__()

(segmentation_parameters,
(self._segmentation_parameters,
segmentation_parameter_names) =\
cellpose_segmentation_parameters_widget()

(finetuning_parameters,
(self._finetuning_parameters,
finetuning_parameter_names) =\
cellpose_finetuning_parameters_widget()

Expand Down Expand Up @@ -166,7 +166,7 @@ def __init__(self):
Qt.ScrollBarAlwaysOff
)
self._segmentation_parameters_scr.setWidget(
segmentation_parameters.native
self._segmentation_parameters.native
)

self._finetuning_parameters_scr = QScrollArea()
Expand All @@ -175,16 +175,17 @@ def __init__(self):
Qt.ScrollBarAlwaysOff
)
self._finetuning_parameters_scr.setWidget(
finetuning_parameters.native
self._finetuning_parameters.native
)

for par_name in segmentation_parameter_names:
segmentation_parameters.__getattr__(par_name).changed.connect(
self._segmentation_parameters.__getattr__(par_name)\
.changed.connect(
partial(self._set_parameter, parameter_key="_" + par_name)
)

for par_name in finetuning_parameter_names:
finetuning_parameters.__getattr__(par_name).changed.connect(
self._finetuning_parameters.__getattr__(par_name).changed.connect(
partial(self._set_parameter, parameter_key="_" + par_name)
)

Expand Down Expand Up @@ -220,6 +221,12 @@ def _show_segmentation_parameters(self, show: bool):
def _show_finetuning_parameters(self, show: bool):
self._finetuning_parameters_scr.setVisible(show)

def _fine_tune(self, train_data, train_labels, test_data, test_labels):
super()._fine_tune(train_data, train_labels, test_data,
test_labels)
self._segmentation_parameters.pretrained_model.value =\
self._pretrained_model


def simple_segmentation_parameters_widget():
@magicgui(auto_call=True)
Expand All @@ -245,7 +252,7 @@ class SimpleTunableWidget(SimpleTunable, QWidget):
def __init__(self):
super().__init__()

(segmentation_parameters,
(self._segmentation_parameters,
segmentation_parameter_names) =\
simple_segmentation_parameters_widget()

Expand All @@ -264,11 +271,11 @@ def __init__(self):
Qt.ScrollBarAlwaysOff
)
self._segmentation_parameters_scr.setWidget(
segmentation_parameters.native
self._segmentation_parameters.native
)

for par_name in segmentation_parameter_names:
segmentation_parameters.__getattr__(par_name).changed.connect(
self._segmentation_parameters.__getattr__(par_name).changed.connect(
partial(self._set_parameter, parameter_key="_" + par_name)
)

Expand Down

0 comments on commit 6543033

Please sign in to comment.