Skip to content

Commit

Permalink
Includes modification to estimate plants without regressor model
Browse files Browse the repository at this point in the history
  • Loading branch information
David committed Nov 8, 2017
1 parent 095b3a1 commit 29bdd8c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
1 change: 0 additions & 1 deletion losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,3 @@ def forward(self, prob_map, gt):
res = term_1 + term_2

return res

17 changes: 11 additions & 6 deletions train_and_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sys
import time
import shutil
from itertools import chain

import numpy as np
import pandas as pd
Expand All @@ -25,6 +26,7 @@
import losses
import unet_model


# Training settings
parser = argparse.ArgumentParser(description='Plant Location with PyTorch')
parser.add_argument('--train-dir', required=True,
Expand Down Expand Up @@ -169,7 +171,8 @@ def __getitem__(self, idx):
criterion_training = chamfer_loss

# Optimization strategy
optimizer = optim.SGD(model.parameters(),
alpha = Variable(torch.FloatTensor([1]).cuda(), requires_grad=True)
optimizer = optim.SGD(chain([alpha], model.parameters()),
lr=args.lr)

start_epoch = 0
Expand Down Expand Up @@ -213,7 +216,7 @@ def __getitem__(self, idx):

# Read image with GT dots from disk
gt_img_numpy = skimage.io.imread(
os.path.join('/home/jprat/projects/phenosorg/data/plant_counts_dots/20160613_F54_training_256x256_white_bigdots',
os.path.join('/home/dgueraco/cvpr/plant-data/plant_counts_dots/20160613_F54_training_256x256_white_bigdots',
dictionary['filename'][0]))
dots_img_tensor = torch.from_numpy(gt_img_numpy).permute(
2, 0, 1)[0, :, :].type(torch.FloatTensor) / 255
Expand All @@ -233,10 +236,13 @@ def __getitem__(self, idx):

# One training step
optimizer.zero_grad()
est_map, est_n_plants = model.forward(data)
est_map, _ = model.forward(data)
est_map = est_map.squeeze()
term1, term2 = criterion_training.forward(est_map, target)
term3 = l1_loss.forward(est_n_plants, target_n_plants) / \
sum_est_map = torch.sum(est_map)
# print(alpha)
# print(sum_est_map)
term3 = l1_loss.forward(alpha*sum_est_map, target_n_plants) / \
target_n_plants.type(torch.cuda.FloatTensor)
loss = term1 + term2 + term3
loss.backward()
Expand Down Expand Up @@ -306,7 +312,7 @@ def __getitem__(self, idx):

# Read image with GT dots from disk
gt_img_numpy = skimage.io.imread(
os.path.join('/home/jprat/projects/phenosorg/data/plant_counts_dots/20160613_F54_validation_256x256_white_bigdots',
os.path.join('/home/dgueraco/cvpr/plant-data/plant_counts_dots/20160613_F54_validation_256x256_white_bigdots',
dictionary['filename'][0]))
dots_img_tensor = torch.from_numpy(gt_img_numpy).permute(
2, 0, 1)[0, :, :].type(torch.FloatTensor) / 255
Expand Down Expand Up @@ -389,4 +395,3 @@ def __getitem__(self, idx):
print("Saved best checkpoint so far in %s " % best_ckpt_path)

epoch += 1

0 comments on commit 29bdd8c

Please sign in to comment.