Skip to content

Commit

Permalink
wrap clustering into a function
Browse files Browse the repository at this point in the history
Former-commit-id: e1d8af6e4be1350796852479fc17fc27db9fad95
  • Loading branch information
Javi Ribera committed Sep 17, 2018
1 parent f709cb2 commit 7e3523f
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 32 deletions.
16 changes: 1 addition & 15 deletions object-locator/locate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from torchvision import transforms
import torchvision as tv
from torchvision.models import inception_v3
from sklearn import mixture
import skimage.transform
from peterpy import peter
from ballpark import ballpark
Expand Down Expand Up @@ -217,20 +216,7 @@
# The estimated map must be thresholded to obtain estimated points
for tau, df_out in zip(args.taus, df_outs):
mask, _ = utils.threshold(est_map_numpy_origsize, tau)
coord = np.where(mask > 0)
y = coord[0].reshape((-1, 1))
x = coord[1].reshape((-1, 1))
c = np.concatenate((y, x), axis=1)
if len(c) == 0:
centroids_wrt_orig = np.array([])
else:
# If the estimation is horrible, we cannot fit a GMM if n_components > n_samples
n_components = max(min(est_count_int, x.size), 1)
centroids_wrt_orig = mixture.GaussianMixture(n_components=n_components,
n_init=1,
covariance_type='full').\
fit(c).means_.astype(np.int)

centroids_wrt_orig = utils.cluster(mask, est_count_int)

# Save thresholded map to disk
os.makedirs(os.path.join(args.out_dir, 'estimated_map_thresholded', f'tau={tau}'),
Expand Down
19 changes: 2 additions & 17 deletions object-locator/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from torchvision.models import inception_v3
from torchvision import transforms
from torch.utils.data import DataLoader
from sklearn import mixture
import matplotlib
matplotlib.use('Agg')
import skimage.transform
Expand Down Expand Up @@ -333,24 +332,10 @@
output_shape=origsize,
mode='constant')
mask, _ = utils.threshold(est_map_numpy_origsize, tau=-1)
# Obtain centroids of the mask
centroids_wrt_orig = utils.cluster(mask, est_count_int)

# Validation metrics
coord = np.where(mask > 0)
y = coord[0].reshape((-1, 1))
x = coord[1].reshape((-1, 1))
c = np.concatenate((y, x), axis=1)
if len(c) == 0:
centroids_wrt_orig = []
est_count = 0
print('len(c) == 0')
else:
# If the estimation is horrible, we cannot fit a GMM if n_components > n_samples
n_components = max(min(est_count_int, x.size), 1)
centroids_wrt_orig = mixture.GaussianMixture(n_components=n_components,
n_init=1,
covariance_type='full').\
fit(c).means_.astype(np.int)

target_locations_wrt_orig = normalzr.unnormalize(target_locations_np,
orig_img_size=target_orig_size_np)
judge.feed_points(centroids_wrt_orig, target_locations_wrt_orig,
Expand Down
32 changes: 32 additions & 0 deletions object-locator/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import numpy as np
import sklearn.mixture
import cv2

class Normalizer():
Expand Down Expand Up @@ -66,3 +67,34 @@ def threshold(array, tau):
return mask, tau


def cluster(array, n_clusters):
"""
Cluster a 2-D binary array.
Applies a Gaussian Mixture Model on the positive elements of the array,
and returns the number of clusters.
:param array: Binary array.
:return: Centroids in the input array.
"""

array = np.array(array)

assert array.ndim == 2

coord = np.where(array > 0)
y = coord[0].reshape((-1, 1))
x = coord[1].reshape((-1, 1))
c = np.concatenate((y, x), axis=1)
if len(c) == 0:
centroids = np.array([])
else:
# If the estimation is horrible, we cannot fit a GMM if n_components > n_samples
n_components = max(min(n_clusters, x.size), 1)
centroids = sklearn.mixture.GaussianMixture(n_components=n_components,
n_init=1,
covariance_type='full').\
fit(c).means_.astype(np.int)

return centroids


0 comments on commit 7e3523f

Please sign in to comment.