From 257f37aaefc31dc26d7724e6f87b0169a623a66a Mon Sep 17 00:00:00 2001 From: Javi Ribera Date: Sun, 16 Sep 2018 19:39:03 -0400 Subject: [PATCH] encapsulate thresholding into a function Former-commit-id: d78ecf6b20b15e9e622ce4b1cb08bf151b8c0483 --- object-locator/locate.py | 13 +------------ object-locator/train.py | 12 +----------- object-locator/utils.py | 30 ++++++++++++++++++++++++++++++ 3 files changed, 32 insertions(+), 23 deletions(-) diff --git a/object-locator/locate.py b/object-locator/locate.py index 75c4770..5f5c15c 100644 --- a/object-locator/locate.py +++ b/object-locator/locate.py @@ -216,18 +216,7 @@ # The estimated map must be thresholded to obtain estimated points for tau, df_out in zip(args.taus, df_outs): - if tau == -1: - # Otsu thresholding - minn, maxx = est_map_numpy_origsize.min(), est_map_numpy_origsize.max() - est_map_origsize_scaled = ((est_map_numpy_origsize - minn)/(maxx - minn)*255) \ - .round().astype(np.uint8).squeeze() - tau_otsu, mask = cv2.threshold(est_map_origsize_scaled, - 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) - tau_otsu = minn + (tau_otsu/255)*(maxx - minn) - # print(f'Otsu selected tau={tau_otsu}') - else: - # Thresholding with a fixed threshold tau - mask = cv2.inRange(est_map_numpy_origsize, tau, 1) + mask, _ = utils.threshold(est_map_numpy_origsize, tau) coord = np.where(mask > 0) y = coord[0].reshape((-1, 1)) x = coord[1].reshape((-1, 1)) diff --git a/object-locator/train.py b/object-locator/train.py index ce92dbc..5b4f765 100644 --- a/object-locator/train.py +++ b/object-locator/train.py @@ -328,17 +328,7 @@ # The estimated map must be thresholed to obtain estimated points # Otsu thresholding - est_map_numpy = est_maps[0, :, :].to(device_cpu).numpy() - est_map_numpy_origsize = \ - skimage.transform.resize(est_map_numpy, - output_shape=origsize, - mode='constant') - minn, maxx = est_map_numpy_origsize.min(), est_map_numpy_origsize.max() - est_map_origsize_scaled = ((est_map_numpy_origsize - minn)/(maxx - minn)*255) \ - .round().astype(np.uint8).squeeze() - tau_otsu, mask = cv2.threshold(est_map_origsize_scaled, - 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) - tau_otsu = minn + (tau_otsu/255)*(maxx - minn) + mask, _ = utils.threshold(est_map_numpy_origsize, tau=-1) # Validation metrics coord = np.where(mask > 0) diff --git a/object-locator/utils.py b/object-locator/utils.py index 397f39f..a12ccf4 100644 --- a/object-locator/utils.py +++ b/object-locator/utils.py @@ -1,5 +1,6 @@ import torch import numpy as np +import cv2 class Normalizer(): def __init__(self, new_size_height, new_size_width): @@ -36,3 +37,32 @@ def unnormalize(self, coordinates_yx_normalized, orig_img_size): return coordinates_yx_unnormalized +def threshold(array, tau): + """ + Threshold an array using either hard thresholding or Otsu thresholding. + + :param array: Array to threshold. + :param tau: (float) Threshold to use. + Values above tau become 1, and values below tau become 0. + If -1, use Otsu thresholding. + :return: Tuple, where first element is the binary mask, and the second one + is the threshold used. When using Otsu thresholding, this threshold will be + is obtained adaptively according to the values of the input array. + + """ + if tau == -1: + # Otsu thresholding + minn, maxx = array.min(), array.max() + array_scaled = ((array - minn)/(maxx - minn)*255) \ + .round().astype(np.uint8).squeeze() + tau, mask = cv2.threshold(array_scaled, + 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) + tau = minn + (tau/255)*(maxx - minn) + # print(f'Otsu selected tau={tau_otsu}') + else: + # Thresholding with a fixed threshold tau + mask = cv2.inRange(array, tau, 1) + + return mask, tau + +