From 7c37049b625455632827e9d2d9be7338e4fe76f9 Mon Sep 17 00:00:00 2001 From: Lorenz Lamm <34575029+LorenzLamm@users.noreply.github.com> Date: Sat, 11 Jan 2025 15:04:41 +0100 Subject: [PATCH] Benchmark metrics (#86) * replace print with logging * replace print with logging * add benchmarking scripts * docs for benchmark --- docs/Usage/Benchmarking.md | 35 ++ src/membrain_seg/benchmark/__init__.py | 1 + src/membrain_seg/benchmark/compute_stats.py | 339 ++++++++++++++++++ src/membrain_seg/benchmark/metrics.py | 92 +++++ src/membrain_seg/segmentation/cli/ske_cli.py | 18 +- .../segmentation/skeletonization/eig3d.py | 12 +- src/membrain_seg/segmentation/skeletonize.py | 21 +- 7 files changed, 491 insertions(+), 27 deletions(-) create mode 100644 docs/Usage/Benchmarking.md create mode 100644 src/membrain_seg/benchmark/__init__.py create mode 100644 src/membrain_seg/benchmark/compute_stats.py create mode 100644 src/membrain_seg/benchmark/metrics.py diff --git a/docs/Usage/Benchmarking.md b/docs/Usage/Benchmarking.md new file mode 100644 index 0000000..7a481b3 --- /dev/null +++ b/docs/Usage/Benchmarking.md @@ -0,0 +1,35 @@ +# Benchmarking + +Our well-annotated dataset can not only be used to train a model, but also the benchmark other models trained on different datasets. +To do so, we provide a script that can be used to evaluate the performance of a model on our dataset. + +## Usage +### Preparations +1. Have the MemBrain-seg dataset ready, i.e. the imagesTest and labelsTest folders. +2. Perform predictions on the imagesTest folder using your model. Store the predictions in a folder of your choice. + +**Important**: To match the predictions and ground truth labels, the predictions should have the same name as the input patch, respectively, but without the "_0000"-suffix. +E.g. if the input patch is called "patch_0000.nii.gz", the prediction should be called "patch.nii.gz" (but also .mrc is possible). + +### Running the benchmarking script +The benchmarking script does not provide. Instead, it is a Python function that you can integrate into your own code: + + ```python + from membrain_seg.benchmark.compute_stats import compute_stats + + dir_gt = "path/to/ground_truth" + dir_pred = "path/to/predictions" + out_dir = "path/to/output" + out_file_token = "stats" + + compute_stats(dir_gt, dir_pred, out_dir, out_file_token) + ``` + +This will compute the statistics for the segmentations on the entire dataset and store the results in the specified output directory. +As metrics, the script computes the surface dice and the dice score for each segmentation. To learn more about the surface dice, please refer to our [manuscript](https://www.biorxiv.org/content/10.1101/2024.01.05.574336v1). + + + + + + diff --git a/src/membrain_seg/benchmark/__init__.py b/src/membrain_seg/benchmark/__init__.py new file mode 100644 index 0000000..911b3ea --- /dev/null +++ b/src/membrain_seg/benchmark/__init__.py @@ -0,0 +1 @@ +"""empty init.""" diff --git a/src/membrain_seg/benchmark/compute_stats.py b/src/membrain_seg/benchmark/compute_stats.py new file mode 100644 index 0000000..3b2dd09 --- /dev/null +++ b/src/membrain_seg/benchmark/compute_stats.py @@ -0,0 +1,339 @@ +""" +Compute statistics for segmentations on the entire dataset. + +Workflow: +1. Load predictions and ground truth segmentations from two separate directories. + Filenames (except for extension) must match. (predictions can be .nii.gz or .mrc) +2. Skeletonize all segmentations. +3. Compute both dice and surface-dice scores for each pair of predictions and ground + truth segmentations. +4. Compute also global dice and surface-dice by aggregating all segmentations / + skeletons. +""" + +import csv +import os + +import numpy as np +from tqdm import tqdm + +from membrain_seg.benchmark.metrics import masked_dice, masked_surface_dice +from membrain_seg.segmentation.dataloading.data_utils import ( + load_tomogram, + read_nifti, +) +from membrain_seg.segmentation.skeletonize import skeletonization + +ds_dict = {} + + +def reset_ds_dict(): + """Reset the dataset dictionary.""" + global ds_dict + ds_dict = {} + + +def get_filepaths(dir_gt: str, dir_pred: str): + """ + Get filepaths for all ground truth segmentations and predictions. + + Parameters + ---------- + dir_gt : str + Directory containing ground truth segmentations. + dir_pred : str + Directory containing predictions. + + Returns + ------- + gt_files : list + List of ground truth segmentation filepaths. + pred_files : list + List of prediction filepaths. + """ + # Load all segmentations and skeletons + gt_files = os.listdir(dir_gt) + # take all file with .nii.gz extension + gt_files = [f for f in gt_files if f.endswith(".nii.gz")] + + # check whether predictions are in .mrc or .nii.gz format + is_mrc = False + pred_files = os.listdir(dir_pred) + pred_files_mrc = [f for f in pred_files if f.endswith(".mrc")] + pred_files_nii = [f for f in pred_files if f.endswith(".nii.gz")] + if len(pred_files_mrc) > 0: + pred_files = pred_files_mrc + is_mrc = True + elif len(pred_files_nii) > 0: + pred_files = pred_files_nii + else: + raise ValueError("No predictions found in .mrc or .nii.gz format.") + + # check whether the number of predictions and ground truth segmentations match + if len(gt_files) != len(pred_files): + raise ValueError( + "Number of ground truth segmentations and predictions do not match." + ) + + # sort all files alphabetically + gt_files.sort() + pred_files.sort() + + # make sure gt_files and pred_files are the same before the extension + for gt, pred in zip(gt_files, pred_files): + if is_mrc: + assert gt[:-7] == pred[:-4] + else: + assert gt[:-7] == pred[:-7] + return gt_files, pred_files + + +def read_nifti_or_mrc(file_path: str): + """Read a nifti or mrc file. + + Parameters + ---------- + file_path : str + Path to the file. + + Returns + ------- + np.ndarray + The data. + """ + if file_path.endswith(".mrc"): + return load_tomogram(file_path) + else: + return read_nifti(file_path) + + +def get_ds_token(filename): + """Get the dataset token from the filename. + + Parameters + ---------- + filename : str + The filename of the patch. + + Returns + ------- + str + The dataset token. + """ + token = filename.split("_")[0] + if token in ["atty", "benedikt", "rory", "virly"]: + return "collaborators" + elif token in ["cts", "polnet"]: + return "synthetic" + else: + return token + + +def initialize_ds_dict_entry(ds_token): + """Initialize a dataset dictionary entry. + + Parameters + ---------- + ds_token : str + The dataset token. + """ + if ds_token not in ds_dict: + ds_dict[ds_token] = { + "surf_dice": [], + "tp_pred_sdice": 0, + "tp_gt_sdice": 0, + "all_pred_sdice": 0, + "all_gt_sdice": 0, + "dice": [], + "tp": 0, + "fp": 0, + "fn": 0, + } + + +def update_ds_dict_entry(ds_token, surf_dice, confusion_dict, dice, dice_dict): + """Update the dataset dictionary entry. + + Parameters + ---------- + ds_token : str + The dataset token. + surf_dice : float + Surface dice score. + confusion_dict : dict + Dictionary containing the following + keys: + - tp_pred: True positives in the prediction. + - tp_gt: True positives in the ground truth. + - all_pred: All positives in the prediction. + - all_gt: All positives in the ground truth. + dice : float + Dice score. + dice_dict : dict + Dictionary containing the following + keys: + - tp: True positives. + - fp: False positives. + - fn: False negatives. + + """ + ds_dict[ds_token]["surf_dice"].append(surf_dice) + ds_dict[ds_token]["tp_pred_sdice"] += confusion_dict["tp_pred"] + ds_dict[ds_token]["tp_gt_sdice"] += confusion_dict["tp_gt"] + ds_dict[ds_token]["all_pred_sdice"] += confusion_dict["all_pred"] + ds_dict[ds_token]["all_gt_sdice"] += confusion_dict["all_gt"] + ds_dict[ds_token]["dice"].append(dice) + ds_dict[ds_token]["tp"] += dice_dict["tp"] + ds_dict[ds_token]["fp"] += dice_dict["fp"] + ds_dict[ds_token]["fn"] += dice_dict["fn"] + + +def print_ds_dict(): + """Print the dataset dictionary.""" + print("") + print("Dataset statistics:") + for ds_token in ds_dict: + print(f"Dataset: {ds_token}") + print(f"Surface dice: {np.mean(ds_dict[ds_token]['surf_dice'])}") + print(f"Global surface dice: {get_global_stats(ds_token, s_dice=True)}") + print(f"Dice: {np.mean(ds_dict[ds_token]['dice'])}") + print(f"Global dice: {get_global_stats(ds_token, s_dice=False)}") + print("") + + +def get_global_stats( + ds_token, + s_dice: bool, +): + """Aggregates global statistics for a dataset. + + Parameters + ---------- + ds_token : str + The dataset token. + s_dice : bool + Whether to compute surface dice or dice. + + Returns + ------- + float + The global statistic. + """ + if s_dice: + global_dice = ( + 2.0 + * ( + ds_dict[ds_token]["tp_pred_sdice"] + / (ds_dict[ds_token]["all_pred_sdice"] + 1e-6) + ) + * ( + ds_dict[ds_token]["tp_gt_sdice"] + / (ds_dict[ds_token]["all_gt_sdice"] + 1e-6) + ) + / ( + ds_dict[ds_token]["tp_pred_sdice"] + / (ds_dict[ds_token]["all_pred_sdice"] + 1e-6) + + ds_dict[ds_token]["tp_gt_sdice"] + / (ds_dict[ds_token]["all_gt_sdice"] + 1e-6) + ) + ) + else: + global_dice = ( + 2.0 + * ( + ds_dict[ds_token]["tp"] + / (ds_dict[ds_token]["tp"] + ds_dict[ds_token]["fp"]) + ) + * ( + ds_dict[ds_token]["tp"] + / (ds_dict[ds_token]["tp"] + ds_dict[ds_token]["fn"]) + ) + / ( + ds_dict[ds_token]["tp"] + / (ds_dict[ds_token]["tp"] + ds_dict[ds_token]["fp"]) + + ds_dict[ds_token]["tp"] + / (ds_dict[ds_token]["tp"] + ds_dict[ds_token]["fn"]) + ) + ) + return global_dice + + +def store_stats(out_file): + """Store the dataset dictionary in a csv file. + + Parameters + ---------- + out_file : str + Path to the output file. + """ + # store ds_dict in a csv file + header = [ + "Dataset", + "Surface Dice", + "Global Surface Dice", + "Dice", + "Global Dice", + ] + with open(out_file, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(header) + for ds_token in ds_dict: + row = [ + ds_token, + np.mean(ds_dict[ds_token]["surf_dice"]), + get_global_stats(ds_token, s_dice=True), + np.mean(ds_dict[ds_token]["dice"]), + get_global_stats(ds_token, s_dice=False), + ] + writer.writerow(row) + + +def compute_stats( + dir_gt: str, + dir_pred: str, + out_dir: str, + out_file_token: str = "stats", +): + """ + Compute statistics for segmentations on the entire dataset. + + Parameters + ---------- + dir_gt : str + Directory containing ground truth segmentations. + dir_pred : str + Directory containing predictions. + out_dir : str + Directory to save the results. + out_file_token : str + Token to append to the output file. + """ + reset_ds_dict() + gt_files, pred_files = get_filepaths(dir_gt, dir_pred) + + length = len(gt_files) + for i in tqdm(range(length)): + gt_file = gt_files[i] + pred_file = pred_files[i] + + ds_token = get_ds_token(gt_file) + initialize_ds_dict_entry(ds_token) + + # Load ground truth and prediction + gt = read_nifti_or_mrc(os.path.join(dir_gt, gt_file)) + pred = read_nifti_or_mrc(os.path.join(dir_pred, pred_file)) + # Skeletonize both segmentations + gt_skeleton = skeletonization(gt == 1, batch_size=100000) + pred_skeleton = skeletonization(pred, batch_size=100000) + mask = gt != 2 + + # Compute surface dice + surf_dice, confusion_dict = masked_surface_dice( + pred_skeleton, gt_skeleton, pred, gt, mask + ) + dice, dice_dict = masked_dice(pred, gt, mask) + update_ds_dict_entry(ds_token, surf_dice, confusion_dict, dice, dice_dict) + print_ds_dict() + os.makedirs(out_dir, exist_ok=True) + out_file = os.path.join(out_dir, f"{out_file_token}.csv") + store_stats(out_file) diff --git a/src/membrain_seg/benchmark/metrics.py b/src/membrain_seg/benchmark/metrics.py new file mode 100644 index 0000000..fa3b98b --- /dev/null +++ b/src/membrain_seg/benchmark/metrics.py @@ -0,0 +1,92 @@ +import numpy as np + + +def masked_surface_dice( + pred_skel: np.ndarray, + gt_skel: np.ndarray, + pred: np.ndarray, + gt: np.ndarray, + mask: np.ndarray, +) -> float: + """Compute surface dice score for two skeletons. + + Parameters + ---------- + pred_skel : np.ndarray + Skeleton of the prediction. + gt_skel : np.ndarray + Skeleton of the ground truth. + pred : np.ndarray + Prediction. + gt : np.ndarray + Ground truth. + mask : np.ndarray + Mask to ignore certain labels. + + Returns + ------- + float + Surface dice score. + dict + Dictionary containing the following keys: + - tp_pred: True positives in the prediction. + - tp_gt: True positives in the ground truth. + - all_pred: All positives in the prediction. + - all_gt: All positives in the ground truth. + """ + # Mask out ignore labels + pred_skel[~mask] = 0 + gt_skel[~mask] = 0 + + tp_pred = np.sum(np.multiply(pred_skel, gt)) + tp_gt = np.sum(np.multiply(gt_skel, pred)) + all_pred = np.sum(pred_skel) + all_gt = np.sum(gt_skel) + + tprec = tp_pred / (all_pred + 1e-6) + tsens = tp_gt / (all_gt + 1e-6) + + surf_dice = 2.0 * (tprec * tsens) / (tprec + tsens + 1e-6) + return surf_dice, { + "tp_pred": tp_pred, + "tp_gt": tp_gt, + "all_pred": all_pred, + "all_gt": all_gt, + } + + +def masked_dice(pred: np.ndarray, gt: np.ndarray, mask: np.ndarray) -> float: + """Compute dice score for two segmentations. + + Parameters + ---------- + pred : np.ndarray + Prediction. + gt : np.ndarray + Ground truth. + mask : np.ndarray + Mask to ignore certain labels. + + Returns + ------- + float + Dice score. + dict + Dictionary containing the following + keys: + - tp: True positives. + - fp: False positives. + - fn: False negatives. + + """ + pred[~mask] = 0 + gt[~mask] = 0 + tp = np.sum(np.logical_and(pred == 1, gt == 1)) + fp = np.sum(np.logical_and(pred == 1, gt == 0)) + fn = np.sum(np.logical_and(pred == 0, gt == 1)) + tprec = tp / (tp + fp + 1e-6) + tsens = tp / (tp + fn + 1e-6) + # return also dict + out_dict = {"tp": tp, "fp": fp, "fn": fn} + dice = 2.0 * (tprec * tsens) / (tprec + tsens + 1e-6) + return dice, out_dict diff --git a/src/membrain_seg/segmentation/cli/ske_cli.py b/src/membrain_seg/segmentation/cli/ske_cli.py index 212773a..d59a5c7 100644 --- a/src/membrain_seg/segmentation/cli/ske_cli.py +++ b/src/membrain_seg/segmentation/cli/ske_cli.py @@ -1,3 +1,4 @@ +import logging import os from typer import Option @@ -7,18 +8,19 @@ store_tomogram, ) - from ..skeletonize import skeletonization as _skeletonization from .cli import cli @cli.command(name="skeletonize", no_args_is_help=True) def skeletonize( - label_path: str = Option(..., help="Specifies the path for skeletonization."), - out_folder: str = Option( + label_path: str = Option( # noqa: B008 + ..., help="Specifies the path for skeletonization." + ), + out_folder: str = Option( # noqa: B008 "./predictions", help="Directory to save the resulting skeletons." ), - batch_size: int = Option( + batch_size: int = Option( # noqa: B008 None, help="Optional batch size for processing the tomogram. If not specified, " "the entire volume is processed at once. If operating with limited GPU " @@ -58,10 +60,10 @@ def skeletonize( segmentation = load_tomogram(label_path) ske = _skeletonization(segmentation=segmentation.data, batch_size=batch_size) - # Update the segmentation data with the skeletonized output while preserving the original header and voxel_size + # Update the segmentation data with the skeletonized output while preserving the + # original header and voxel_size segmentation.data = ske - if not os.path.exists(out_folder): os.makedirs(out_folder) @@ -69,6 +71,6 @@ def skeletonize( out_folder, os.path.splitext(os.path.basename(label_path))[0] + "_skel.mrc", ) - + store_tomogram(filename=out_file, tomogram=segmentation) - print("Skeleton saved to ", out_file) + logging.info("Skeleton saved to " + out_file) diff --git a/src/membrain_seg/segmentation/skeletonization/eig3d.py b/src/membrain_seg/segmentation/skeletonization/eig3d.py index 91eb8fd..9d034b2 100644 --- a/src/membrain_seg/segmentation/skeletonization/eig3d.py +++ b/src/membrain_seg/segmentation/skeletonization/eig3d.py @@ -8,6 +8,7 @@ # to the original licensing agreements. For details on the original license, refer to # the publication: https://www.sciencedirect.com/science/article/pii/S1047847714000495. # --------------------------------------------------------------------------------- +import logging from typing import List, Tuple import numpy as np @@ -18,8 +19,7 @@ def batch_mask_eigendecomposition_3d( filtered_hessian: List[torch.Tensor], batch_size: int, labels: np.ndarray ) -> Tuple[np.ndarray, np.ndarray]: """ - Perform batch eigendecomposition on a 3D Hessian matrix using a binary mask to - select voxels. + Perform batch eigendecomposition on a 3D Hessian matrix using a binary mask. This function processes only those voxels where the label is set to 1, computing the largest eigenvalue and its corresponding eigenvector for @@ -54,7 +54,7 @@ def batch_mask_eigendecomposition_3d( # if no specified batch size is given if batch_size is None: batch_size = Nx * Ny * Nz - print("batch_size=", batch_size) + logging.info("batch_size=" + str(batch_size)) # Identify coordinates where computation is needed active_voxel_coords = np.where(labels == 1) @@ -77,7 +77,6 @@ def batch_mask_eigendecomposition_3d( dim=-1, ).view(-1, 3, 3) del hessianXX, hessianYY, hessianZZ, hessianXY, hessianXZ, hessianYZ - print("Hessian component matrix shape:", hessian_components.shape) # Initialize output arrays first_eigenvalues = np.zeros((Nx, Ny, Nz), dtype=np.float32) @@ -88,9 +87,6 @@ def batch_mask_eigendecomposition_3d( for i in range(0, num_active_voxels, batch_size): if torch.cuda.is_available(): torch.cuda.empty_cache() - # print('i=', i) - # print(f"Allocated: {torch.cuda.memory_allocated(0)/1e9:.2f} GB") - # print(f"Cached: {torch.cuda.memory_reserved(0)/1e9:.2f} GB") i_end = min(i + batch_size, num_active_voxels) batch_matrix = hessian_components[i:i_end, :, :] @@ -108,7 +104,7 @@ def batch_mask_eigendecomposition_3d( # Store results back to CPU to save cuda memory first_eigenvalues[ x_indices[i:i_end], y_indices[i:i_end], z_indices[i:i_end] - ] = batch_first_eigenvalues.cpu().numpy().real + ] = (batch_first_eigenvalues.cpu().numpy().real) first_eigenvectors[ x_indices[i:i_end], y_indices[i:i_end], z_indices[i:i_end], : ] = (batch_first_eigenvectors.view(-1, 3).cpu().numpy()).real diff --git a/src/membrain_seg/segmentation/skeletonize.py b/src/membrain_seg/segmentation/skeletonize.py index f4048d5..7fa14ba 100644 --- a/src/membrain_seg/segmentation/skeletonize.py +++ b/src/membrain_seg/segmentation/skeletonize.py @@ -8,11 +8,12 @@ # to the original licensing agreements. For details on the original license, refer to # the publication: https://www.sciencedirect.com/science/article/pii/S1047847714000495. # --------------------------------------------------------------------------------- +import logging + import numpy as np import scipy.ndimage as ndimage import torch - from membrain_seg.segmentation.skeletonization.diff3d import ( compute_gradients, compute_hessian, @@ -24,7 +25,6 @@ from membrain_seg.segmentation.training.surface_dice import apply_gaussian_filter - def skeletonization(segmentation: np.ndarray, batch_size: int) -> np.ndarray: """ Perform skeletonization on a tomogram segmentation. @@ -61,19 +61,18 @@ def skeletonization(segmentation: np.ndarray, batch_size: int) -> np.ndarray: --batch-size 1000000 This command runs the skeletonization process from the command line. """ - # Convert non-zero segmentation values to 1.0 labels = (segmentation > 0) * 1.0 - print("Distance transform on original labels.") + logging.info("Distance transform on original labels.") labels_dt = ndimage.distance_transform_edt(labels) * (-1) # Calculates partial derivative along 3 dimensions. - print("Computing partial derivative.") + logging.info("Computing partial derivative.") gradX, gradY, gradZ = compute_gradients(labels_dt) # Calculates Hessian tensor - print("Computing Hessian tensor.") + logging.info("Computing Hessian tensor.") hessianXX, hessianYY, hessianZZ, hessianXY, hessianXZ, hessianYZ = compute_hessian( gradX, gradY, gradZ ) @@ -81,10 +80,10 @@ def skeletonization(segmentation: np.ndarray, batch_size: int) -> np.ndarray: del gradX, gradY, gradZ # Apply Gaussian filter with the same sigma value for all dimensions - print("Applying Gaussian filtering.") + logging.info("Applying Gaussian filtering.") # Load hessian tensors on GPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - print(f"Using device: {device}") + logging.info(f"Using device: {device}") filtered_hessian = [ apply_gaussian_filter( @@ -98,8 +97,8 @@ def skeletonization(segmentation: np.ndarray, batch_size: int) -> np.ndarray: ] # Solve Eigen problem - print("Computing Eigenvalues and Eigenvectors.") - print( + logging.info("Computing Eigenvalues and Eigenvectors.") + logging.info( "In case the execution of the program is terminated unexpectedly, " "attempt to rerun it using smaller segmentation patches" "or give a specified batch size as input, e.g. batch_size=1000000." @@ -109,7 +108,7 @@ def skeletonization(segmentation: np.ndarray, batch_size: int) -> np.ndarray: ) # Non-maximum suppression - print("Genration of skeleton based on non-maximum suppression algorithm.") + logging.info("Generation of skeleton based on non-maximum suppression algorithm.") first_eigenvalue = ndimage.gaussian_filter(first_eigenvalue, sigma=1) skeleton = nonmaxsup( first_eigenvalue,