Skip to content

Commit

Permalink
Surface dice loss (#45)
Browse files Browse the repository at this point in the history
* Surface-Dice functionalities

* Adjust losses to be compatible with Surface-Dice exclusions

* Adjust training routine and include surface dice loss

* Add dataset labels to dataloading

* Pass Surface-Dice arguments to training routine

* Update CLI to include advanced options for Surface-Dice

* precommit formatting

* make list readable by passing argument multiple times

* remove redundant import

* Compatibility with updated masked_surface_dice function

* Add training summary and remove wandb logging

* remove reduntant print statements and include ds_labels into for-loop

* Implement Gaussian smoothing with torch to compute everything on GPU

* Training summary printing

* add dataset token to CLI

* Add dataset token to filename

* Update warnings

* Fix bug for accuracy masking

* Fix Dice reduction to scalar

* Make test compatible with CombinedLoss

* Fix default path

* Raise Error when reduction is not defined

* Add required dimensions to docstrings
  • Loading branch information
LorenzLamm authored Jan 22, 2024
1 parent 3fd4de9 commit 49aa798
Show file tree
Hide file tree
Showing 12 changed files with 902 additions and 54 deletions.
6 changes: 6 additions & 0 deletions src/membrain_seg/annotations/extract_patch_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ def extract_patches(
help="Path to the folder where extracted patches should be stored. \
(subdirectories will be created)",
),
ds_token: str = Option( # noqa: B008
"other",
help="Dataset token. Important for distinguishing between different \
datasets. Should NOT contain underscores!",
),
coords_file: str = Option( # noqa: B008
None,
help="Path to a file containing coordinates for patch extraction. The file \
Expand Down Expand Up @@ -93,6 +98,7 @@ def extract_patches(
coords=coords,
out_dir=out_folder,
idx_add=idx_add,
ds_token=ds_token,
token=token,
pad_value=pad_value,
)
31 changes: 21 additions & 10 deletions src/membrain_seg/annotations/extract_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def pad_labels(patch, padding, pad_value=2.0):


def get_out_files_and_patch_number(
token, out_folder_raw, out_folder_lab, patch_nr, idx_add
ds_token, token, out_folder_raw, out_folder_lab, patch_nr, idx_add
):
"""
Create filenames and corrected patch numbers.
Expand All @@ -62,8 +62,10 @@ def get_out_files_and_patch_number(
Parameters
----------
ds_token : str
The dataset identifier used as a part of the filename.
token : str
The unique identifier used as a part of the filename.
The tomogram identifier used as a part of the filename.
out_folder_raw : str
The directory path where raw data patches are stored.
out_folder_lab : str
Expand Down Expand Up @@ -96,27 +98,34 @@ def get_out_files_and_patch_number(
"""
patch_nr += idx_add
out_file_patch = os.path.join(
out_folder_raw, token + "_patch" + str(patch_nr) + "_raw.nii.gz"
out_folder_raw, ds_token + "_" + token + "_patch" + str(patch_nr) + ".nii.gz"
)
out_file_patch_label = os.path.join(
out_folder_lab, token + "_patch" + str(patch_nr) + "_labels.nii.gz"
out_folder_lab, ds_token + "_" + token + "_patch" + str(patch_nr) + ".nii.gz"
)
exist_add = 0
while os.path.isfile(out_file_patch):
exist_add += 1
out_file_patch = os.path.join(
out_folder_raw,
token + "_patch" + str(patch_nr + exist_add) + "_raw.nii.gz",
ds_token + "_" + token + "_patch" + str(patch_nr + exist_add) + ".nii.gz",
)
out_file_patch_label = os.path.join(
out_folder_lab,
token + "_patch" + str(patch_nr + exist_add) + "_labels.nii.gz",
ds_token + "_" + token + "_patch" + str(patch_nr + exist_add) + ".nii.gz",
)
return patch_nr + exist_add, out_file_patch, out_file_patch_label


def extract_patches(
tomo_path, seg_path, coords, out_dir, idx_add=0, token=None, pad_value=2.0
tomo_path,
seg_path,
coords,
out_dir,
ds_token="other",
token=None,
idx_add=0,
pad_value=2.0,
):
"""
Extracts 3D patches from a given tomogram and corresponding segmentation.
Expand All @@ -133,11 +142,13 @@ def extract_patches(
List of tuples where each tuple represents the 3D coordinates of a patch center.
out_dir : str
The output directory where the extracted patches will be saved.
idx_add : int, optional
The index addition for patch numbering, default is 0.
ds_token : str, optional
Dataset token to uniquely identify the dataset, default is 'other'.
token : str, optional
Token to uniquely identify the tomogram, default is None. If None,
the base name of the tomogram file path is used.
idx_add : int, optional
The index addition for patch numbering, default is 0.
pad_value: float, optional
Borders of extracted patch are padded with this value ("ignore" label)
Expand Down Expand Up @@ -170,7 +181,7 @@ def extract_patches(

for patch_nr, cur_coords in enumerate(coords):
patch_nr, out_file_patch, out_file_patch_label = get_out_files_and_patch_number(
token, out_folder_raw, out_folder_lab, patch_nr, idx_add
ds_token, token, out_folder_raw, out_folder_lab, patch_nr, idx_add
)
print("Extracting patch nr", patch_nr, "from tomo", token)
try:
Expand Down
8 changes: 5 additions & 3 deletions src/membrain_seg/annotations/merge_corrections.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,15 @@ def get_corrections_from_folder(folder_name, orig_pred_file):
or filename.startswith("Ignore")
or filename.startswith("ignore")
):
print("ATTENTION! Not processing", filename)
print("Is this intended?")
print(
"File does not fit into Add/Remove/Ignore naming! " "Not processing",
filename,
)
continue
readdata = sitk.GetArrayFromImage(
sitk.ReadImage(os.path.join(folder_name, filename))
)
print("Adding file", filename, "<--")
print("Adding file", filename)

if filename.startswith("Add") or filename.startswith("add"):
add_patch += readdata
Expand Down
30 changes: 29 additions & 1 deletion src/membrain_seg/segmentation/cli/train_cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import List, Optional

from typer import Option
from typing_extensions import Annotated

from ..train import train as _train
from .cli import OPTION_PROMPT_KWARGS as PKWARGS
Expand Down Expand Up @@ -70,7 +73,7 @@ def train_advanced(
help="Batch size for training.",
),
num_workers: int = Option( # noqa: B008
1,
8,
help="Number of worker threads for loading data",
),
max_epochs: int = Option( # noqa: B008
Expand All @@ -84,6 +87,22 @@ def train_advanced(
but also severely increases training time.\
Pass "True" or "False".',
),
use_surface_dice: bool = Option( # noqa: B008
False, help='Whether to use Surface-Dice as a loss. Pass "True" or "False".'
),
surface_dice_weight: float = Option( # noqa: B008
1.0, help="Scaling factor for the Surface-Dice loss. "
),
surface_dice_tokens: Annotated[
Optional[List[str]],
Option(
help='List of tokens to \
use for the Surface-Dice loss. \
Pass tokens separately:\
For example, train_advanced --surface_dice_tokens "ds1" \
--surface_dice_tokens "ds2"'
),
] = None,
use_deep_supervision: bool = Option( # noqa: B008
True, help='Whether to use deep supervision. Pass "True" or "False".'
),
Expand Down Expand Up @@ -119,6 +138,12 @@ def train_advanced(
If set to False, data augmentation still happens, but not as frequently.
More data augmentation can lead to a better performance, but also increases the
training time substantially.
use_surface_dice : bool
Determines whether to use Surface-Dice loss, by default True.
surface_dice_weight : float
Scaling factor for the Surface-Dice loss, by default 1.0.
surface_dice_tokens : list
List of tokens to use for the Surface-Dice loss, by default ["all"].
use_deep_supervision : bool
Determines whether to use deep supervision, by default True.
project_name : str
Expand All @@ -140,6 +165,9 @@ def train_advanced(
max_epochs=max_epochs,
aug_prob_to_one=aug_prob_to_one,
use_deep_supervision=use_deep_supervision,
use_surf_dice=use_surface_dice,
surf_dice_weight=surface_dice_weight,
surf_dice_tokens=surface_dice_tokens,
project_name=project_name,
sub_name=sub_name,
)
Expand Down
24 changes: 23 additions & 1 deletion src/membrain_seg/segmentation/dataloading/memseg_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
from typing import Dict

# from skimage import io
import imageio as io
import numpy as np
from torch.utils.data import Dataset
Expand Down Expand Up @@ -102,6 +101,7 @@ def __getitem__(self, idx: int) -> Dict[str, np.ndarray]:
"label": np.expand_dims(self.labels[idx], 0),
}
idx_dict = self.transforms(idx_dict)
idx_dict["dataset"] = self.dataset_labels[idx]
return idx_dict

def __len__(self) -> int:
Expand All @@ -126,6 +126,7 @@ def load_data(self) -> None:
print("Loading images into dataset.")
self.imgs = []
self.labels = []
self.dataset_labels = []
for entry in self.data_paths:
label = read_nifti(
entry[1]
Expand All @@ -137,6 +138,7 @@ def load_data(self) -> None:
img = np.transpose(img, (1, 2, 0))
self.imgs.append(img)
self.labels.append(label)
self.dataset_labels.append(get_dataset_token(entry[0]))

def initialize_imgs_paths(self) -> None:
"""
Expand Down Expand Up @@ -190,3 +192,23 @@ def test(self, test_folder: str, num_files: int = 20) -> None:
os.path.join(test_folder, f"test_mask_ds2_{i}_group{num_mask}.png"),
test_sample["label"][1][0, :, :, num_mask],
)


def get_dataset_token(patch_name):
"""
Get the dataset token from the patch name.
Parameters
----------
patch_name : str
The name of the patch.
Returns
-------
str
The dataset token.
"""
basename = os.path.basename(patch_name)
dataset_token = basename.split("_")[0]
return dataset_token
Loading

0 comments on commit 49aa798

Please sign in to comment.