Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
quantumjot committed Nov 20, 2023
1 parent 1c22d78 commit f1d4bd1
Show file tree
Hide file tree
Showing 10 changed files with 323 additions and 187 deletions.
25 changes: 16 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ TODO:
### Single image usage

```python
import umetrics
import umetrix
from skimage.io import imread

y_true = imread('true.tif')
Expand All @@ -29,7 +29,12 @@ y_pred = imread('pred.tif')

# can now make the calculation strict, by only considering objects that have
# an IoU above a theshold as being true positives
result = umetrics.calculate(y_true, y_pred, strict=True, iou_threshold=0.5)
result = umetrix.calculate(
y_true,
y_pred,
strict=True,
iou_threshold=0.5
)

print(result.results)
```
Expand All @@ -56,14 +61,16 @@ localization_error: 0.010
### Batch processing

```python
import umetrics
import umetrix

# provide a list of file pairs ('true', 'prediction')
files = [('true0.tif', 'pred0.tif'),
('true1.tif', 'pred1.tif'),
('true2.tif', 'pred2.tif')]
files = [
('true0.tif', 'pred0.tif'),
('true1.tif', 'pred1.tif'),
('true2.tif', 'pred2.tif')
]

batch_result = umetrics.batch(files)
batch_result = umetrix.batch(files)
```

Returns aggregate statistics over the batch. Jaccard index is calculated over
Expand All @@ -79,8 +86,8 @@ $ git clone https://github.com/quantumjot/unet_segmentation_metrics.git

2. (Optional, but advised) Create a conda environment:
```sh
$ conda create -n umetrics python=3.7
$ conda activate umetrics
$ conda create -n umetrix python=3.9
$ conda activate umetrix
```

3. Install the package
Expand Down
67 changes: 30 additions & 37 deletions notebooks/unet_segmentation_metrics-napari.ipynb

Large diffs are not rendered by default.

201 changes: 145 additions & 56 deletions notebooks/unet_segmentation_metrics.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ dynamic = ["version"]

[tool.setuptools.packages.find]
where = ["src"]
include = ["umetrics*"]
include = ["umetrix*"]

[tool.setuptools_scm]
local_scheme = "no-local-version"
write_to = "src/umetrics/_version.py"
write_to = "src/umetrix/_version.py"
File renamed without changes.
112 changes: 57 additions & 55 deletions src/umetrics/core.py → src/umetrix/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
from scipy.ndimage import find_objects
from scipy.optimize import linear_sum_assignment

from typing import Dict
from typing import Dict, Tuple

from . import render
from umetrix import render


DEFAULT_MAXIMUM_COST = 1e8


class Metrics(str, enum.Enum):
Expand Down Expand Up @@ -40,7 +43,7 @@ class Metrics(str, enum.Enum):
)


def _IoU(ref: npt.NDArray, pred: npt.NDArray) -> float:
def IoU(ref: npt.NDArray, pred: npt.NDArray) -> float:
"""Calculate the IoU between two binary masks."""
intersection = np.sum(np.logical_and(ref, pred))
union = np.sum(np.logical_or(ref, pred))
Expand Down Expand Up @@ -73,33 +76,37 @@ def find_matches(
------
matches : dict
A dictionary of matches between the two images.
"""

# make an infinite cost matrix, so that we only consider matches where
# there is some overlap in the masks
cost_matrix = np.full((len(ref.labels), len(pred.labels)), 1e8)
cost_matrix = np.full((len(ref.labels), len(pred.labels)), DEFAULT_MAXIMUM_COST)

for r_id, ref_label in enumerate(ref.labels):
mask = ref.labeled == ref_label
_matches = [m for m in np.unique(pred.labeled[mask]) if m > 0]
for pred_label in _matches:
p_id = pred.labels.index(pred_label)
reward = _IoU(mask, pred.labeled == pred_label)
if reward < iou_threshold and strict:
reward = IoU(mask, pred.labeled == pred_label)
if (reward < iou_threshold) and strict:
continue
cost_matrix[r_id, p_id] = 1.0 - reward

# if it's strict, make sure every element is above the threshold
if strict:
cost_threshold = 1.0 - iou_threshold
assert np.all(cost_matrix >= cost_threshold), cost_matrix
cost_mask = cost_matrix == DEFAULT_MAXIMUM_COST
assert np.all(cost_matrix[~cost_mask] <= cost_threshold)

# solve
# solve it using JV
sol_row, sol_col = linear_sum_assignment(cost_matrix)

# remove infeasible solutions
edges = [(r, c) for r, c in zip(sol_row, sol_col) if cost_matrix[r, c] <= 1]
edges = [
(ref.labels[r], pred.labels[c], 1.0 - cost_matrix[r, c])
for r, c in zip(sol_row, sol_col)
if cost_matrix[r, c] <= 1
]

# return a default dictionary if there are no matches
if not edges:
Expand All @@ -111,22 +118,15 @@ def find_matches(
}
return matches

sol_row, sol_col = zip(*edges)

# now that we've solved the LAP, find the matches that have been made
used_ref = [ref.labels[row] for row in sol_row]
used_pred = [pred.labels[col] for col in sol_col]
assert len(used_ref) == len(used_pred)
true_matches = list(zip(used_ref, used_pred))

# find the labels that haven't been used
used_ref, used_pred, IoUs = zip(*edges)
in_ref_only = set(ref.labels).difference(used_ref)
in_pred_only = set(pred.labels).difference(used_pred)

# return a dictionary of found matches
matches = {
"true_matches": true_matches,
"true_matches_IoU": 1.0 - cost_matrix[sol_row, sol_col],
"true_matches": set(zip(used_ref, used_pred)),
"true_matches_IoU": IoUs,
"in_ref_only": in_ref_only,
"in_pred_only": in_pred_only,
}
Expand Down Expand Up @@ -183,6 +183,28 @@ def __repr__(self) -> str:
r += f"{m}: {mval}\n"
return r

def _repr_html_(self):
try:
import pandas as pd
except ImportError:
return (
"<b>Install pandas for nicer, tabular rendering.</b> <br>"
+ self.__repr__()
)

data = {
"metrics": (
"N",
"strict",
"IoU_threshold",
)
+ METRICS,
"values": [self.n_images, self.strict, self.iou_threshold]
+ [getattr(self, metric) for metric in METRICS],
}

return pd.DataFrame.from_dict(data=data, orient="columns").to_html()

@property
def localization_error(self) -> float:
return np.mean(self.per_object_localization_error)
Expand All @@ -204,9 +226,8 @@ def pixel_identity(self) -> float:
return np.mean(self.per_image_pixel_identity)

@staticmethod
def merge(results) -> MetricResults:
"""merge n results together and return a single object"""
assert isinstance(results, list)
def merge(results: list) -> MetricResults:
"""Merge n results together and return a single object."""
merged = results.pop(0)
for result in results:
assert isinstance(result, MetricResults)
Expand Down Expand Up @@ -259,9 +280,6 @@ class SegmentationMetrics:
The IoU is calculated as the intersection of the binary segmentation
divided by the union.
TODO(arl): need to address undersegmentation detection
"""

def __init__(
Expand Down Expand Up @@ -289,19 +307,6 @@ def __init__(
iou_threshold=self.iou_threshold,
)

# # if we're in strict mode, prune the matches
# if self.strict:
# iou = self.per_object_IoU
# tp = [
# self.true_positives[i]
# for i, ov in enumerate(iou)
# if ov > self.iou_threshold
# ]
# fp = list(set(self.true_positives).difference(tp))

# self._matches["true_matches"] = tp
# self._matches["in_pred_only"] += [m[1] for m in fp]

@property
def strict(self) -> bool:
return self._strict
Expand Down Expand Up @@ -336,17 +341,17 @@ def n_pred_labels(self):

@property
def true_positives(self):
"""only one match between reference and predicted"""
"""Only one match between reference and predicted."""
return self._matches["true_matches"]

@property
def false_negatives(self):
"""no match in predicted for reference object"""
"""No match in predicted for reference object."""
return self._matches["in_ref_only"]

@property
def false_positives(self):
"""combination of non unique matches and unmatched objects"""
"""Combination of non unique matches and unmatched objects."""
return self._matches["in_pred_only"]

@property
Expand All @@ -364,23 +369,17 @@ def n_false_positives(self):
@property
def per_object_IoU(self):
"""Intersection over Union (IoU) metric"""
# iou = []
# for m in self.true_positives:
# mask_ref = self._reference.labeled == m[0]
# mask_pred = self._predicted.labeled == m[1]

# iou.append(_IoU(mask_ref, mask_pred))
# return iou
return self._matches["true_matches_IoU"]

@property
def per_image_pixel_identity(self):
"""Calculate the per-image pixel identity."""
n_tot = np.prod(self._reference.image.shape)
return [np.sum(self._reference.image == self._predicted.image) / n_tot]

@property
def per_object_localization_error(self):
"""localization error"""
"""Calculate the per-object localization error."""
ref_centroids = self._reference.centroids
tgt_centroids = self._predicted.centroids
positional_error = []
Expand All @@ -397,21 +396,24 @@ def plot(self):
def to_napari(self):
return render.render_metrics_napari(self)

def __repr__(self):
return self.results.__repr__()

class LabeledSegmentation:
"""LabeledSegmentation
def _repr_html_(self):
return self.results._repr_html_()

A helper class to enable simple calculation of accuracy statistics for
image segmentation output.

class LabeledSegmentation:
"""A helper class to enable simple calculation of accuracy statistics for
image segmentation output.
"""

def __init__(self, image: npt.NDArray):
self.image = image
self.labeled, self.n_labels = label(image.astype(bool))

@property
def shape(self):
def shape(self) -> Tuple[int]:
return self.image.shape

@property
Expand Down
5 changes: 0 additions & 5 deletions src/umetrics/render.py → src/umetrix/render.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
__author__ = "Alan R. Lowe"
__email__ = "a.lowe@ucl.ac.uk"


import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
Expand All @@ -17,7 +13,6 @@ def plot_metrics(seg_metrics):
iou[tp[0] - 1] = "{:.2f}".format(IoU[i])

fig, ax = plt.subplots(1, figsize=(16, 12))
# plt.imshow(J_image)
ax.imshow(seg_metrics.image_overlay)

for i, (sy, sx) in enumerate(ref.bboxes):
Expand Down
Loading

0 comments on commit f1d4bd1

Please sign in to comment.