diff --git a/data.py b/data.py new file mode 100644 index 0000000..bab0276 --- /dev/null +++ b/data.py @@ -0,0 +1,147 @@ +import os +import inspect +import random + +from PIL import Image +import skimage +import pandas as pd +import torch +from torch.utils import data +from torchvision import datasets +from torchvision import transforms + + +class PlantDataset(data.Dataset): + def __init__(self, root_dir, transform=None, max_dataset_size=float('inf')): + """ + Args: + root_dir (string): Directory with all the images. + transform (callable, optional): Optional transform to be applied + on a sample. + max_dataset_size: If the dataset is bigger than this integer, + ignore additional samples. + """ + + # Get groundtruth from CSV file + csv_filename = None + for filename in os.listdir(root_dir): + if filename.endswith('.csv'): + csv_filename = filename + break + if csv_filename is None: + raise ValueError( + 'The root directory %s does not have a CSV file with groundtruth' % root_dir) + self.csv_df = pd.read_csv(os.path.join(root_dir, csv_filename)) + + # Make the dataset smaller + self.csv_df = self.csv_df[0:min(len(self.csv_df), max_dataset_size)] + + self.root_dir = root_dir + self.transforms = transform + + def __len__(self): + return len(self.csv_df) + + def __getitem__(self, idx): + img_path = os.path.join(self.root_dir, self.csv_df.ix[idx, 0]) + img = Image.open(img_path) + dictionary = dict(self.csv_df.ix[idx]) + + # str -> lists + dictionary['plant_locations'] = eval(dictionary['plant_locations']) + dictionary['plant_locations'] = [ + list(loc) for loc in dictionary['plant_locations']] + + # list --> Tensors + dictionary['plant_locations'] = torch.FloatTensor( + dictionary['plant_locations']) + dictionary['plant_count'] = torch.FloatTensor( + [dictionary['plant_count']]) + + img_transformed = img + transformed_dictionary = dictionary + + # Apply all transformations provided + if self.transforms is not None: + for transform in self.transforms.transforms: + if hasattr(transform, 'modifies_label'): + img_transformed, transformed_dictionary = \ + transform(img_transformed, transformed_dictionary) + else: + img_transformed = transform(img_transformed) + + # Prevents crash when making a batch out of an empty tensor + if dictionary['plant_count'][0] == 0: + dictionary['plant_locations'] = torch.FloatTensor([-1, -1]) + + return (img_transformed, transformed_dictionary) + + +class RandomHorizontalFlipImageAndLabel(object): + """ Horizontally flip a numpy array image and the GT with probability p """ + + def __init__(self, p): + self.modifies_label = True + self.p = p + + def __call__(self, img, dictionary): + transformed_img = img + transformed_dictionary = dictionary + + if random.random() < self.p: + transformed_img = hflip(img) + width = img.size[1] + for l, loc in enumerate(dictionary['plant_locations']): + dictionary['plant_locations'][l][1] = (width - 1) - loc[1] + + return transformed_img, transformed_dictionary + + +class RandomVerticalFlipImageAndLabel(object): + """ Vertically flip a numpy array image and the GT with probability p """ + + def __init__(self, p): + self.modifies_label = True + self.p = p + + def __call__(self, img, dictionary): + transformed_img = img + transformed_dictionary = dictionary + + if random.random() < self.p: + transformed_img = vflip(img) + height = img.size[0] + for l, loc in enumerate(dictionary['plant_locations']): + dictionary['plant_locations'][l][0] = (height - 1) - loc[0] + + return transformed_img, transformed_dictionary + + +def hflip(img): + """Horizontally flip the given PIL Image. + Args: + img (PIL Image): Image to be flipped. + Returns: + PIL Image: Horizontall flipped image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + return img.transpose(Image.FLIP_LEFT_RIGHT) + + +def vflip(img): + """Vertically flip the given PIL Image. + Args: + img (PIL Image): Image to be flipped. + Returns: + PIL Image: Vertically flipped image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + return img.transpose(Image.FLIP_TOP_BOTTOM) + + +def _is_pil_image(img): + return isinstance(img, Image.Image) diff --git a/train_and_validate.py b/train_and_validate.py index d85cb5e..55a7dd5 100644 --- a/train_and_validate.py +++ b/train_and_validate.py @@ -10,18 +10,12 @@ from tqdm import tqdm import numpy as np -import pandas as pd -import skimage.io import torch import torch.optim as optim import visdom -import skimage.draw import utils from torch import nn from torch.autograd import Variable -from torch.utils import data -from torchvision import datasets -from torchvision import transforms import torchvision as tv from torchvision.models import inception_v3 from sklearn import mixture @@ -29,13 +23,18 @@ import losses import unet_model from eval_precision_recall import Judge +from torchvision import transforms +from torch.utils.data import DataLoader +from data import PlantDataset +from data import RandomHorizontalFlipImageAndLabel +from data import RandomVerticalFlipImageAndLabel # Training settings parser = argparse.ArgumentParser(description='Plant Location with PyTorch') parser.add_argument('--train-dir', required=True, help='Directory with training images') -parser.add_argument('--val-dir', +parser.add_argument('--val-dir', help='Directory with validation images. If left blank no validation will be done.') parser.add_argument('--batch-size', type=int, default=1, metavar='N', help='input batch size for training') @@ -64,13 +63,13 @@ parser.add_argument('--env-name', default='Pure U-Net', type=str, metavar='NAME', help='Name of the environment in Visdom') parser.add_argument('--paint', default=False, action="store_true", - help='Paint red circles at estimated locations in Validation? '\ - 'It takes an enormous amount of time!') + help='Paint red circles at estimated locations in Validation? ' + 'It takes an enormous amount of time!') parser.add_argument('--radius', type=int, default=5, metavar='R', help='Default radius to consider a object detection as "match".') parser.add_argument('--n-points', type=int, default=None, metavar='N', - help='If you know the number of points (e.g, just one pupil), set it.' \ - 'Otherwise it will be estimated by adding a L1 cost term.') + help='If you know the number of points (e.g, just one pupil), set it.' + 'Otherwise it will be estimated by adding a L1 cost term.') parser.add_argument('--lambdaa', type=float, default=1, metavar='L', help='Weight that will multiply the MAPE term in the loss function.') @@ -101,63 +100,20 @@ viz_train_gt_win, viz_val_gt_win = None, None viz_train_est_win, viz_val_est_win = None, None - -class PlantDataset(data.Dataset): - def __init__(self, root_dir, transform=None, max_dataset_size=np.inf): - """ - Args: - root_dir (string): Directory with all the images. - transform (callable, optional): Optional transform to be applied - on a sample. - max_dataset_size: If the dataset is bigger than this integer, - ignore additional samples. - """ - - # Get groundtruth from CSV file - csv_filename = None - for filename in os.listdir(root_dir): - if filename.endswith('.csv'): - csv_filename = filename - break - if csv_filename is None: - raise ValueError( - 'The root directory %s does not have a CSV file with groundtruth' % root_dir) - self.csv_df = pd.read_csv(os.path.join(root_dir, csv_filename)) - - # Make the dataset smaller - self.csv_df = self.csv_df[0:min(len(self.csv_df), max_dataset_size)] - - self.root_dir = root_dir - self.transform = transform - - def __len__(self): - return len(self.csv_df) - - def __getitem__(self, idx): - img_path = os.path.join(self.root_dir, self.csv_df.ix[idx, 0]) - img = skimage.io.imread(img_path) - dictionary = dict(self.csv_df.ix[idx]) - - if self.transform: - transformed = self.transform(img) - else: - transformed = img - - return (transformed, dictionary) - - # Data loading code trainset = PlantDataset(args.train_dir, transform=transforms.Compose([ + RandomHorizontalFlipImageAndLabel(p=0.5), + RandomVerticalFlipImageAndLabel(p=0.5), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]), max_dataset_size=args.max_trainset_size) -trainset_loader = data.DataLoader(trainset, - batch_size=args.batch_size, - shuffle=True, - num_workers=args.nThreads) +trainset_loader = DataLoader(trainset, + batch_size=args.batch_size, + shuffle=True, + num_workers=args.nThreads) if args.val_dir: valset = PlantDataset(args.val_dir, transform=transforms.Compose([ @@ -166,10 +122,10 @@ def __getitem__(self, idx): (0.5, 0.5, 0.5)), ]), max_dataset_size=args.max_valset_size) - valset_loader = data.DataLoader(valset, - batch_size=args.eval_batch_size, - shuffle=True, - num_workers=args.nThreads) + valset_loader = DataLoader(valset, + batch_size=args.eval_batch_size, + shuffle=True, + num_workers=args.nThreads) # Model print('Building network... ', end='') @@ -223,21 +179,18 @@ def __getitem__(self, idx): model.train() # Pull info from this sample image - gt_plant_locations = [eval(el) for el in dictionary['plant_locations']] + gt_plant_locations = dictionary['plant_locations'] target_n_plants = dictionary['plant_count'] + # We cannot deal with images with 0 plants (CD is not defined) - if any(len(target_one_img) == 0 for target_one_img in gt_plant_locations): + if target_n_plants[0][0] == 0: continue - if criterion_training is chamfer_loss: target = gt_plant_locations else: target = dots_img_tensor - # Prepare data and target - data, target, target_n_plants = data.type( - torch.FloatTensor), torch.FloatTensor(target), target_n_plants.type(torch.FloatTensor) if args.cuda: data, target, target_n_plants = data.cuda(), target.cuda(), target_n_plants.cuda() data, target, target_n_plants = Variable( @@ -258,8 +211,9 @@ def __getitem__(self, idx): # Log training error if time.time() > tic_train + args.log_interval: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( - epoch, batch_idx * len(data), len(trainset_loader.dataset), - 100. * batch_idx / len(trainset_loader), loss.data[0])) + epoch, batch_idx * + args.batch_size, len(trainset_loader.dataset), + 100. * batch_idx / len(trainset_loader), loss.data[0][0])) tic_train = time.time() # Send training loss to Visdom @@ -408,14 +362,16 @@ def __getitem__(self, idx): if args.paint: # Send original image with a cross at the estimated centroids to Visdom image_with_x = torch.cuda.FloatTensor(data.data.squeeze().size()).\ - copy_(data.data.squeeze()) + copy_(data.data.squeeze()) image_with_x = ((image_with_x + 1) / 2.0 * 255.0) image_with_x = image_with_x.cpu().numpy() image_with_x = np.moveaxis(image_with_x, 0, 2).copy() for y, x in centroids: - image_with_x = cv2.circle(image_with_x, (x, y), 3, [255, 0, 0], -1) + image_with_x = cv2.circle( + image_with_x, (x, y), 3, [255, 0, 0], -1) viz.image(np.moveaxis(image_with_x, 2, 0), - opts=dict(title='(Validation) Estimated centers @ crossings'), + opts=dict( + title='(Validation) Estimated centers @ crossings'), win=8) avg_term1_val = sum_term1 / len(valset_loader) @@ -427,20 +383,20 @@ def __getitem__(self, idx): prec, rec = torch.cuda.FloatTensor([prec]), torch.cuda.FloatTensor([rec]) # Send validation loss to Visdom - win_val_loss = viz.updateTrace(Y=torch.stack((avg_term1_val, avg_term2_val, avg_term3_val, avg_loss_val/3, avg_ahd_val, prec, rec)).view(1, -1).data.cpu(), + win_val_loss = viz.updateTrace(Y=torch.stack((avg_term1_val, avg_term2_val, avg_term3_val, avg_loss_val / 3, avg_ahd_val, prec, rec)).view(1, -1).data.cpu(), X=torch.Tensor([epoch]).repeat(1, 7), opts=dict(title='Validation', legend=['Term 1', 'Term 2', - 'Term 3', 'Sum/3', 'AHD', 'Precision', 'Recall'], + 'Term 3', 'Sum/3', 'AHD', 'Precision', 'Recall'], ylabel='Loss', xlabel='Epoch'), append=True, win='4') if win_val_loss == 'win does not exist': - win_val_loss = viz.line(Y=torch.stack((avg_term1_val, avg_term2_val, avg_term3_val, avg_loss_val/3, avg_ahd_val, prec, rec)).view(1, -1).data.cpu(), + win_val_loss = viz.line(Y=torch.stack((avg_term1_val, avg_term2_val, avg_term3_val, avg_loss_val / 3, avg_ahd_val, prec, rec)).view(1, -1).data.cpu(), X=torch.Tensor([epoch]).repeat(1, 7), opts=dict(title='Validation', legend=['Term 1', 'Term 2', - 'Term 3', 'Sum/3', 'AHD', 'Precision', 'Recall'], + 'Term 3', 'Sum/3', 'AHD', 'Precision', 'Recall'], ylabel='Loss', xlabel='Epoch'), win='4') @@ -448,7 +404,7 @@ def __getitem__(self, idx): avg_ahd_val_float = avg_ahd_val.cpu().numpy()[0] if avg_ahd_val_float < lowest_avg_ahd_val: # Keep the best model - lowest_avg_ahd_val = avg_ahd_val_float + lowest_avg_ahd_val = avg_ahd_val_float if args.save: torch.save({'epoch': epoch + 1, # when resuming, we will start at the next epoch 'model': model.state_dict(),