Skip to content

Commit

Permalink
Benchmark metrics (#86)
Browse files Browse the repository at this point in the history
* replace print with logging

* replace print with logging

* add benchmarking scripts

* docs for benchmark
  • Loading branch information
LorenzLamm authored Jan 11, 2025
1 parent 215b3d5 commit 7c37049
Show file tree
Hide file tree
Showing 7 changed files with 491 additions and 27 deletions.
35 changes: 35 additions & 0 deletions docs/Usage/Benchmarking.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Benchmarking

Our well-annotated dataset can not only be used to train a model, but also the benchmark other models trained on different datasets.
To do so, we provide a script that can be used to evaluate the performance of a model on our dataset.

## Usage
### Preparations
1. Have the MemBrain-seg dataset ready, i.e. the imagesTest and labelsTest folders.
2. Perform predictions on the imagesTest folder using your model. Store the predictions in a folder of your choice.

**Important**: To match the predictions and ground truth labels, the predictions should have the same name as the input patch, respectively, but without the "_0000"-suffix.
E.g. if the input patch is called "patch_0000.nii.gz", the prediction should be called "patch.nii.gz" (but also .mrc is possible).

### Running the benchmarking script
The benchmarking script does not provide. Instead, it is a Python function that you can integrate into your own code:

```python
from membrain_seg.benchmark.compute_stats import compute_stats

dir_gt = "path/to/ground_truth"
dir_pred = "path/to/predictions"
out_dir = "path/to/output"
out_file_token = "stats"

compute_stats(dir_gt, dir_pred, out_dir, out_file_token)
```

This will compute the statistics for the segmentations on the entire dataset and store the results in the specified output directory.
As metrics, the script computes the surface dice and the dice score for each segmentation. To learn more about the surface dice, please refer to our [manuscript](https://www.biorxiv.org/content/10.1101/2024.01.05.574336v1).






1 change: 1 addition & 0 deletions src/membrain_seg/benchmark/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""empty init."""
339 changes: 339 additions & 0 deletions src/membrain_seg/benchmark/compute_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,339 @@
"""
Compute statistics for segmentations on the entire dataset.
Workflow:
1. Load predictions and ground truth segmentations from two separate directories.
Filenames (except for extension) must match. (predictions can be .nii.gz or .mrc)
2. Skeletonize all segmentations.
3. Compute both dice and surface-dice scores for each pair of predictions and ground
truth segmentations.
4. Compute also global dice and surface-dice by aggregating all segmentations /
skeletons.
"""

import csv
import os

import numpy as np
from tqdm import tqdm

from membrain_seg.benchmark.metrics import masked_dice, masked_surface_dice
from membrain_seg.segmentation.dataloading.data_utils import (
load_tomogram,
read_nifti,
)
from membrain_seg.segmentation.skeletonize import skeletonization

ds_dict = {}


def reset_ds_dict():
"""Reset the dataset dictionary."""
global ds_dict
ds_dict = {}


def get_filepaths(dir_gt: str, dir_pred: str):
"""
Get filepaths for all ground truth segmentations and predictions.
Parameters
----------
dir_gt : str
Directory containing ground truth segmentations.
dir_pred : str
Directory containing predictions.
Returns
-------
gt_files : list
List of ground truth segmentation filepaths.
pred_files : list
List of prediction filepaths.
"""
# Load all segmentations and skeletons
gt_files = os.listdir(dir_gt)
# take all file with .nii.gz extension
gt_files = [f for f in gt_files if f.endswith(".nii.gz")]

# check whether predictions are in .mrc or .nii.gz format
is_mrc = False
pred_files = os.listdir(dir_pred)
pred_files_mrc = [f for f in pred_files if f.endswith(".mrc")]
pred_files_nii = [f for f in pred_files if f.endswith(".nii.gz")]
if len(pred_files_mrc) > 0:
pred_files = pred_files_mrc
is_mrc = True
elif len(pred_files_nii) > 0:
pred_files = pred_files_nii
else:
raise ValueError("No predictions found in .mrc or .nii.gz format.")

# check whether the number of predictions and ground truth segmentations match
if len(gt_files) != len(pred_files):
raise ValueError(
"Number of ground truth segmentations and predictions do not match."
)

# sort all files alphabetically
gt_files.sort()
pred_files.sort()

# make sure gt_files and pred_files are the same before the extension
for gt, pred in zip(gt_files, pred_files):
if is_mrc:
assert gt[:-7] == pred[:-4]
else:
assert gt[:-7] == pred[:-7]
return gt_files, pred_files


def read_nifti_or_mrc(file_path: str):
"""Read a nifti or mrc file.
Parameters
----------
file_path : str
Path to the file.
Returns
-------
np.ndarray
The data.
"""
if file_path.endswith(".mrc"):
return load_tomogram(file_path)
else:
return read_nifti(file_path)


def get_ds_token(filename):
"""Get the dataset token from the filename.
Parameters
----------
filename : str
The filename of the patch.
Returns
-------
str
The dataset token.
"""
token = filename.split("_")[0]
if token in ["atty", "benedikt", "rory", "virly"]:
return "collaborators"
elif token in ["cts", "polnet"]:
return "synthetic"
else:
return token


def initialize_ds_dict_entry(ds_token):
"""Initialize a dataset dictionary entry.
Parameters
----------
ds_token : str
The dataset token.
"""
if ds_token not in ds_dict:
ds_dict[ds_token] = {
"surf_dice": [],
"tp_pred_sdice": 0,
"tp_gt_sdice": 0,
"all_pred_sdice": 0,
"all_gt_sdice": 0,
"dice": [],
"tp": 0,
"fp": 0,
"fn": 0,
}


def update_ds_dict_entry(ds_token, surf_dice, confusion_dict, dice, dice_dict):
"""Update the dataset dictionary entry.
Parameters
----------
ds_token : str
The dataset token.
surf_dice : float
Surface dice score.
confusion_dict : dict
Dictionary containing the following
keys:
- tp_pred: True positives in the prediction.
- tp_gt: True positives in the ground truth.
- all_pred: All positives in the prediction.
- all_gt: All positives in the ground truth.
dice : float
Dice score.
dice_dict : dict
Dictionary containing the following
keys:
- tp: True positives.
- fp: False positives.
- fn: False negatives.
"""
ds_dict[ds_token]["surf_dice"].append(surf_dice)
ds_dict[ds_token]["tp_pred_sdice"] += confusion_dict["tp_pred"]
ds_dict[ds_token]["tp_gt_sdice"] += confusion_dict["tp_gt"]
ds_dict[ds_token]["all_pred_sdice"] += confusion_dict["all_pred"]
ds_dict[ds_token]["all_gt_sdice"] += confusion_dict["all_gt"]
ds_dict[ds_token]["dice"].append(dice)
ds_dict[ds_token]["tp"] += dice_dict["tp"]
ds_dict[ds_token]["fp"] += dice_dict["fp"]
ds_dict[ds_token]["fn"] += dice_dict["fn"]


def print_ds_dict():
"""Print the dataset dictionary."""
print("")
print("Dataset statistics:")
for ds_token in ds_dict:
print(f"Dataset: {ds_token}")
print(f"Surface dice: {np.mean(ds_dict[ds_token]['surf_dice'])}")
print(f"Global surface dice: {get_global_stats(ds_token, s_dice=True)}")
print(f"Dice: {np.mean(ds_dict[ds_token]['dice'])}")
print(f"Global dice: {get_global_stats(ds_token, s_dice=False)}")
print("")


def get_global_stats(
ds_token,
s_dice: bool,
):
"""Aggregates global statistics for a dataset.
Parameters
----------
ds_token : str
The dataset token.
s_dice : bool
Whether to compute surface dice or dice.
Returns
-------
float
The global statistic.
"""
if s_dice:
global_dice = (
2.0
* (
ds_dict[ds_token]["tp_pred_sdice"]
/ (ds_dict[ds_token]["all_pred_sdice"] + 1e-6)
)
* (
ds_dict[ds_token]["tp_gt_sdice"]
/ (ds_dict[ds_token]["all_gt_sdice"] + 1e-6)
)
/ (
ds_dict[ds_token]["tp_pred_sdice"]
/ (ds_dict[ds_token]["all_pred_sdice"] + 1e-6)
+ ds_dict[ds_token]["tp_gt_sdice"]
/ (ds_dict[ds_token]["all_gt_sdice"] + 1e-6)
)
)
else:
global_dice = (
2.0
* (
ds_dict[ds_token]["tp"]
/ (ds_dict[ds_token]["tp"] + ds_dict[ds_token]["fp"])
)
* (
ds_dict[ds_token]["tp"]
/ (ds_dict[ds_token]["tp"] + ds_dict[ds_token]["fn"])
)
/ (
ds_dict[ds_token]["tp"]
/ (ds_dict[ds_token]["tp"] + ds_dict[ds_token]["fp"])
+ ds_dict[ds_token]["tp"]
/ (ds_dict[ds_token]["tp"] + ds_dict[ds_token]["fn"])
)
)
return global_dice


def store_stats(out_file):
"""Store the dataset dictionary in a csv file.
Parameters
----------
out_file : str
Path to the output file.
"""
# store ds_dict in a csv file
header = [
"Dataset",
"Surface Dice",
"Global Surface Dice",
"Dice",
"Global Dice",
]
with open(out_file, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(header)
for ds_token in ds_dict:
row = [
ds_token,
np.mean(ds_dict[ds_token]["surf_dice"]),
get_global_stats(ds_token, s_dice=True),
np.mean(ds_dict[ds_token]["dice"]),
get_global_stats(ds_token, s_dice=False),
]
writer.writerow(row)


def compute_stats(
dir_gt: str,
dir_pred: str,
out_dir: str,
out_file_token: str = "stats",
):
"""
Compute statistics for segmentations on the entire dataset.
Parameters
----------
dir_gt : str
Directory containing ground truth segmentations.
dir_pred : str
Directory containing predictions.
out_dir : str
Directory to save the results.
out_file_token : str
Token to append to the output file.
"""
reset_ds_dict()
gt_files, pred_files = get_filepaths(dir_gt, dir_pred)

length = len(gt_files)
for i in tqdm(range(length)):
gt_file = gt_files[i]
pred_file = pred_files[i]

ds_token = get_ds_token(gt_file)
initialize_ds_dict_entry(ds_token)

# Load ground truth and prediction
gt = read_nifti_or_mrc(os.path.join(dir_gt, gt_file))
pred = read_nifti_or_mrc(os.path.join(dir_pred, pred_file))
# Skeletonize both segmentations
gt_skeleton = skeletonization(gt == 1, batch_size=100000)
pred_skeleton = skeletonization(pred, batch_size=100000)
mask = gt != 2

# Compute surface dice
surf_dice, confusion_dict = masked_surface_dice(
pred_skeleton, gt_skeleton, pred, gt, mask
)
dice, dice_dict = masked_dice(pred, gt, mask)
update_ds_dict_entry(ds_token, surf_dice, confusion_dict, dice, dice_dict)
print_ds_dict()
os.makedirs(out_dir, exist_ok=True)
out_file = os.path.join(out_dir, f"{out_file_token}.csv")
store_stats(out_file)
Loading

0 comments on commit 7c37049

Please sign in to comment.