Skip to content

Commit

Permalink
Revert "use regressed pts in WHD loss"
Browse files Browse the repository at this point in the history
This reverts commit 92f4a1a [formerly f8b10aa3fca6f234c1d80b0f6546e5c1d540d884].


Former-commit-id: 24dc20167a08ff7fe60a7e619fa057631f328462
  • Loading branch information
Javier Ribera committed Nov 4, 2018
1 parent 5c00abd commit 918c438
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 49 deletions.
54 changes: 23 additions & 31 deletions object-locator/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,18 +108,17 @@ def forward(self, set1, set2):
return res


class WeightedHausdorffLoss(nn.Module):
class WeightedHausdorffDistance(nn.Module):
def __init__(self,
resized_height, resized_width,
p=-9,
regression_loss=F.smooth_l1_loss,
return_3_terms=False,
return_2_terms=False,
device=torch.device('cpu')):
"""
:param resized_height: Number of rows in the image.
:param resized_width: Number of columns in the image.
:param p: Exponent in the generalized mean. -inf makes it the minimum.
:param return_3_terms: Whether to return the 3 terms
:param return_2_terms: Whether to return the 2 terms
of the WHD instead of their sum.
Default: False.
:param device: Device where all Tensors will reside.
Expand All @@ -140,28 +139,23 @@ def __init__(self,
self.all_img_locations = torch.tensor(self.all_img_locations,
dtype=torch.get_default_dtype()).to(device)

self.return_3_terms = return_3_terms

self.regression_loss = regression_loss
self.return_2_terms = return_2_terms

self.p = p

def forward(self, prob_map, count_estim, gt_loc, gt_count, orig_sizes):
def forward(self, prob_map, gt, orig_sizes):
"""
Compute the Weighted Hausdorff Distance function
between the estimated probability map and ground truth points.
The output is the WHD averaged through all the batch.
It includes the two terms plus the regression term.
:param prob_map: (B x H x W) Tensor of the probability map of the estimation.
B is batch size, H is height and W is width.
Values must be between 0 and 1.
:param count_estim: (B, )-sized tensor with count estimates.
:param gt_loc: List of Tensors of the Ground Truth points.
Must be of size B as in prob_map.
Each element in the list must be a 2D Tensor,
where each row is the (y, x), i.e, (row, col) of a GT point.
:param gt_count: (B, )-sized tensor with count labels.
:param gt: List of Tensors of the Ground Truth points.
Must be of size B as in prob_map.
Each element in the list must be a 2D Tensor,
where each row is the (y, x), i.e, (row, col) of a GT point.
:param orig_sizes: Bx2 Tensor containing the size of the original images.
B is batch size. The size must be in (height, width) format.
:param orig_widths: List of the original width for each image in the batch.
Expand All @@ -170,8 +164,7 @@ def forward(self, prob_map, count_estim, gt_loc, gt_count, orig_sizes):
the two terms of the Weighted Hausdorff Distance.
"""

_assert_no_grad(gt_loc)
_assert_no_grad(gt_count)
_assert_no_grad(gt)

assert prob_map.dim() == 3, 'The probability map must be (B x H x W)'
assert prob_map.size()[1:3] == (self.height, self.width), \
Expand All @@ -180,62 +173,61 @@ def forward(self, prob_map, count_estim, gt_loc, gt_count, orig_sizes):
% str(prob_map.size())

batch_size = prob_map.shape[0]
assert batch_size == len(gt_loc)
assert batch_size == len(gt)

terms_1 = []
terms_2 = []
for b in range(batch_size):

# One by one
prob_map_b = prob_map[b, :, :]
gt_loc_b = gt_loc[b]
count_estim_b = count_estim[b]
gt_b = gt[b]
orig_size_b = orig_sizes[b, :]
norm_factor = (orig_size_b/self.resized_size).unsqueeze(0)
n_gt_pts = gt_loc_b.size()[0]
n_gt_pts = gt_b.size()[0]

# Corner case: no GT points
if gt_loc_b.ndimension() == 1 and (gt_loc_b < 0).all().item() == 0:
if gt_b.ndimension() == 1 and (gt_b < 0).all().item() == 0:
terms_1.append(torch.tensor([0],
dtype=torch.get_default_dtype()))
terms_2.append(torch.tensor([self.max_dist],
dtype=torch.get_default_dtype()))
continue

# Pairwise distances between all possible locations and the GTed locations
n_gt_pts = gt_loc_b.size()[0]
n_gt_pts = gt_b.size()[0]
normalized_x = norm_factor.repeat(self.n_pixels, 1) *\
self.all_img_locations
normalized_y = norm_factor.repeat(len(gt_loc_b), 1)*gt_loc_b
normalized_y = norm_factor.repeat(len(gt_b), 1)*gt_b
d_matrix = cdist(normalized_x, normalized_y)

# Reshape probability map as a long column vector,
# and prepare it for multiplication
p = prob_map_b.view(prob_map_b.nelement())
n_est_pts = p.sum()
p_replicated = p.view(-1, 1).repeat(1, n_gt_pts)

# Weighted Hausdorff Distance
term_1 = (1 / (count_estim_b + 1e-6)) * \
term_1 = (1 / (n_est_pts + 1e-6)) * \
torch.sum(p * torch.min(d_matrix, 1)[0])
weighted_d_matrix = (1 - p_replicated)*self.max_dist + p_replicated*d_matrix
minn = generaliz_mean(weighted_d_matrix,
p=self.p,
dim=0, keepdim=False)
term_2 = torch.mean(minn)

# terms_1[b] = term_1
# terms_2[b] = term_2
terms_1.append(term_1)
terms_2.append(term_2)

terms_1 = torch.stack(terms_1)
terms_2 = torch.stack(terms_2)

# Regression term
term3_mean = self.regression_loss(count_estim, gt_count)

if self.return_3_terms:
res = terms_1.mean(), terms_2.mean(), term3_mean
if self.return_2_terms:
res = terms_1.mean(), terms_2.mean()
else:
res = terms_1.mean() + terms_2.mean() + term3_mean
res = terms_1.mean() + terms_2.mean()

return res

Expand Down
33 changes: 15 additions & 18 deletions object-locator/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,12 @@
model.to(device)

# Loss functions
loss_fctn = losses.WeightedHausdorffLoss(resized_height=args.height,
resized_width=args.width,
p=args.p,
return_3_terms=True,
device=device)
loss_regress = nn.SmoothL1Loss()
loss_loc = losses.WeightedHausdorffDistance(resized_height=args.height,
resized_width=args.width,
p=args.p,
return_2_terms=True,
device=device)

# Optimization strategy
if args.optimizer == 'sgd':
Expand Down Expand Up @@ -195,13 +196,12 @@
# One training step
optimizer.zero_grad()
est_maps, est_counts = model.forward(imgs)
term1, term2 = loss_loc.forward(est_maps,
target_locations,
target_orig_sizes)
est_counts = est_counts.view(-1)
target_counts = target_counts.view(-1)
term1, term2, term3 = loss_fctn.forward(est_maps,
est_counts,
target_locations,
target_counts,
target_orig_sizes)
term3 = loss_regress.forward(est_counts, target_counts)
term3 *= args.lambdaa
loss = term1 + term2 + term3
loss.backward()
Expand Down Expand Up @@ -344,13 +344,10 @@

# The 3 terms
with torch.no_grad():
est_counts = est_counts.view(-1)
target_counts = target_counts.view(-1)
term1, term2, term3 = loss_fctn.forward(est_maps,
est_counts,
target_locations,
target_counts,
target_orig_sizes)
term1, term2 = loss_loc.forward(est_maps,
target_locations,
target_orig_sizes)
term3 = loss_regress.forward(est_counts, target_counts)
term3 *= args.lambdaa
sum_term1 += term1.item()
sum_term2 += term2.item()
Expand All @@ -377,7 +374,7 @@
target_locations_wrt_orig = normalzr.unnormalize(target_locations_np,
orig_img_size=target_orig_size_np)
judge.feed_points(centroids_wrt_orig, target_locations_wrt_orig,
max_ahd=loss_fctn.max_dist)
max_ahd=loss_loc.max_dist)
judge.feed_count(est_count_int, target_count_int)

if time.time() > tic_val + args.log_interval:
Expand Down

0 comments on commit 918c438

Please sign in to comment.