Skip to content

Commit

Permalink
move data classes to new file and do RandomFlip
Browse files Browse the repository at this point in the history
Former-commit-id: 3740c0388c10e3cec5173d67323a8b273875e6e2
  • Loading branch information
Javi Ribera committed Feb 25, 2018
1 parent e4a65a4 commit b8fe6dc
Show file tree
Hide file tree
Showing 2 changed files with 183 additions and 80 deletions.
147 changes: 147 additions & 0 deletions data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import os
import inspect
import random

from PIL import Image
import skimage
import pandas as pd
import torch
from torch.utils import data
from torchvision import datasets
from torchvision import transforms


class PlantDataset(data.Dataset):
def __init__(self, root_dir, transform=None, max_dataset_size=float('inf')):
"""
Args:
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
max_dataset_size: If the dataset is bigger than this integer,
ignore additional samples.
"""

# Get groundtruth from CSV file
csv_filename = None
for filename in os.listdir(root_dir):
if filename.endswith('.csv'):
csv_filename = filename
break
if csv_filename is None:
raise ValueError(
'The root directory %s does not have a CSV file with groundtruth' % root_dir)
self.csv_df = pd.read_csv(os.path.join(root_dir, csv_filename))

# Make the dataset smaller
self.csv_df = self.csv_df[0:min(len(self.csv_df), max_dataset_size)]

self.root_dir = root_dir
self.transforms = transform

def __len__(self):
return len(self.csv_df)

def __getitem__(self, idx):
img_path = os.path.join(self.root_dir, self.csv_df.ix[idx, 0])
img = Image.open(img_path)
dictionary = dict(self.csv_df.ix[idx])

# str -> lists
dictionary['plant_locations'] = eval(dictionary['plant_locations'])
dictionary['plant_locations'] = [
list(loc) for loc in dictionary['plant_locations']]

# list --> Tensors
dictionary['plant_locations'] = torch.FloatTensor(
dictionary['plant_locations'])
dictionary['plant_count'] = torch.FloatTensor(
[dictionary['plant_count']])

img_transformed = img
transformed_dictionary = dictionary

# Apply all transformations provided
if self.transforms is not None:
for transform in self.transforms.transforms:
if hasattr(transform, 'modifies_label'):
img_transformed, transformed_dictionary = \
transform(img_transformed, transformed_dictionary)
else:
img_transformed = transform(img_transformed)

# Prevents crash when making a batch out of an empty tensor
if dictionary['plant_count'][0] == 0:
dictionary['plant_locations'] = torch.FloatTensor([-1, -1])

return (img_transformed, transformed_dictionary)


class RandomHorizontalFlipImageAndLabel(object):
""" Horizontally flip a numpy array image and the GT with probability p """

def __init__(self, p):
self.modifies_label = True
self.p = p

def __call__(self, img, dictionary):
transformed_img = img
transformed_dictionary = dictionary

if random.random() < self.p:
transformed_img = hflip(img)
width = img.size[1]
for l, loc in enumerate(dictionary['plant_locations']):
dictionary['plant_locations'][l][1] = (width - 1) - loc[1]

return transformed_img, transformed_dictionary


class RandomVerticalFlipImageAndLabel(object):
""" Vertically flip a numpy array image and the GT with probability p """

def __init__(self, p):
self.modifies_label = True
self.p = p

def __call__(self, img, dictionary):
transformed_img = img
transformed_dictionary = dictionary

if random.random() < self.p:
transformed_img = vflip(img)
height = img.size[0]
for l, loc in enumerate(dictionary['plant_locations']):
dictionary['plant_locations'][l][0] = (height - 1) - loc[0]

return transformed_img, transformed_dictionary


def hflip(img):
"""Horizontally flip the given PIL Image.
Args:
img (PIL Image): Image to be flipped.
Returns:
PIL Image: Horizontall flipped image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

return img.transpose(Image.FLIP_LEFT_RIGHT)


def vflip(img):
"""Vertically flip the given PIL Image.
Args:
img (PIL Image): Image to be flipped.
Returns:
PIL Image: Vertically flipped image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

return img.transpose(Image.FLIP_TOP_BOTTOM)


def _is_pil_image(img):
return isinstance(img, Image.Image)
116 changes: 36 additions & 80 deletions train_and_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,32 +10,31 @@
from tqdm import tqdm

import numpy as np
import pandas as pd
import skimage.io
import torch
import torch.optim as optim
import visdom
import skimage.draw
import utils
from torch import nn
from torch.autograd import Variable
from torch.utils import data
from torchvision import datasets
from torchvision import transforms
import torchvision as tv
from torchvision.models import inception_v3
from sklearn import mixture
import unet_pix2pix
import losses
import unet_model
from eval_precision_recall import Judge
from torchvision import transforms
from torch.utils.data import DataLoader
from data import PlantDataset
from data import RandomHorizontalFlipImageAndLabel
from data import RandomVerticalFlipImageAndLabel


# Training settings
parser = argparse.ArgumentParser(description='Plant Location with PyTorch')
parser.add_argument('--train-dir', required=True,
help='Directory with training images')
parser.add_argument('--val-dir',
parser.add_argument('--val-dir',
help='Directory with validation images. If left blank no validation will be done.')
parser.add_argument('--batch-size', type=int, default=1, metavar='N',
help='input batch size for training')
Expand Down Expand Up @@ -64,13 +63,13 @@
parser.add_argument('--env-name', default='Pure U-Net', type=str, metavar='NAME',
help='Name of the environment in Visdom')
parser.add_argument('--paint', default=False, action="store_true",
help='Paint red circles at estimated locations in Validation? '\
'It takes an enormous amount of time!')
help='Paint red circles at estimated locations in Validation? '
'It takes an enormous amount of time!')
parser.add_argument('--radius', type=int, default=5, metavar='R',
help='Default radius to consider a object detection as "match".')
parser.add_argument('--n-points', type=int, default=None, metavar='N',
help='If you know the number of points (e.g, just one pupil), set it.' \
'Otherwise it will be estimated by adding a L1 cost term.')
help='If you know the number of points (e.g, just one pupil), set it.'
'Otherwise it will be estimated by adding a L1 cost term.')
parser.add_argument('--lambdaa', type=float, default=1, metavar='L',
help='Weight that will multiply the MAPE term in the loss function.')

Expand Down Expand Up @@ -101,63 +100,20 @@
viz_train_gt_win, viz_val_gt_win = None, None
viz_train_est_win, viz_val_est_win = None, None


class PlantDataset(data.Dataset):
def __init__(self, root_dir, transform=None, max_dataset_size=np.inf):
"""
Args:
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
max_dataset_size: If the dataset is bigger than this integer,
ignore additional samples.
"""

# Get groundtruth from CSV file
csv_filename = None
for filename in os.listdir(root_dir):
if filename.endswith('.csv'):
csv_filename = filename
break
if csv_filename is None:
raise ValueError(
'The root directory %s does not have a CSV file with groundtruth' % root_dir)
self.csv_df = pd.read_csv(os.path.join(root_dir, csv_filename))

# Make the dataset smaller
self.csv_df = self.csv_df[0:min(len(self.csv_df), max_dataset_size)]

self.root_dir = root_dir
self.transform = transform

def __len__(self):
return len(self.csv_df)

def __getitem__(self, idx):
img_path = os.path.join(self.root_dir, self.csv_df.ix[idx, 0])
img = skimage.io.imread(img_path)
dictionary = dict(self.csv_df.ix[idx])

if self.transform:
transformed = self.transform(img)
else:
transformed = img

return (transformed, dictionary)


# Data loading code
trainset = PlantDataset(args.train_dir,
transform=transforms.Compose([
RandomHorizontalFlipImageAndLabel(p=0.5),
RandomVerticalFlipImageAndLabel(p=0.5),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5)),
]),
max_dataset_size=args.max_trainset_size)
trainset_loader = data.DataLoader(trainset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.nThreads)
trainset_loader = DataLoader(trainset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.nThreads)
if args.val_dir:
valset = PlantDataset(args.val_dir,
transform=transforms.Compose([
Expand All @@ -166,10 +122,10 @@ def __getitem__(self, idx):
(0.5, 0.5, 0.5)),
]),
max_dataset_size=args.max_valset_size)
valset_loader = data.DataLoader(valset,
batch_size=args.eval_batch_size,
shuffle=True,
num_workers=args.nThreads)
valset_loader = DataLoader(valset,
batch_size=args.eval_batch_size,
shuffle=True,
num_workers=args.nThreads)

# Model
print('Building network... ', end='')
Expand Down Expand Up @@ -223,21 +179,18 @@ def __getitem__(self, idx):
model.train()

# Pull info from this sample image
gt_plant_locations = [eval(el) for el in dictionary['plant_locations']]
gt_plant_locations = dictionary['plant_locations']
target_n_plants = dictionary['plant_count']

# We cannot deal with images with 0 plants (CD is not defined)
if any(len(target_one_img) == 0 for target_one_img in gt_plant_locations):
if target_n_plants[0][0] == 0:
continue


if criterion_training is chamfer_loss:
target = gt_plant_locations
else:
target = dots_img_tensor

# Prepare data and target
data, target, target_n_plants = data.type(
torch.FloatTensor), torch.FloatTensor(target), target_n_plants.type(torch.FloatTensor)
if args.cuda:
data, target, target_n_plants = data.cuda(), target.cuda(), target_n_plants.cuda()
data, target, target_n_plants = Variable(
Expand All @@ -258,8 +211,9 @@ def __getitem__(self, idx):
# Log training error
if time.time() > tic_train + args.log_interval:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(trainset_loader.dataset),
100. * batch_idx / len(trainset_loader), loss.data[0]))
epoch, batch_idx *
args.batch_size, len(trainset_loader.dataset),
100. * batch_idx / len(trainset_loader), loss.data[0][0]))
tic_train = time.time()

# Send training loss to Visdom
Expand Down Expand Up @@ -408,14 +362,16 @@ def __getitem__(self, idx):
if args.paint:
# Send original image with a cross at the estimated centroids to Visdom
image_with_x = torch.cuda.FloatTensor(data.data.squeeze().size()).\
copy_(data.data.squeeze())
copy_(data.data.squeeze())
image_with_x = ((image_with_x + 1) / 2.0 * 255.0)
image_with_x = image_with_x.cpu().numpy()
image_with_x = np.moveaxis(image_with_x, 0, 2).copy()
for y, x in centroids:
image_with_x = cv2.circle(image_with_x, (x, y), 3, [255, 0, 0], -1)
image_with_x = cv2.circle(
image_with_x, (x, y), 3, [255, 0, 0], -1)
viz.image(np.moveaxis(image_with_x, 2, 0),
opts=dict(title='(Validation) Estimated centers @ crossings'),
opts=dict(
title='(Validation) Estimated centers @ crossings'),
win=8)

avg_term1_val = sum_term1 / len(valset_loader)
Expand All @@ -427,28 +383,28 @@ def __getitem__(self, idx):
prec, rec = torch.cuda.FloatTensor([prec]), torch.cuda.FloatTensor([rec])

# Send validation loss to Visdom
win_val_loss = viz.updateTrace(Y=torch.stack((avg_term1_val, avg_term2_val, avg_term3_val, avg_loss_val/3, avg_ahd_val, prec, rec)).view(1, -1).data.cpu(),
win_val_loss = viz.updateTrace(Y=torch.stack((avg_term1_val, avg_term2_val, avg_term3_val, avg_loss_val / 3, avg_ahd_val, prec, rec)).view(1, -1).data.cpu(),
X=torch.Tensor([epoch]).repeat(1, 7),
opts=dict(title='Validation',
legend=['Term 1', 'Term 2',
'Term 3', 'Sum/3', 'AHD', 'Precision', 'Recall'],
'Term 3', 'Sum/3', 'AHD', 'Precision', 'Recall'],
ylabel='Loss', xlabel='Epoch'),
append=True,
win='4')
if win_val_loss == 'win does not exist':
win_val_loss = viz.line(Y=torch.stack((avg_term1_val, avg_term2_val, avg_term3_val, avg_loss_val/3, avg_ahd_val, prec, rec)).view(1, -1).data.cpu(),
win_val_loss = viz.line(Y=torch.stack((avg_term1_val, avg_term2_val, avg_term3_val, avg_loss_val / 3, avg_ahd_val, prec, rec)).view(1, -1).data.cpu(),
X=torch.Tensor([epoch]).repeat(1, 7),
opts=dict(title='Validation',
legend=['Term 1', 'Term 2',
'Term 3', 'Sum/3', 'AHD', 'Precision', 'Recall'],
'Term 3', 'Sum/3', 'AHD', 'Precision', 'Recall'],
ylabel='Loss', xlabel='Epoch'),
win='4')

# If this is the best epoch (in terms of validation error)
avg_ahd_val_float = avg_ahd_val.cpu().numpy()[0]
if avg_ahd_val_float < lowest_avg_ahd_val:
# Keep the best model
lowest_avg_ahd_val = avg_ahd_val_float
lowest_avg_ahd_val = avg_ahd_val_float
if args.save:
torch.save({'epoch': epoch + 1, # when resuming, we will start at the next epoch
'model': model.state_dict(),
Expand Down

0 comments on commit b8fe6dc

Please sign in to comment.