Skip to content

Commit

Permalink
use item() from pythorch0.4
Browse files Browse the repository at this point in the history
Former-commit-id: 83b3628fb1141578ba958fff0433d0839c38ad7a
  • Loading branch information
Javi Ribera committed Sep 16, 2018
1 parent 905fce6 commit 639ba94
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
9 changes: 4 additions & 5 deletions object-locator/locate.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,8 @@
dictionaries[0]['filename']),
est_map_numpy_origsize)

# Convert to scalar
est_count = est_count.to(device_cpu).numpy()[0][0]
# Tensor -> int
est_count_int = int(round(est_count.item()))

# The estimated map must be thresholded to obtain estimated points
for tau, df_out in zip(args.taus, df_outs):
Expand All @@ -236,9 +236,8 @@
ahd = criterion_training.max_dist
centroids_wrt_orig = np.array([])
else:
n_components = int(round(est_count))
# If the estimation is horrible, we cannot fit a GMM if n_components > n_samples
n_components = max(min(n_components, x.size), 1)
n_components = max(min(est_count_int, x.size), 1)
centroids_wrt_orig = mixture.GaussianMixture(n_components=n_components,
n_init=1,
covariance_type='full').\
Expand Down Expand Up @@ -282,7 +281,7 @@
continue
judge.feed_points(centroids_wrt_orig, target_locations_wrt_orig,
max_ahd=math.sqrt(origsize[0]**2 + origsize[1]**2))
judge.feed_count(est_count, target_count)
judge.feed_count(est_count_int, target_count)

# Save a new line in the CSV corresonding to the resuls of this img
res_dict = dictionaries[0]
Expand Down
2 changes: 1 addition & 1 deletion object-locator/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@
image_with_x = tensortype(imgs[0, :, :].squeeze().size()).\
copy_(imgs[0, :, :].squeeze())
image_with_x = ((image_with_x + 1) / 2.0 * 255.0)
image_with_x = image_with_x.cpu().numpy()
image_with_x = image_with_x.to(device_cpu).numpy()
image_with_x = np.moveaxis(image_with_x, 0, 2).copy()
for y, x in centroids:
image_with_x = cv2.circle(
Expand Down

0 comments on commit 639ba94

Please sign in to comment.