Skip to content

Commit

Permalink
update locate (test)
Browse files Browse the repository at this point in the history
Former-commit-id: f7c172405c6e2b7175d0eb1bd4a11f5c72099979
  • Loading branch information
Javi Ribera committed Mar 6, 2018
1 parent a860c77 commit 561311d
Showing 1 changed file with 57 additions and 135 deletions.
192 changes: 57 additions & 135 deletions plant-locator/locate.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import print_function

import argparse
import os
import sys
Expand All @@ -19,26 +21,25 @@
import torchvision as tv
from torchvision.models import inception_v3
from sklearn import mixture
from .data import CSVDataset
from .data import csv_collator

from . import losses
from . import unet_model
from .models import unet_model
from .eval_precision_recall import Judge
from . import utils

# Testing settings
parser = argparse.ArgumentParser(description='Plant Location with PyTorch (inference/test only)',
parser = argparse.ArgumentParser(description='BoundingBox-less Location with PyTorch (inference/test only)',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--dataset', required=True,
help='REQUIRED. Directory with test images.\n')
# parser.add_argument('--eval-batch-size', type=int, default=1, metavar='N',
# help='Input batch size.')
parser.add_argument('--model', type=str, metavar='PATH',
default='unet_256x256_sorghum',
help='Checkpoint with the CNN model.\n')
parser.add_argument('--out-dir', type=str, required=True,
help='REQUIRED. Directory where results will be stored (images+CSV).')
# parser.add_argument('--imgsize', type=str, default='256x256', metavar='HxW',
# help='Size of the input images (heightxwidth).')
# help='Size of the input images (heightxwidth).')
parser.add_argument('--radius', type=int, default=5, metavar='R',
help='Detections at dist <= R to a GT pt are True Positives.')
parser.add_argument('--paint', default=True, action="store_true",
Expand Down Expand Up @@ -82,108 +83,29 @@
print("\__ E: The input --imgsize must be in format WxH, got '{}'".format(args.imgsize))
exit(-1)


class CSVDataset(data.Dataset):
def __init__(self, directory, transform=None, max_dataset_size=np.inf):
"""CSVDataset.
The sample images of this dataset must be all inside one directory.
Inside the same directory, there must be one CSV file.
This file must contain one row per image.
It can containas many columns as wanted, i.e, filename, count...
:param directory: Directory with all the images and the CSV file.
:param transform: Transform to be applied to each image.
:param max_dataset_size: Only use the first N images in the directory.
"""

self.root_dir = directory
self.transform = transform

# Get groundtruth from CSV file
listfiles = os.listdir(directory)
csv_filename = None
for filename in listfiles:
if filename.endswith('.csv'):
csv_filename = filename
break

self.there_is_gt = csv_filename is not None

# CSV does not exist (no GT available)
if not self.there_is_gt:
print('W: The dataset directory %s does not contain a CSV file with groundtruth. \n' \
' Metrics will not be evaluated. Only estimations will be returned.' % directory)
self.csv_df = None
self.listfiles = listfiles

# Make dataset smaller
self.listfiles = self.listfiles[0:min(len(self.listfiles), max_dataset_size)]

# CSV does exist (GT is available)
else:
self.csv_df = pd.read_csv(os.path.join(directory, csv_filename))

# Make dataset smaller
self.csv_df = self.csv_df[0:min(len(self.csv_df), max_dataset_size)]

def __len__(self):
if self.there_is_gt:
return len(self.csv_df)
else:
return len(self.listfiles)

def __getitem__(self, idx):
"""Get one element of the dataset.
Returns a tuple. The first element is the image.
The second element is a dictionary where the keys are the columns of the CSV.
If the CSV did not exist in the dataset directory,
the dictionary will only contain the filename of the image.
:param idx: Index of the image in the dataset to get.
"""

if self.there_is_gt:
img_abspath = os.path.join(self.root_dir, self.csv_df.ix[idx, 0])
dictionary = dict(self.csv_df.ix[idx])
else:
img_abspath = os.path.join(self.root_dir, self.listfiles[idx])
dictionary = {'filename': self.listfiles[idx]}

img = skimage.io.imread(img_abspath)

if self.transform:
transformed = self.transform(img)
else:
transformed = img

return (transformed, dictionary)


# Force batchsize == 1
args.eval_batch_size = 1
if args.eval_batch_size != 1:
raise NotImplementedError('Only a batch size of 1 is implemented for now, got %s'
% args.eval_batch_size)
# Tensor type to use, select CUDA or not
tensortype = torch.cuda.FloatTensor if args.cuda else torch.FloatTensor
tensortype_cpu = torch.FloatTensor

# Data loading code
testset = CSVDataset(args.dataset,
transform=transforms.Compose([
transforms=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5)),
]),
max_dataset_size=args.max_testset_size)
max_dataset_size=args.max_testset_size,
tensortype=tensortype_cpu)
testset_loader = data.DataLoader(testset,
batch_size=args.eval_batch_size,
num_workers=args.nThreads)

# Tensor type to use, select CUDA or not
tensortype = torch.cuda.FloatTensor if args.cuda else torch.FloatTensor
batch_size=1,
num_workers=args.nThreads,
collate_fn=csv_collator)

# Loss function
l1_loss = nn.L1Loss()
l1_loss = nn.L1Loss(reduce=False)
criterion_training = losses.WeightedHausdorffDistance(height=height, width=width,
return_2_terms=True)
return_2_terms=True,
tensortype=tensortype)

# Restore saved checkpoint (model weights)
print("Loading checkpoint '{}' ...".format(args.model))
Expand Down Expand Up @@ -233,7 +155,7 @@ def __getitem__(self, idx):


# Empty output CSV
df_out = pd.DataFrame(columns=['plant_count'])
df_out = pd.DataFrame(columns=['count'])

# Set the module in evaluation mode
model.eval()
Expand All @@ -243,41 +165,33 @@ def __getitem__(self, idx):
sum_ahd = 0
sum_ape = 0

for batch_idx, (data, dictionary) in tqdm(enumerate(testset_loader),
for batch_idx, (imgs, dictionaries) in tqdm(enumerate(testset_loader),
total=len(testset_loader)):

# Prepare data
data = data.type(tensortype)
data = Variable(data, volatile=True)
imgs = Variable(imgs.type(tensortype), volatile=True)

if testset.there_is_gt:
# Pull info from this sample image
gt_plant_locations = [eval(el) for el in dictionary['plant_locations']]
target_n_plants = dictionary['plant_count']

# We cannot deal with images with 0 plants (HD is not defined)
if any(len(target_one_img) == 0 for target_one_img in gt_plant_locations):
continue

target = gt_plant_locations
# Pull info from this batch
target_locations = [dictt['locations'] for dictt in dictionaries]
target_count = torch.stack([dictt['count']
for dictt in dictionaries])

# Prepare targets
target_n_plants = target_n_plants.type(tensortype)
target = torch.FloatTensor(target).type(tensortype)
target, target_n_plants = Variable(target, volatile=True), \
Variable(target_n_plants, volatile=True)
target = target.squeeze()
target_locations = [Variable(t.type(tensortype), volatile=True)
for t in target_locations]
target_count = Variable(target_count.type(tensortype), volatile=True)

# Feed forward
est_map, est_n_plants = model.forward(data)
est_map = est_map.squeeze()
est_map, est_count = model.forward(imgs)

# Save estimated map to disk
tv.utils.save_image(est_map.data,
os.path.join(args.out_dir, 'est_map', dictionary['filename'][0]))
tv.utils.save_image(est_map.data[0, :, :],
os.path.join(args.out_dir,
'est_map',
dictionaries[0]['filename']))

# The estimated map must be thresholded to obtain estimated points
est_map_numpy = est_map.data.cpu().numpy()
est_map_numpy = est_map.data[0, :, :].cpu().numpy()
mask = cv2.inRange(est_map_numpy, 2 / 255, 1)
coord = np.where(mask > 0)
y = coord[0].reshape((-1, 1))
Expand All @@ -288,7 +202,7 @@ def __getitem__(self, idx):
ahd = criterion_training.max_dist
centroids = []
else:
n_components = int(torch.round(est_n_plants).data.cpu().numpy()[0])
n_components = int(torch.round(est_count[0]).data.cpu().numpy()[0])
# If the estimation is horrible, we cannot fit a GMM if n_components > n_samples
n_components = max(min(n_components, x.size), 1)
centroids = mixture.GaussianMixture(n_components=n_components,
Expand All @@ -297,45 +211,53 @@ def __getitem__(self, idx):
fit(c).means_.astype(np.int)

# Save thresholded map to disk
cv2.imwrite(os.path.join(args.out_dir, 'est_map_thresholded', dictionary['filename'][0]),
cv2.imwrite(os.path.join(args.out_dir,
'est_map_thresholded',
dictionaries[0]['filename']),
mask)


# Paint red dots if user asked for it
if args.paint:
# Paint a circle in the original image at the estimated location
image_with_x = tensortype(data.data.squeeze().size()).\
copy_(data.data.squeeze())
image_with_x = tensortype(imgs.data[0, :, :].squeeze().size()).\
copy_(imgs.data[0, :, :].squeeze())
image_with_x = ((image_with_x + 1) / 2.0 * 255.0)
image_with_x = image_with_x.cpu().numpy()
image_with_x = np.moveaxis(image_with_x, 0, 2).copy()
for y, x in centroids:
image_with_x = cv2.circle(image_with_x, (x, y), 3, [255, 0, 0], -1)
# Save original image with circle to disk
image_with_x = image_with_x[:, :, ::-1]
cv2.imwrite(os.path.join(args.out_dir, 'painted', dictionary['filename'][0]),
cv2.imwrite(os.path.join(args.out_dir,
'painted',
dictionaries[0]['filename']),
image_with_x)

if testset.there_is_gt:
# Evaluate Average Percent Error for this image
ape = 100 * l1_loss.forward(est_n_plants, target_n_plants) / target_n_plants
if bool((target_count==0).data.cpu().numpy()[0][0]):
ape = 100 * l1_loss.forward(est_count, target_count)
else:
ape = 100 * l1_loss.forward(est_count,
target_count) / target_count
ape = ape.data.cpu().numpy()[0]
sum_ape += ape

# Evaluation using the Averaged Hausdorff Distance
target = target.data.cpu().numpy().reshape(-1, 2)
ahd = losses.averaged_hausdorff_distance(centroids, target)
target_locations = \
target_locations[0].data.cpu().numpy().reshape(-1, 2)
ahd = losses.averaged_hausdorff_distance(
centroids, target_locations)

sum_ahd += ahd

# Validation using Precision and Recall
for judge in judges:
judge.evaluate_sample(centroids, target)

judge.evaluate_sample(centroids, target_locations)

df = pd.DataFrame(data=[est_n_plants.data.cpu().numpy()[0]],
index=[dictionary['filename'][0]],
columns=['plant_count'])
df = pd.DataFrame(data=[est_count.data[0, 0]],
index=[dictionaries[0]['filename']],
columns=['count'])
df_out = df_out.append(df)

if testset.there_is_gt:
Expand Down

0 comments on commit 561311d

Please sign in to comment.