From 88fd65c1586fa2578dbb21b2d7488d5d19de9a9d Mon Sep 17 00:00:00 2001 From: quantumjot Date: Thu, 16 Nov 2023 10:58:12 +0000 Subject: [PATCH] improvements --- src/umetrics/core.py | 171 ++++++++++++++++++++++++++---------------- tests/test_metrics.py | 8 ++ 2 files changed, 114 insertions(+), 65 deletions(-) diff --git a/src/umetrics/core.py b/src/umetrics/core.py index cea7d98..64f47c7 100644 --- a/src/umetrics/core.py +++ b/src/umetrics/core.py @@ -1,5 +1,6 @@ from __future__ import annotations +import enum import numpy as np import numpy.typing as npt @@ -14,6 +15,18 @@ from . import render +class Metrics(str, enum.Enum): + N_TRUE_LABELS = "n_true_labels" + N_PRED_LABELS = "n_pred_labels" + N_TRUE_POSITIVES = "n_true_positives" + N_FALSE_POSITIVES = "n_false_positives" + N_FALSE_NEGATIVES = "n_false_negatives" + IOU = "IoU" + JACCARD = "Jaccard" + PIXEL_IDENTITY = "pixel_identity" + LOCALIZATION_ERROR = "localization_error" + + METRICS = ( "n_true_labels", "n_pred_labels", @@ -27,7 +40,7 @@ ) -def _IoU(ref, pred) -> float: +def _IoU(ref: npt.NDArray, pred: npt.NDArray) -> float: """Calculate the IoU between two binary masks.""" intersection = np.sum(np.logical_and(ref, pred)) union = np.sum(np.logical_or(ref, pred)) @@ -35,7 +48,13 @@ def _IoU(ref, pred) -> float: return iou -def find_matches(ref: LabeledSegmentation, pred: LabeledSegmentation) -> Dict: +def find_matches( + ref: LabeledSegmentation, + pred: LabeledSegmentation, + *, + strict: bool = False, + iou_threshold: float = 0.5, +) -> Dict: """Perform matching between the reference and the predicted image. Parameters @@ -52,11 +71,12 @@ def find_matches(ref: LabeledSegmentation, pred: LabeledSegmentation) -> Dict: """ - # return a dictionary of found matches + # return a default dictionary of no matches matches = { "true_matches": [], - "in_ref_only": [], - "in_pred_only": [], + "true_matches_IoU": [], + "in_ref_only": set(ref.labels), + "in_pred_only": set(pred.labels), } # make an infinite cost matrix, so that we only consider matches where @@ -69,8 +89,14 @@ def find_matches(ref: LabeledSegmentation, pred: LabeledSegmentation) -> Dict: for pred_label in _matches: p_id = pred.labels.index(pred_label) reward = _IoU(mask, pred.labeled == pred_label) + if reward < iou_threshold and strict: + continue cost_matrix[r_id, p_id] = 1.0 - reward + # if it's strict, make sure every element is above the threshold + if strict: + assert np.all(cost_matrix >= iou_threshold) + try: sol_row, sol_col = linear_sum_assignment(cost_matrix) except ValueError: @@ -83,12 +109,13 @@ def find_matches(ref: LabeledSegmentation, pred: LabeledSegmentation) -> Dict: true_matches = list(zip(used_ref, used_pred)) # find the labels that haven't been used - in_ref_only = list(set(ref.labels).difference(used_ref)) - in_pred_only = list(set(pred.labels).difference(used_pred)) + in_ref_only = set(ref.labels).difference(used_ref) + in_pred_only = set(pred.labels).difference(used_pred) # return a dictionary of found matches matches = { "true_matches": true_matches, + "true_matches_IoU": 1.0 - cost_matrix[sol_row, sol_col], "in_ref_only": in_ref_only, "in_pred_only": in_pred_only, } @@ -118,20 +145,20 @@ def __getattr__(self, key): return getattr(self._metrics, key) @property - def n_images(self): + def n_images(self) -> int: if any([getattr(self, m) is None for m in self._agg]): return 0 else: return self._images - def __add__(self, result): + def __add__(self, result: MetricResults) -> MetricResults: assert isinstance(result, MetricResults) for m in self._agg: setattr(self, m, getattr(result, m) + getattr(self, m)) self._images += 1 return self - def __repr__(self): + def __repr__(self) -> str: title = f" Segmentation Metrics (n={self.n_images})\n" hbar = "=" * len(title) + "\n" r = hbar + title + hbar @@ -146,15 +173,15 @@ def __repr__(self): return r @property - def localization_error(self): + def localization_error(self) -> float: return np.mean(self.per_object_localization_error) @property - def IoU(self): + def IoU(self) -> float: return np.mean(self.per_object_IoU) @property - def Jaccard(self): + def Jaccard(self) -> float: """Jaccard metric""" tp = self.n_true_positives fn = self.n_false_negatives @@ -162,11 +189,11 @@ def Jaccard(self): return tp / (tp + fn + fp) @property - def pixel_identity(self): + def pixel_identity(self) -> float: return np.mean(self.per_image_pixel_identity) @staticmethod - def merge(results): + def merge(results) -> MetricResults: """merge n results together and return a single object""" assert isinstance(results, list) merged = results.pop(0) @@ -178,48 +205,57 @@ def merge(results): class SegmentationMetrics: - """SegmentationMetrics - - A class for calculating various segmentation metrics to assess the + """A class for calculating various segmentation metrics to assess the accuracy of a trained model. - Args: - reference - a numpy array (wxh) containing labeled objects from the - ground truth - predicted - a numpy array (wxh) containing labeled objects from the - segmentation algorithm - - strict - (bool) whether to disregard matches with a low IoU score - iou_threshold - (float) threshold for strict matching - - Properties: - Jaccard: the Jaccard index calculated according to the notes below - IoU: the Intersection over Union metric - localisation_precision: - - true_positives: - false_positives: - false_negatives: + Parameters + ---------- + reference : array + An array containing labeled objects from the ground truth. + predicted : array + An array containing labeled objects from the segmentation algorithm. + strict : bool + Whether to disregard matches with a low IoU score. + iou_threshold : float + Threshold IoU for strict matching. + + Properties + ---------- + Jaccard : float + The Jaccard index calculated according to the notes below. + IoU : float + The Intersection over Union metric. + localisation_precision : float + The localisation precision. + true_positives : int + Number of TP predictions. + false_positives : int + Number of FP predictions. + false_negatives : int + Number of FN predicitons. - Notes: - The Jaccard metric is calculated accordingly: + Notes + ----- + The Jaccard metric is calculated accordingly: - FP = number of objects in predicted but not in reference - TP = number of objects in both - TN = background correctly segmented (not used) - FN = number of objects in true but not in predicted + FP = number of objects in predicted but not in reference + TP = number of objects in both + TN = background correctly segmented (not used) + FN = number of objects in true but not in predicted - J = TP / (TP+FP+FN) + J = TP / (TP+FP+FN) - The IoU is calculated as the intersection of the binary segmentation - divided by the union. + The IoU is calculated as the intersection of the binary segmentation + divided by the union. - TODO(arl): need to address undersegmentation detection + TODO(arl): need to address undersegmentation detection """ - def __init__(self, reference, predicted, **kwargs): + def __init__( + self, reference: LabeledSegmentation, predicted: LabeledSegmentation, **kwargs + ): assert isinstance(predicted, LabeledSegmentation) assert isinstance(reference, LabeledSegmentation) @@ -228,31 +264,39 @@ def __init__(self, reference, predicted, **kwargs): self._strict = kwargs.get("strict", False) self._iou_threshold = kwargs.get("iou_threshold", 0.5) - assert self.iou_threshold >= 0.0 and self.iou_threshold <= 1.0 + if self.iou_threshold < 0.0 or self.iou_threshold > 1.0: + raise ValueError( + f"IoU Threshold shoud be in (0, 1), found: {self.iou_threshold:.2f}" + ) assert isinstance(self.strict, bool) # find the matches - self._matches = find_matches(self._reference, self._predicted) + self._matches = find_matches( + self._reference, + self._predicted, + strict=self.strict, + iou_threshold=self.iou_threshold, + ) - # if we're in strict mode, prune the matches - if self.strict: - iou = self.per_object_IoU - tp = [ - self.true_positives[i] - for i, ov in enumerate(iou) - if ov > self.iou_threshold - ] - fp = list(set(self.true_positives).difference(tp)) + # # if we're in strict mode, prune the matches + # if self.strict: + # iou = self.per_object_IoU + # tp = [ + # self.true_positives[i] + # for i, ov in enumerate(iou) + # if ov > self.iou_threshold + # ] + # fp = list(set(self.true_positives).difference(tp)) - self._matches["true_matches"] = tp - self._matches["in_pred_only"] += [m[1] for m in fp] + # self._matches["true_matches"] = tp + # self._matches["in_pred_only"] += [m[1] for m in fp] @property - def strict(self): + def strict(self) -> bool: return self._strict @property - def iou_threshold(self): + def iou_threshold(self) -> float: return self._iou_threshold @property @@ -314,10 +358,7 @@ def per_object_IoU(self): mask_ref = self._reference.labeled == m[0] mask_pred = self._predicted.labeled == m[1] - intersection = np.logical_and(mask_ref, mask_pred) - union = np.logical_or(mask_ref, mask_pred) - - iou.append(np.sum(intersection) / np.sum(union)) + iou.append(_IoU(mask_ref, mask_pred)) return iou @property diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 97b4d52..1ba782a 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -14,6 +14,14 @@ def test_calculate(image_pair, strict): # calculate the real number of true postives based on strict matching real_tp = int(IoU > result.iou_threshold) if strict else int(IoU > 0) + import matplotlib.pyplot as plt + + fig, ax = plt.subplots(1, 2) + ax[0].imshow(y_true) + ax[1].imshow(y_pred) + ax[1].set_title(f"IoU: {IoU}, threshold: {result.iou_threshold}") + plt.show() + assert result.n_true_labels == 1 assert result.n_pred_labels == 1 assert result.n_true_positives == real_tp