Skip to content

Commit

Permalink
volatile Variables during inference
Browse files Browse the repository at this point in the history
Former-commit-id: d378ec7426df20029fdcd6b8650b23006b4c824f
  • Loading branch information
Javi Ribera committed Mar 5, 2018
1 parent e803292 commit bb749d5
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions plant-locator/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,10 +287,10 @@
target_count = torch.stack([dictt['plant_count']
for dictt in dictionaries])

imgs = Variable(imgs.type(tensortype))
target_locations = [Variable(t.type(tensortype))
imgs = Variable(imgs.type(tensortype), volatile=True)
target_locations = [Variable(t.type(tensortype), volatile=True)
for t in target_locations]
target_count = Variable(target_count.type(tensortype))
target_count = Variable(target_count.type(tensortype), volatile=True)

# Feed-forward
est_map, est_count = model.forward(imgs)
Expand Down Expand Up @@ -330,7 +330,7 @@
).numpy().reshape(-1, 2)
ahd = losses.averaged_hausdorff_distance(
centroids, target_locations)
ahd = Variable(tensortype([ahd]))
ahd = tensortype([ahd])
sum_ahd += ahd

# Validation using Precision and Recall
Expand Down Expand Up @@ -376,7 +376,7 @@
avg_loss_val = sum_loss / len(valset_loader)
avg_ahd_val = sum_ahd / len(valset_loader)
prec, rec = judge.get_p_n_r()
prec, rec = Variable(tensortype([prec])), Variable(tensortype([rec]))
prec, rec = tensortype([prec]), tensortype([rec])

# Log validation losses
log.val_losses(terms=(avg_term1_val,
Expand All @@ -396,7 +396,7 @@
'Recall (%)'])

# If this is the best epoch (in terms of validation error)
avg_ahd_val_float = avg_ahd_val.data.cpu().numpy()[0]
avg_ahd_val_float = avg_ahd_val.cpu().numpy()[0]
if avg_ahd_val_float < lowest_avg_ahd_val:
# Keep the best model
lowest_avg_ahd_val = avg_ahd_val_float
Expand Down

0 comments on commit bb749d5

Please sign in to comment.