Skip to content

Commit

Permalink
encapsulate thresholding into a function
Browse files Browse the repository at this point in the history
Former-commit-id: d78ecf6b20b15e9e622ce4b1cb08bf151b8c0483
  • Loading branch information
Javi Ribera committed Sep 16, 2018
1 parent 82c59ee commit 257f37a
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 23 deletions.
13 changes: 1 addition & 12 deletions object-locator/locate.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,18 +216,7 @@

# The estimated map must be thresholded to obtain estimated points
for tau, df_out in zip(args.taus, df_outs):
if tau == -1:
# Otsu thresholding
minn, maxx = est_map_numpy_origsize.min(), est_map_numpy_origsize.max()
est_map_origsize_scaled = ((est_map_numpy_origsize - minn)/(maxx - minn)*255) \
.round().astype(np.uint8).squeeze()
tau_otsu, mask = cv2.threshold(est_map_origsize_scaled,
0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
tau_otsu = minn + (tau_otsu/255)*(maxx - minn)
# print(f'Otsu selected tau={tau_otsu}')
else:
# Thresholding with a fixed threshold tau
mask = cv2.inRange(est_map_numpy_origsize, tau, 1)
mask, _ = utils.threshold(est_map_numpy_origsize, tau)
coord = np.where(mask > 0)
y = coord[0].reshape((-1, 1))
x = coord[1].reshape((-1, 1))
Expand Down
12 changes: 1 addition & 11 deletions object-locator/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,17 +328,7 @@

# The estimated map must be thresholed to obtain estimated points
# Otsu thresholding
est_map_numpy = est_maps[0, :, :].to(device_cpu).numpy()
est_map_numpy_origsize = \
skimage.transform.resize(est_map_numpy,
output_shape=origsize,
mode='constant')
minn, maxx = est_map_numpy_origsize.min(), est_map_numpy_origsize.max()
est_map_origsize_scaled = ((est_map_numpy_origsize - minn)/(maxx - minn)*255) \
.round().astype(np.uint8).squeeze()
tau_otsu, mask = cv2.threshold(est_map_origsize_scaled,
0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
tau_otsu = minn + (tau_otsu/255)*(maxx - minn)
mask, _ = utils.threshold(est_map_numpy_origsize, tau=-1)

# Validation metrics
coord = np.where(mask > 0)
Expand Down
30 changes: 30 additions & 0 deletions object-locator/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import numpy as np
import cv2

class Normalizer():
def __init__(self, new_size_height, new_size_width):
Expand Down Expand Up @@ -36,3 +37,32 @@ def unnormalize(self, coordinates_yx_normalized, orig_img_size):

return coordinates_yx_unnormalized

def threshold(array, tau):
"""
Threshold an array using either hard thresholding or Otsu thresholding.
:param array: Array to threshold.
:param tau: (float) Threshold to use.
Values above tau become 1, and values below tau become 0.
If -1, use Otsu thresholding.
:return: Tuple, where first element is the binary mask, and the second one
is the threshold used. When using Otsu thresholding, this threshold will be
is obtained adaptively according to the values of the input array.
"""
if tau == -1:
# Otsu thresholding
minn, maxx = array.min(), array.max()
array_scaled = ((array - minn)/(maxx - minn)*255) \
.round().astype(np.uint8).squeeze()
tau, mask = cv2.threshold(array_scaled,
0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
tau = minn + (tau/255)*(maxx - minn)
# print(f'Otsu selected tau={tau_otsu}')
else:
# Thresholding with a fixed threshold tau
mask = cv2.inRange(array, tau, 1)

return mask, tau


0 comments on commit 257f37a

Please sign in to comment.