Skip to content

Commit

Permalink
improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
quantumjot committed Nov 16, 2023
1 parent d5b528f commit 88fd65c
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 65 deletions.
171 changes: 106 additions & 65 deletions src/umetrics/core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import enum
import numpy as np
import numpy.typing as npt

Expand All @@ -14,6 +15,18 @@
from . import render


class Metrics(str, enum.Enum):
N_TRUE_LABELS = "n_true_labels"
N_PRED_LABELS = "n_pred_labels"
N_TRUE_POSITIVES = "n_true_positives"
N_FALSE_POSITIVES = "n_false_positives"
N_FALSE_NEGATIVES = "n_false_negatives"
IOU = "IoU"
JACCARD = "Jaccard"
PIXEL_IDENTITY = "pixel_identity"
LOCALIZATION_ERROR = "localization_error"


METRICS = (
"n_true_labels",
"n_pred_labels",
Expand All @@ -27,15 +40,21 @@
)


def _IoU(ref, pred) -> float:
def _IoU(ref: npt.NDArray, pred: npt.NDArray) -> float:
"""Calculate the IoU between two binary masks."""
intersection = np.sum(np.logical_and(ref, pred))
union = np.sum(np.logical_or(ref, pred))
iou = 0.0 if union == 0 else intersection / union
return iou


def find_matches(ref: LabeledSegmentation, pred: LabeledSegmentation) -> Dict:
def find_matches(
ref: LabeledSegmentation,
pred: LabeledSegmentation,
*,
strict: bool = False,
iou_threshold: float = 0.5,
) -> Dict:
"""Perform matching between the reference and the predicted image.
Parameters
Expand All @@ -52,11 +71,12 @@ def find_matches(ref: LabeledSegmentation, pred: LabeledSegmentation) -> Dict:
"""

# return a dictionary of found matches
# return a default dictionary of no matches
matches = {
"true_matches": [],
"in_ref_only": [],
"in_pred_only": [],
"true_matches_IoU": [],
"in_ref_only": set(ref.labels),
"in_pred_only": set(pred.labels),
}

# make an infinite cost matrix, so that we only consider matches where
Expand All @@ -69,8 +89,14 @@ def find_matches(ref: LabeledSegmentation, pred: LabeledSegmentation) -> Dict:
for pred_label in _matches:
p_id = pred.labels.index(pred_label)
reward = _IoU(mask, pred.labeled == pred_label)
if reward < iou_threshold and strict:
continue
cost_matrix[r_id, p_id] = 1.0 - reward

# if it's strict, make sure every element is above the threshold
if strict:
assert np.all(cost_matrix >= iou_threshold)

try:
sol_row, sol_col = linear_sum_assignment(cost_matrix)
except ValueError:
Expand All @@ -83,12 +109,13 @@ def find_matches(ref: LabeledSegmentation, pred: LabeledSegmentation) -> Dict:
true_matches = list(zip(used_ref, used_pred))

# find the labels that haven't been used
in_ref_only = list(set(ref.labels).difference(used_ref))
in_pred_only = list(set(pred.labels).difference(used_pred))
in_ref_only = set(ref.labels).difference(used_ref)
in_pred_only = set(pred.labels).difference(used_pred)

# return a dictionary of found matches
matches = {
"true_matches": true_matches,
"true_matches_IoU": 1.0 - cost_matrix[sol_row, sol_col],
"in_ref_only": in_ref_only,
"in_pred_only": in_pred_only,
}
Expand Down Expand Up @@ -118,20 +145,20 @@ def __getattr__(self, key):
return getattr(self._metrics, key)

@property
def n_images(self):
def n_images(self) -> int:
if any([getattr(self, m) is None for m in self._agg]):
return 0
else:
return self._images

def __add__(self, result):
def __add__(self, result: MetricResults) -> MetricResults:
assert isinstance(result, MetricResults)
for m in self._agg:
setattr(self, m, getattr(result, m) + getattr(self, m))
self._images += 1
return self

def __repr__(self):
def __repr__(self) -> str:
title = f" Segmentation Metrics (n={self.n_images})\n"
hbar = "=" * len(title) + "\n"
r = hbar + title + hbar
Expand All @@ -146,27 +173,27 @@ def __repr__(self):
return r

@property
def localization_error(self):
def localization_error(self) -> float:
return np.mean(self.per_object_localization_error)

@property
def IoU(self):
def IoU(self) -> float:
return np.mean(self.per_object_IoU)

@property
def Jaccard(self):
def Jaccard(self) -> float:
"""Jaccard metric"""
tp = self.n_true_positives
fn = self.n_false_negatives
fp = self.n_false_positives
return tp / (tp + fn + fp)

@property
def pixel_identity(self):
def pixel_identity(self) -> float:
return np.mean(self.per_image_pixel_identity)

@staticmethod
def merge(results):
def merge(results) -> MetricResults:
"""merge n results together and return a single object"""
assert isinstance(results, list)
merged = results.pop(0)
Expand All @@ -178,48 +205,57 @@ def merge(results):


class SegmentationMetrics:
"""SegmentationMetrics
A class for calculating various segmentation metrics to assess the
"""A class for calculating various segmentation metrics to assess the
accuracy of a trained model.
Args:
reference - a numpy array (wxh) containing labeled objects from the
ground truth
predicted - a numpy array (wxh) containing labeled objects from the
segmentation algorithm
strict - (bool) whether to disregard matches with a low IoU score
iou_threshold - (float) threshold for strict matching
Properties:
Jaccard: the Jaccard index calculated according to the notes below
IoU: the Intersection over Union metric
localisation_precision:
true_positives:
false_positives:
false_negatives:
Parameters
----------
reference : array
An array containing labeled objects from the ground truth.
predicted : array
An array containing labeled objects from the segmentation algorithm.
strict : bool
Whether to disregard matches with a low IoU score.
iou_threshold : float
Threshold IoU for strict matching.
Properties
----------
Jaccard : float
The Jaccard index calculated according to the notes below.
IoU : float
The Intersection over Union metric.
localisation_precision : float
The localisation precision.
true_positives : int
Number of TP predictions.
false_positives : int
Number of FP predictions.
false_negatives : int
Number of FN predicitons.
Notes:
The Jaccard metric is calculated accordingly:
Notes
-----
The Jaccard metric is calculated accordingly:
FP = number of objects in predicted but not in reference
TP = number of objects in both
TN = background correctly segmented (not used)
FN = number of objects in true but not in predicted
FP = number of objects in predicted but not in reference
TP = number of objects in both
TN = background correctly segmented (not used)
FN = number of objects in true but not in predicted
J = TP / (TP+FP+FN)
J = TP / (TP+FP+FN)
The IoU is calculated as the intersection of the binary segmentation
divided by the union.
The IoU is calculated as the intersection of the binary segmentation
divided by the union.
TODO(arl): need to address undersegmentation detection
TODO(arl): need to address undersegmentation detection
"""

def __init__(self, reference, predicted, **kwargs):
def __init__(
self, reference: LabeledSegmentation, predicted: LabeledSegmentation, **kwargs
):
assert isinstance(predicted, LabeledSegmentation)
assert isinstance(reference, LabeledSegmentation)

Expand All @@ -228,31 +264,39 @@ def __init__(self, reference, predicted, **kwargs):
self._strict = kwargs.get("strict", False)
self._iou_threshold = kwargs.get("iou_threshold", 0.5)

assert self.iou_threshold >= 0.0 and self.iou_threshold <= 1.0
if self.iou_threshold < 0.0 or self.iou_threshold > 1.0:
raise ValueError(
f"IoU Threshold shoud be in (0, 1), found: {self.iou_threshold:.2f}"
)
assert isinstance(self.strict, bool)

# find the matches
self._matches = find_matches(self._reference, self._predicted)
self._matches = find_matches(
self._reference,
self._predicted,
strict=self.strict,
iou_threshold=self.iou_threshold,
)

# if we're in strict mode, prune the matches
if self.strict:
iou = self.per_object_IoU
tp = [
self.true_positives[i]
for i, ov in enumerate(iou)
if ov > self.iou_threshold
]
fp = list(set(self.true_positives).difference(tp))
# # if we're in strict mode, prune the matches
# if self.strict:
# iou = self.per_object_IoU
# tp = [
# self.true_positives[i]
# for i, ov in enumerate(iou)
# if ov > self.iou_threshold
# ]
# fp = list(set(self.true_positives).difference(tp))

self._matches["true_matches"] = tp
self._matches["in_pred_only"] += [m[1] for m in fp]
# self._matches["true_matches"] = tp
# self._matches["in_pred_only"] += [m[1] for m in fp]

@property
def strict(self):
def strict(self) -> bool:
return self._strict

@property
def iou_threshold(self):
def iou_threshold(self) -> float:
return self._iou_threshold

@property
Expand Down Expand Up @@ -314,10 +358,7 @@ def per_object_IoU(self):
mask_ref = self._reference.labeled == m[0]
mask_pred = self._predicted.labeled == m[1]

intersection = np.logical_and(mask_ref, mask_pred)
union = np.logical_or(mask_ref, mask_pred)

iou.append(np.sum(intersection) / np.sum(union))
iou.append(_IoU(mask_ref, mask_pred))
return iou

@property
Expand Down
8 changes: 8 additions & 0 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ def test_calculate(image_pair, strict):
# calculate the real number of true postives based on strict matching
real_tp = int(IoU > result.iou_threshold) if strict else int(IoU > 0)

import matplotlib.pyplot as plt

fig, ax = plt.subplots(1, 2)
ax[0].imshow(y_true)
ax[1].imshow(y_pred)
ax[1].set_title(f"IoU: {IoU}, threshold: {result.iou_threshold}")
plt.show()

assert result.n_true_labels == 1
assert result.n_pred_labels == 1
assert result.n_true_positives == real_tp
Expand Down

0 comments on commit 88fd65c

Please sign in to comment.