diff --git a/losses.py b/losses.py new file mode 100644 index 0000000..d96d045 --- /dev/null +++ b/losses.py @@ -0,0 +1,107 @@ +import math +import torch +from sklearn.utils.extmath import cartesian +import numpy as np +from torch.nn import functional as F +import os +import time +from sklearn.metrics.pairwise import pairwise_distances +from sklearn.neighbors.kde import KernelDensity +import skimage.io +from matplotlib import pyplot as plt +from torch import nn +from torch.autograd import Variable + +""" +We recommend copying this file to any project you need. +""" + + +def _assert_no_grad(variable): + assert not variable.requires_grad, \ + "nn criterions don't compute the gradient w.r.t. targets - please " \ + "mark these variables as volatile or not requiring gradients" + + +def cdist(x, y): + ''' + Input: x is a Nxd Tensor + y is a Mxd Tensor + Output: dist is a NxM matrix where dist[i,j] is the norm + between x[i,:] and y[j,:] + i.e. dist[i,j] = ||x[i,:]-y[j,:]|| + ''' + differences = x.unsqueeze(1) - y.unsqueeze(0) + distances = torch.sum(differences**2, -1).sqrt() + return distances + + +class ModifiedChamferLoss(nn.Module): + def __init__(self, height, width, return_2_terms=False): + """ + :param height: Number of rows in the image. + :param width: Number of columns in the image. + :param return_2_terms: Whether to return the 2 terms of the CD instead of their sum. Default: False. + """ + super(nn.Module, self).__init__() + + # Prepare all possible (row, col) locations in the image + self.height, self.width = height, width + self.max_dist = math.sqrt(height**2 + width**2) + self.n_pixels = height * width + self.all_img_locations = torch.from_numpy(cartesian([np.arange(height), + np.arange(width)])) + self.all_img_locations = self.all_img_locations.type(torch.FloatTensor) + self.all_img_locations = self.all_img_locations.cuda() + self.all_img_locations = Variable(self.all_img_locations) + + self.return_2_terms = return_2_terms + + def forward(self, prob_map, gt): + """ + Compute the Modified Chamfer Distance function + between the estimated probability map and ground truth points. + + :param prob_map: Tensor of the probability map of the estimation, must be between 0 and 1. + :param gt: Tensor where each row is the (y, x), i.e, (row, col) of GT points. + :return: Value of the Modified Chamfer Distance, or their 2 terms as a tuples. + """ + _assert_no_grad(gt) + + assert prob_map.size()[0:2] == (self.height, self.width), \ + 'You must configure the ModifiedChamferLoss with the height and width of the ' \ + 'probability map that you are using, got a probability map of size (%s, %s)'\ + % prob_map.size() + + # Pairwise distances between all possible locations and the GTed locations + gt = gt.squeeze() + n_gt_pts = gt.size()[0] + d2_matrix = cdist(self.all_img_locations, gt) + + # Reshape probability map as a long column vector, + # and prepare it for multiplication + p = prob_map.view(prob_map.nelement()) + # Think of the next line as a regular threshold at 0.5 to {0,1} (damn pytorch!) + # Hard threshold + # p_thresh = F.threshold(p,0.1,0)/p + n_est_pts = p.sum() + p_replicated = p.view(-1, 1).repeat(1, n_gt_pts) + # p_thresh_replicated = p_thresh.view(-1, 1).repeat(1, n_gt_pts) + + eps = 1e-6 + + # Modified Chamfer Loss + term_1 = (1 / (n_est_pts + eps)) * \ + torch.sum(p * torch.min(d2_matrix, 1)[0]) + d_div_p = torch.min((d2_matrix + eps) / + (p_replicated**4 + eps / self.max_dist), 0)[0] + d_div_p = torch.clamp(d_div_p, 0, self.max_dist) + term_2 = 1 * torch.mean(d_div_p, 0)[0] + + if self.return_2_terms: + res = (term_1, term_2) + else: + res = term_1 + term_2 + + return res + diff --git a/main.py b/main.py new file mode 100644 index 0000000..6cf310a --- /dev/null +++ b/main.py @@ -0,0 +1,509 @@ +from __future__ import print_function + +import argparse +import os +import sys +import time +import shutil + +import numpy as np +import pandas as pd +import skimage.io +import torch +import torch.optim as optim +import visdom +import skimage.draw +import utils +from torch import nn +from torch.autograd import Variable +from torch.utils import data +from torchvision import datasets +from torchvision import transforms +import torchvision as tv +from torchvision.models import inception_v3 +import unet +import losses +import unet_model + +# Training settings +parser = argparse.ArgumentParser(description='Plant Location with PyTorch') +parser.add_argument('--train-dir', required=True, + help='Directory with training images') +parser.add_argument('--val-dir', required=True, + help='Directory with validation images') +parser.add_argument('--test-dir', required=True, + help='Directory with testing images') +parser.add_argument('--batch-size', type=int, default=300, metavar='N', + help='input batch size for training') +parser.add_argument('--eval-batch-size', type=int, default=120, metavar='N', + help='input batch size for validation and testing') +parser.add_argument('--epochs', type=int, default=np.inf, metavar='N', + help='number of epochs to train') +parser.add_argument('--nThreads', '-j', default=4, type=int, metavar='N', + help='number of data loading threads (default: 4)') +parser.add_argument('--lr', type=float, default=4e-5, metavar='LR', + help='learning rate (default: 1e-5)') +parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA training') +parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') +parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') +parser.add_argument('--save', default='', type=str, metavar='PATH', + help='where to save the model after each epoch') +parser.add_argument('--log-interval', type=float, default=3, metavar='N', + help='time to wait between logging training status (in seconds)') +parser.add_argument('--max-trainset-size', type=int, default=np.inf, metavar='N', + help='only use the first N images of the training dataset') +parser.add_argument('--max-valset-size', type=int, default=np.inf, metavar='N', + help='only use the first N images of the validation dataset') +parser.add_argument('--max-testset-size', type=int, default=np.inf, metavar='N', + help='only use the first N images of the testing dataset') +parser.add_argument('--out-test-csv', type=str, + help='path where to store the results of analyzing the test set') +args = parser.parse_args() +args.cuda = not args.no_cuda and torch.cuda.is_available() + +# Check we are not overwriting a checkpoint without resume from it +if args.save and os.path.isfile(args.save) and \ + not (args.resume and args.resume == args.save): + print("E: Don't overwrite a checkpoint without resuming from it (if you want that, remove it manually).") + exit(1) + +# Create directory for checkpoint to be saved +if args.save: + os.makedirs(os.path.split(args.save)[0], exist_ok=True) + + +# Set seeds +np.random.seed(0) +torch.manual_seed(args.seed) +if args.cuda: + torch.cuda.manual_seed_all(args.seed) + +# Visdom setup +viz = visdom.Visdom(env='Pure U-Net') +viz_train_input_win, viz_val_input_win = None, None +viz_train_loss_win, viz_val_loss_win = None, None +viz_train_gt_win, viz_val_gt_win = None, None +viz_train_est_win, viz_val_est_win = None, None + + +class PlantDataset(data.Dataset): + def __init__(self, root_dir, transform=None, max_dataset_size=np.inf): + """ + Args: + root_dir (string): Directory with all the images. + transform (callable, optional): Optional transform to be applied + on a sample. + max_dataset_size: If the dataset is bigger than this integer, + ignore additional samples. + """ + + # Get groundtruth from CSV file + csv_filename = None + for filename in os.listdir(root_dir): + if filename.endswith('.csv'): + csv_filename = filename + break + if csv_filename is None: + raise ValueError( + 'The root directory %s does not have a CSV file with groundtruth' % root_dir) + self.csv_df = pd.read_csv(os.path.join(root_dir, csv_filename)) + + # Make the dataset smaller + self.csv_df = self.csv_df[0:min(len(self.csv_df), max_dataset_size)] + + self.root_dir = root_dir + self.transform = transform + + def __len__(self): + return len(self.csv_df) + + def __getitem__(self, idx): + img_path = os.path.join(self.root_dir, self.csv_df.ix[idx, 0]) + img = skimage.io.imread(img_path) + dictionary = dict(self.csv_df.ix[idx]) + + if self.transform: + transformed = self.transform(img) + else: + transformed = img + + return (transformed, dictionary) + + +# Data loading code +trainset = PlantDataset(args.train_dir, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), + (0.5, 0.5, 0.5)), + ]), + max_dataset_size=args.max_trainset_size) +valset = PlantDataset(args.val_dir, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), + (0.5, 0.5, 0.5)), + ]), + max_dataset_size=args.max_valset_size) +testset = PlantDataset(args.test_dir, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), + (0.5, 0.5, 0.5)), + ]), + max_dataset_size=args.max_testset_size) +trainset_loader = data.DataLoader(trainset, + batch_size=args.batch_size, + shuffle=True, + num_workers=args.nThreads) +valset_loader = data.DataLoader(valset, + batch_size=args.eval_batch_size, + shuffle=True, + num_workers=args.nThreads) +testset_loader = data.DataLoader(testset, + batch_size=args.eval_batch_size, + num_workers=args.nThreads) + +# Model +print('Building network... ', end='') +#model = unet.UnetGenerator(input_nc=3, output_nc=1, num_downs=8) +model = unet_model.UNet(3, 1) +print('DONE') +print(model) +model = nn.DataParallel(model) +if args.cuda: + model.cuda() + +# Loss function +l1_loss = nn.L1Loss() +chamfer_loss = losses.ModifiedChamferLoss(256, 256, return_2_terms=True) +criterion_training = chamfer_loss + +# Optimization strategy +optimizer = optim.SGD(model.parameters(), + lr=args.lr) + +start_epoch = 0 +lowest_avg_loss_val = np.infty + +# Restore saved checkpoint (model weights + epoch + optimizer state) +if args.resume: + print("Loading checkpoint '{}' ...".format(args.resume)) + if os.path.isfile(args.resume): + checkpoint = torch.load(args.resume) + start_epoch = checkpoint['epoch'] + lowest_avg_loss_val = checkpoint['lowest_avg_loss_val'] + model.load_state_dict(checkpoint['model']) + optimizer.load_state_dict(checkpoint['optimizer']) + print("╰─ loaded checkpoint '{}' (now on epoch {})" + .format(args.resume, checkpoint['epoch'])) + else: + print("╰─ E: no checkpoint found at '{}'".format(args.resume)) + exit(-1) + +# Time at the last evaluation +tic_train = -np.infty +tic_val = -np.infty + +epoch = start_epoch +it_num = 0 +while epoch < args.epochs: + + for batch_idx, (data, dictionary) in enumerate(trainset_loader): + # === TRAIN === + + # Set the module in training mode + model.train() + + # Pull info from this sample image + gt_plant_locations = [eval(el) for el in dictionary['plant_locations']] + target_n_plants = dictionary['plant_count'] + # We cannot deal with images with 0 plants (CD is not defined) + if any(len(target_one_img) == 0 for target_one_img in gt_plant_locations): + continue + + # Read image with GT dots from disk + gt_img_numpy = skimage.io.imread( + os.path.join('/home/jprat/projects/phenosorg/data/plant_counts_dots/20160613_F54_training_256x256_white_bigdots', + dictionary['filename'][0])) + dots_img_tensor = torch.from_numpy(gt_img_numpy).permute( + 2, 0, 1)[0, :, :].type(torch.FloatTensor) / 255 + + if criterion_training is chamfer_loss: + target = gt_plant_locations + else: + target = dots_img_tensor + + # Prepare data and target + data, target, target_n_plants = data.type( + torch.FloatTensor), torch.FloatTensor(target), target_n_plants.type(torch.FloatTensor) + if args.cuda: + data, target, target_n_plants = data.cuda(), target.cuda(), target_n_plants.cuda() + data, target, target_n_plants = Variable( + data), Variable(target), Variable(target_n_plants) + + # One training step + optimizer.zero_grad() + est_map, est_n_plants = model.forward(data) + est_map = est_map.squeeze() + term1, term2 = criterion_training.forward(est_map, target) + term3 = l1_loss.forward(est_n_plants, target_n_plants) / \ + target_n_plants.type(torch.cuda.FloatTensor) + loss = term1 + term2 + term3 + loss.backward() + optimizer.step() + + # Log training error + if time.time() > tic_train + args.log_interval: + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( + epoch, batch_idx * len(data), len(trainset_loader.dataset), + 100. * batch_idx / len(trainset_loader), loss.data[0])) + tic_train = time.time() + + # Send training loss to Visdom + win_train_loss = viz.updateTrace(Y=torch.cat([term1, term2, term3, loss / 3]).view(1, -1).data.cpu(), + X=torch.Tensor( + [it_num]).repeat(1, 4), + opts=dict(title='(Training) Chamfer', + legend=[ + 'Term 1', 'Term 2', 'Term3', 'Sum/3'], + ylabel='Loss', xlabel='Iteration'), + append=True, + win='0') + if win_train_loss == 'win does not exist': + win_train_loss = viz.line(Y=torch.cat([term1, term2, term3, loss / 3]).view(1, -1).data.cpu(), + X=torch.Tensor( + [it_num]).repeat(1, 4), + opts=dict(title='(Training) Chamfer', + legend=[ + 'Term 1', 'Term 2', 'Term3', 'Sum/3'], + ylabel='Loss', xlabel='Iteration'), + win='0') + + # Send input image to Visdom + viz.image(((data.data + 1) / 2.0 * 255.0).squeeze().cpu().numpy(), + opts=dict(title='(Training) Input'), + win=1) + # Send estimated image to Visdom + viz.image(est_map.data.unsqueeze(0).cpu().numpy(), + opts=dict(title='(Training) U-Net output'), + win=2) + # Send GT image to Visdom + viz.image(np.moveaxis(gt_img_numpy, 2, 0), + opts=dict(title='(Training) Ground Truth'), + win=3) + + it_num += 1 + + # At the end of each epoch, validate + test + save checkpoint if validation error decreased + if len(valset_loader) == 0: + continue + + # === VALIDATION === + print("\nValidating... ") + + # Set the module in evaluation mode + model.eval() + + sum_loss = 0 + for batch_idx, (data, dictionary) in enumerate(valset_loader): + + # Pull info from this sample image + gt_plant_locations = [eval(el) for el in dictionary['plant_locations']] + target_n_plants = dictionary['plant_count'] + # We cannot deal with images with 0 plants (CD is not defined) + if any(len(target_one_img) == 0 for target_one_img in gt_plant_locations): + continue + + # Read image with GT dots from disk + gt_img_numpy = skimage.io.imread( + os.path.join('/home/jprat/projects/phenosorg/data/plant_counts_dots/20160613_F54_validation_256x256_white_bigdots', + dictionary['filename'][0])) + dots_img_tensor = torch.from_numpy(gt_img_numpy).permute( + 2, 0, 1)[0, :, :].type(torch.FloatTensor) / 255 + + if criterion_training is chamfer_loss: + target = gt_plant_locations + else: + target = dots_img_tensor + + # Prepare data and target + data, target, target_n_plants = data.type( + torch.FloatTensor), torch.FloatTensor(target), target_n_plants.type(torch.FloatTensor) + if args.cuda: + data, target, target_n_plants = data.cuda(), target.cuda(), target_n_plants.cuda() + data, target, target_n_plants = Variable(data, volatile=True), Variable( + target, volatile=True), Variable(target_n_plants, volatile=True) + + # One training step + est_map, est_n_plants = model.forward(data) + est_map = est_map.squeeze() + term1, term2 = criterion_training.forward(est_map, target) + term3 = l1_loss.forward(est_n_plants, target_n_plants) / \ + target_n_plants.type(torch.cuda.FloatTensor) + loss = term1 + term2 + term3 + + sum_loss += loss + + avg_loss_val = sum_loss / len(valset_loader) + avg_loss_val_float = avg_loss_val.data.cpu().numpy()[0] + + print('╰─ Loss: {:.4f}'.format(avg_loss_val_float)) + + # Send stuff to Visdom every X seconds + if time.time() > tic_val + args.log_interval: + tic_val = time.time() + + # Send validation loss to Visdom + win_val_loss = viz.updateTrace(Y=torch.cat([term1, term2, term3, loss / 3]).view(1, -1).data.cpu(), + X=torch.Tensor([epoch]).repeat(1, 4), + opts=dict(title='(Validation) Chamfer', + legend=[ + 'Term 1', 'Term 2', 'Term 3', 'Sum/2'], + ylabel='Loss', xlabel='Epoch'), + append=True, + win='4') + if win_val_loss == 'win does not exist': + win_val_loss = viz.line(Y=torch.cat([term1, term2, term3, loss / 3]).view(1, -1).data.cpu(), + X=torch.Tensor([epoch]).repeat(1, 4), + opts=dict(title='(Validation) Chamfer', + legend=['Term 1', 'Term 2', + 'Term 3', 'Sum/2'], + ylabel='Loss', xlabel='Iteration'), + win='4') + + # Send input image to Visdom + viz.image(((data.data + 1) / 2.0 * 255.0).squeeze().cpu().numpy(), + opts=dict(title='(Validation) Input'), + win=5) + # Send estimated image to Visdom + viz.image(est_map.data.unsqueeze(0).cpu().numpy(), + opts=dict(title='(Validation) UNet output'), + win=6) + # Send GT image to Visdom + viz.image(np.moveaxis(gt_img_numpy, 2, 0), + opts=dict(title='(Validation) Ground Truth'), + win=7) + + # If this is the best epoch (in terms of validation error) + if avg_loss_val_float < lowest_avg_loss_val: + # Keep the best model + lowest_avg_loss_val = avg_loss_val_float + if args.save: + name, ext = os.path.splitext(args.save) + best_ckpt_path = name + '-best' + ext + torch.save({'epoch': epoch + 1, # when resuming, we will start at the next epoch + 'model': model.state_dict(), + 'lowest_avg_loss_val': lowest_avg_loss_val, + 'optimizer': optimizer.state_dict(), + }, best_ckpt_path) + print("Saved best checkpoint so far in %s " % best_ckpt_path) + + epoch += 1 + + # # === TESTING === + # print("\nTesting... ") + + # # Set the module in evaluation mode + # model.eval() + + # sum_loss = 0 + # for data, dictionary in testset_loader: + + # # Pull info from this sample image + # gt_plant_locations = [eval(el) for el in dictionary['plant_locations']] + # # We cannot deal with images with 0 plants (CD is not defined) + # if any(len(target_one_img) == 0 for target_one_img in gt_plant_locations): + # continue + + # # Prepare data and target + # data, gt_plant_locations = data.type( + # torch.FloatTensor), torch.FloatTensor(gt_plant_locations) + # if args.cuda: + # data, gt_plant_locations = data.cuda(), gt_plant_locations.cuda() + # data, gt_plant_locations = Variable(data, volatile=True), Variable(gt_plant_locations, volatile=True) + + # # Inference + # est_map = model.forward(data) + # est_map = est_map.squeeze() + # loss = criterion_training.forward(est_map, gt_plant_locations) + # sum_loss += loss + + # avg_loss_test = sum_loss / len(testset_loader) + # avg_loss_val_float = avg_loss_val.data.cpu().numpy()[0] + + # print('╰─ Chamfer: {:.4f}'.format(avg_loss_val_float)) + # cc_test.add_scalar_value("Chamfer", avg_loss_val.cpu().data[0], step=epoch) + + # # === TEST === + # OLD CODEEEEEEEEEEEEEE + # print("Testing... ") + + # df_out = testset.csv_df.copy() + + # sum_ape = 0 + # sum_ae = 0 + # for data, dictionary in testset_loader: + + # # Pull info from this sample image + # theres_gt = True if (dictionary['plant_count'][0] > 0) else False + # if theres_gt: + # target = dictionary['plant_locations'] + + # # Prepare data and target + # data = data.type(torch.FloatTensor) + # if theres_gt: + # target = target.type(torch.FloatTensor) + # if args.cuda: + # data = data.cuda() + # if theres_gt: + # target = target.cuda() + # data = Variable(data, volatile=True) + # if theres_gt: + # target = Variable(target, volatile=True) + + # # Compute Absolute Error of each image + # logits = model(data) + # logits = torch.squeeze(logits) + # if theres_gt: + # print('(Estimate, GT): \n %s' % torch.stack((logits.round(), + # target.round()), dim=1)) + # ae = torch.abs(logits - target) + + # # Compute Absolute Percent Error of each image + # ape = 100. * ae / target + + # # Make APE 0 when target is 0 + # inf_indices = Variable(torch.nonzero( + # ape.data == np.infty).squeeze()) + # if len(inf_indices) > 0: + # ape.index_fill_(dim=0, + # index=inf_indices, + # value=0) + # sum_ape += ape.sum() + # sum_ae += ae.sum() + # else: + # print('(Filename, Estimate): \n %s' % np.stack([np.array(dictionary['filename']), + # logits.cpu().data.numpy().astype(str)]).T) + + # if args.out_test_csv: + # # Put estimated plant counts into the data frame + # for filename, plant_count in zip(dictionary['filename'], logits.cpu().data.numpy().tolist()): + # df_out.ix[df_out['filename'] == filename, + # df_out.columns.get_loc('plant_locations')] = plant_count + + # if theres_gt: + # # MAPE (Mean Absolute Percent Error), and MAE (Mean Absolute Error) + # mape = sum_ape.data[0] / (len(valset_loader) * args.eval_batch_size) + # mae = sum_ae.data[0] / (len(valset_loader) * args.eval_batch_size) + + # print('╰─ MAE: {:.4f} plants'.format(mae)) + # print('╰─ MAPE: {:.4f} %'.format(mape)) + # cc_test.add_scalar_value("MAPE", mape, step=epoch) + + # if not theres_gt and args.out_test_csv: + # # Store estimated plant count to CSV file + # df_out.to_csv(args.out_test_csv) diff --git a/train_and_validate.py b/train_and_validate.py new file mode 100644 index 0000000..a2f1206 --- /dev/null +++ b/train_and_validate.py @@ -0,0 +1,199 @@ +from __future__ import print_function + +import argparse +import os +import sys +import time +import shutil + +from tqdm import tqdm +import numpy as np +import pandas as pd +import skimage.io +import torch +import torch.optim as optim +import visdom +import skimage.draw +import utils +from torch import nn +from torch.autograd import Variable +from torch.utils import data +from torchvision import datasets +from torchvision import transforms +import torchvision as tv +from torchvision.models import inception_v3 +import unet +import losses +import unet_model + +# Training settings +parser = argparse.ArgumentParser(description='Plant Location with PyTorch') +parser.add_argument('--test-dir', required=True, + help='Directory with testing images') +parser.add_argument('--eval-batch-size', type=int, default=120, metavar='N', + help='input batch size for validation and testing') +parser.add_argument('--nThreads', '-j', default=4, type=int, metavar='N', + help='number of data loading threads (default: 4)') +parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA training') +parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') +parser.add_argument('--checkpoint', default='', type=str, required=True, metavar='PATH', + help='path to latest checkpoint (default: none)') +parser.add_argument('--max-testset-size', type=int, default=np.inf, metavar='N', + help='only use the first N images of the testing dataset') +parser.add_argument('--out-dir', type=str, + help='path where to store the results of analyzing the test set \ + (images and CSV file)') +args = parser.parse_args() +args.cuda = not args.no_cuda and torch.cuda.is_available() + +# Set seeds +np.random.seed(0) +torch.manual_seed(args.seed) +if args.cuda: + torch.cuda.manual_seed_all(args.seed) + +os.makedirs(args.out_dir, exist_ok=True) + +class PlantDataset(data.Dataset): + def __init__(self, root_dir, transform=None, max_dataset_size=np.inf): + """ + Args: + root_dir (string): Directory with all the images. + transform (callable, optional): Optional transform to be applied + on a sample. + max_dataset_size: If the dataset is bigger than this integer, + ignore additional samples. + """ + + # Get groundtruth from CSV file + csv_filename = None + for filename in os.listdir(root_dir): + if filename.endswith('.csv'): + csv_filename = filename + break + if csv_filename is None: + raise ValueError( + 'The root directory %s does not have a CSV file with groundtruth' % root_dir) + self.csv_df = pd.read_csv(os.path.join(root_dir, csv_filename)) + + # Make the dataset smaller + self.csv_df = self.csv_df[0:min(len(self.csv_df), max_dataset_size)] + + self.root_dir = root_dir + self.transform = transform + + def __len__(self): + return len(self.csv_df) + + def __getitem__(self, idx): + img_path = os.path.join(self.root_dir, self.csv_df.ix[idx, 0]) + img = skimage.io.imread(img_path) + dictionary = dict(self.csv_df.ix[idx]) + + if self.transform: + transformed = self.transform(img) + else: + transformed = img + + return (transformed, dictionary) + + +# Data loading code +testset = PlantDataset(args.test_dir, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), + (0.5, 0.5, 0.5)), + ]), + max_dataset_size=args.max_testset_size) +testset_loader = data.DataLoader(testset, + batch_size=args.eval_batch_size, + num_workers=args.nThreads) + +# Model +print('Building network... ', end='') +#model = unet.UnetGenerator(input_nc=3, output_nc=1, num_downs=8) +model = unet_model.UNet(3, 1) +print('DONE') +print(model) +model = nn.DataParallel(model) +if args.cuda: + model.cuda() + +# Loss function +l1_loss = nn.L1Loss() +criterion_training = losses.ModifiedChamferLoss(256, 256, return_2_terms=True) + +# Restore saved checkpoint (model weights + epoch + optimizer state) +print("Loading checkpoint '{}' ...".format(args.checkpoint)) +if os.path.isfile(args.checkpoint): + checkpoint = torch.load(args.checkpoint) + start_epoch = checkpoint['epoch'] + lowest_avg_loss_val = checkpoint['lowest_avg_loss_val'] + model.load_state_dict(checkpoint['model']) + print("╰─ loaded checkpoint '{}' (now on epoch {})" + .format(args.checkpoint, checkpoint['epoch'])) +else: + print("╰─ E: no checkpoint found at '{}'".format(args.checkpoint)) + exit(-1) + +tic = time.time() + + +# === Testing === +print("\Testing... ") + +# Empty output CSV +df_out = pd.DataFrame(columns=['plant_count']) + +# Set the module in evaluation mode +model.eval() + +sum_loss = 0 +for batch_idx, (data, dictionary) in tqdm(enumerate(testset_loader), total=len(testset_loader)): + + # Pull info from this sample image + gt_plant_locations = [eval(el) for el in dictionary['plant_locations']] + target_n_plants = dictionary['plant_count'] + # We cannot deal with images with 0 plants (CD is not defined) + if any(len(target_one_img) == 0 for target_one_img in gt_plant_locations): + continue + + target = gt_plant_locations + + # Prepare data and target + data, target, target_n_plants = data.type( + torch.FloatTensor), torch.FloatTensor(target), target_n_plants.type(torch.FloatTensor) + if args.cuda: + data, target, target_n_plants = data.cuda(), target.cuda(), target_n_plants.cuda() + data, target, target_n_plants = Variable(data, volatile=True), Variable( + target, volatile=True), Variable(target_n_plants, volatile=True) + + # One forward + est_map, est_n_plants = model.forward(data) + est_map = est_map.squeeze() + term1, term2 = criterion_training.forward(est_map, target) + term3 = l1_loss.forward(est_n_plants, + target_n_plants.type(torch.cuda.FloatTensor))/ \ + target_n_plants.type(torch.cuda.FloatTensor) + loss = term1 + term2 + term3 + + sum_loss += loss + + # Save estimation to disk and append to CSV + tv.utils.save_image(est_map.data, os.path.join(args.out_dir, dictionary['filename'][0])) + df = pd.DataFrame(data=[est_n_plants.data.cpu().numpy()[0]], + index=[dictionary['filename'][0]], + columns=['plant_count']) + df_out = df_out.append(df) + +avg_loss_test = sum_loss / len(testset_loader) +avg_loss_test_float = avg_loss_test.data.cpu().numpy()[0] + +# Write CSV to disk +df_out.to_csv(os.path.join(args.out_dir, 'estimations.csv')) + +print('╰─ Average Loss for all the testing set: {:.4f}'.format(avg_loss_test_float)) +print('It took %s seconds to evaluate all the testing set.' % int(time.time() - tic)) diff --git a/unet_model.py b/unet_model.py new file mode 100644 index 0000000..17ec1c3 --- /dev/null +++ b/unet_model.py @@ -0,0 +1,60 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from unet_parts import * + + +class UNet(nn.Module): + def __init__(self, n_channels, n_classes): + super(UNet, self).__init__() + self.inc = inconv(n_channels, 64) + self.down1 = down(64, 128) + self.down2 = down(128, 256) + self.down3 = down(256, 512) + self.down4 = down(512, 512) + self.down5 = down(512, 512) + self.down6 = down(512, 512) + self.down7 = down(512, 512) + self.down8 = down(512, 512) + self.up1 = up(1024, 512) + self.up2 = up(1024, 512) + self.up3 = up(1024, 512) + self.up4 = up(1024, 512) + self.up5 = up(1024, 256) + self.up6 = up(512, 128) + self.up7 = up(256, 64) + self.up8 = up(128, 64) + self.outc = outconv(64, n_classes) + self.out_nonlin = nn.Sigmoid() + + self.regressor = nn.Linear(256*256, 1) + self.regressor_nonlin = nn.Softplus() + + def forward(self, x): + x1 = self.inc(x) + x2 = self.down1(x1) + x3 = self.down2(x2) + x4 = self.down3(x3) + x5 = self.down4(x4) + x6 = self.down5(x5) + x7 = self.down6(x6) + x8 = self.down7(x7) + x9 = self.down8(x8) + x = self.up1(x9, x8) + x = self.up2(x, x7) + x = self.up3(x, x6) + x = self.up4(x, x5) + x = self.up5(x, x4) + x = self.up6(x, x3) + x = self.up7(x, x2) + x = self.up8(x, x1) + x = self.outc(x) + x = self.out_nonlin(x) + + x_flat = x.view(1, -1) + + regression = self.regressor(x_flat) + regression = self.regressor_nonlin(regression) + + return x, regression diff --git a/unet_parts.py b/unet_parts.py new file mode 100644 index 0000000..7efb750 --- /dev/null +++ b/unet_parts.py @@ -0,0 +1,73 @@ +# sub-parts of the U-Net model + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class double_conv(nn.Module): + def __init__(self, in_ch, out_ch): + super(double_conv, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d(in_ch, out_ch, 3, padding=1), + nn.BatchNorm2d(out_ch), + nn.ReLU(inplace=True), + nn.Conv2d(out_ch, out_ch, 3, padding=1), + nn.BatchNorm2d(out_ch), + nn.ReLU(inplace=True) + ) + + def forward(self, x): + x = self.conv(x) + return x + + +class inconv(nn.Module): + def __init__(self, in_ch, out_ch): + super(inconv, self).__init__() + self.conv = double_conv(in_ch, out_ch) + + def forward(self, x): + x = self.conv(x) + return x + + +class down(nn.Module): + def __init__(self, in_ch, out_ch): + super(down, self).__init__() + self.mpconv = nn.Sequential( + nn.MaxPool2d(2), + double_conv(in_ch, out_ch) + ) + + def forward(self, x): + x = self.mpconv(x) + return x + + +class up(nn.Module): + def __init__(self, in_ch, out_ch): + super(up, self).__init__() + self.up = nn.UpsamplingBilinear2d(scale_factor=2) + # self.up = nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2) + self.conv = double_conv(in_ch, out_ch) + + def forward(self, x1, x2): + x1 = self.up(x1) + diffX = x1.size()[2] - x2.size()[2] + diffY = x1.size()[3] - x2.size()[3] + x2 = F.pad(x2, (diffX // 2, int(diffX / 2), + diffY // 2, int(diffY / 2))) + x = torch.cat([x2, x1], dim=1) + x = self.conv(x) + return x + + +class outconv(nn.Module): + def __init__(self, in_ch, out_ch): + super(outconv, self).__init__() + self.conv = nn.Conv2d(in_ch, out_ch, 1) + + def forward(self, x): + x = self.conv(x) + return x diff --git a/unet_pix2pix.py b/unet_pix2pix.py new file mode 100644 index 0000000..9399bcc --- /dev/null +++ b/unet_pix2pix.py @@ -0,0 +1,97 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +import functools +from torch.optim import lr_scheduler +import numpy as np + + +# Defines the Unet generator. +# |num_downs|: number of downsamplings in UNet. For example, +# if |num_downs| == 7, image of size 128x128 will become of size 1x1 +# at the bottleneck +class UnetGenerator(nn.Module): + def __init__(self, input_nc, output_nc, num_downs, ngf=64, + norm_layer=nn.BatchNorm2d, use_dropout=True, gpu_ids=[]): + super(UnetGenerator, self).__init__() + self.gpu_ids = gpu_ids + + # construct unet structure + unet_block_innermost = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block_innermost, norm_layer=norm_layer) + for i in range(num_downs - 4): + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) + unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) + + self.model = unet_block + self.model.regressor = unet_block_innermost.regressor + + def forward(self, input): + if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor): + res = nn.parallel.data_parallel(self.model, input, self.gpu_ids) + else: + res = self.model(input) + # res = res/(res.max() + 1e-12) + # res = F.threshold(res, 3/255, 0)/res + return res + + +# Defines the submodule with skip connection. +# X -------------------identity---------------------- X +# |-- downsampling -- |submodule| -- upsampling --| +class UnetSkipConnectionBlock(nn.Module): + def __init__(self, outer_nc, inner_nc, input_nc=None, + submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=True): + super(UnetSkipConnectionBlock, self).__init__() + self.outermost = outermost + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + if input_nc is None: + input_nc = outer_nc + downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, + stride=2, padding=1, bias=use_bias) + downrelu = nn.LeakyReLU(0.2, True) + downnorm = norm_layer(inner_nc) + uprelu = nn.ReLU(True) + upnorm = norm_layer(outer_nc) + + if outermost: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, + kernel_size=4, stride=2, + padding=1) + down = [downconv] + up = [uprelu, upconv, nn.Sigmoid()] + model = down + [submodule] + up + elif innermost: + upconv = nn.ConvTranspose2d(inner_nc, outer_nc, + kernel_size=4, stride=2, + padding=1, bias=use_bias) + down = [downrelu, downconv] + up = [uprelu, upconv, upnorm] + model = down + up + + self.regressor = nn.Sequential(*down, nn.Linear(1, 1, bias=True)) + else: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, + kernel_size=4, stride=2, + padding=1, bias=use_bias) + down = [downrelu, downconv, downnorm] + up = [uprelu, upconv, upnorm] + + if use_dropout: + model = down + [submodule] + up + [nn.Dropout(0.5)] + else: + model = down + [submodule] + up + + self.model = nn.Sequential(*model) + + def forward(self, x): + if self.outermost: + return (self.model(x), self.regressor(x)) + else: + return torch.cat([x, self.model(x)], 1) diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..1af8f30 --- /dev/null +++ b/utils.py @@ -0,0 +1,11 @@ +import numpy as np + +# Converts a Tensor into a Numpy array +# |imtype|: the desired type of the converted numpy array +def tensor2im(image_tensor, imtype=np.uint8): + image_numpy = image_tensor.cpu().float().numpy() + import ipdb; ipdb.set_trace() + if image_numpy.shape[0] == 1: + image_numpy = np.tile(image_numpy, (3, 1, 1)) + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 + return image_numpy.astype(imtype)