Skip to content

Commit

Permalink
normalize centroids in a new class
Browse files Browse the repository at this point in the history
Former-commit-id: 23b4c4109d3f7594d0580c945b8932ec98bdc953
  • Loading branch information
Javi Ribera committed Sep 15, 2018
1 parent 9efadac commit d5b5e31
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions object-locator/locate.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,10 @@
target_count = target_count.item()
target_locations = \
target_locations[0].to(device_cpu).numpy().reshape(-1, 2)
target_orig_size = \
target_orig_sizes[0].to(device_cpu).numpy().reshape(2)

normalzr = utils.Normalizer(args.height, args.width)

# Feed forward
with torch.no_grad():
Expand Down Expand Up @@ -230,16 +233,17 @@
c = np.concatenate((y, x), axis=1)
if len(c) == 0:
ahd = criterion_training.max_dist
centroids = np.array([])
centroids_wrt_orig = np.array([])
else:
n_components = int(round(est_count))
# If the estimation is horrible, we cannot fit a GMM if n_components > n_samples
n_components = max(min(n_components, x.size), 1)
centroids = mixture.GaussianMixture(n_components=n_components,
centroids_wrt_orig = mixture.GaussianMixture(n_components=n_components,
n_init=1,
covariance_type='full').\
fit(c).means_.astype(np.int)


# Save thresholded map to disk
os.makedirs(os.path.join(args.out_dir, 'estimated_map_thresholded', f'tau={tau}'),
exist_ok=True)
Expand All @@ -257,7 +261,7 @@
output_shape=origsize,
mode='constant')
image_with_x = ((image_with_x + 1) / 2.0 * 255.0)
for y, x in centroids:
for y, x in centroids_wrt_orig:
image_with_x = cv2.circle(image_with_x, (x, y), 3, [255, 0, 0], -1)
# Save original image with circle to disk
image_with_x = image_with_x[:, :, ::-1]
Expand All @@ -268,24 +272,21 @@


if args.evaluate:
# Normalize to use locations in the original image
norm_factor = target_orig_sizes[0].unsqueeze(0).cpu().numpy() \
/ resized_size
norm_factor = norm_factor.repeat(len(target_locations), axis=0)
target_locations_wrt_orig = norm_factor*target_locations
target_locations_wrt_orig = normalzr.unnormalize(target_locations,
orig_img_size=target_orig_size)

# Compute metrics for each value of r (for each Judge)
for judge in judges:
if judge.th != tau:
continue
judge.feed_points(centroids, target_locations_wrt_orig,
judge.feed_points(centroids_wrt_orig, target_locations_wrt_orig,
max_ahd=math.sqrt(origsize[0]**2 + origsize[1]**2))
judge.feed_count(est_count, target_count)

# Save a new line in the CSV corresonding to the resuls of this img
res_dict = dictionaries[0]
res_dict['count'] = est_count
res_dict['locations'] = str(centroids.tolist())
res_dict['locations'] = str(centroids_wrt_orig.tolist())
for key, val in res_dict.copy().items():
if 'height' in key or 'width' in key:
del res_dict[key]
Expand Down

0 comments on commit d5b5e31

Please sign in to comment.