diff --git a/object-locator/argparser.py b/object-locator/argparser.py index 09ef25e..38e11f1 100644 --- a/object-locator/argparser.py +++ b/object-locator/argparser.py @@ -33,9 +33,11 @@ def parse_command_args(training_or_testing): required=True, help='Directory with training images.') parser.add_argument('--val-dir', - help='Directory with validation images. ' - 'If left blank no validation will be done. ' - 'If not provided, will not do validation') + help="Directory with validation images. " + "If 'auto', 20%% of the training samples " + "will be removed from training " + "and used for validation. " + "If left blank no validation will be done.") parser.add_argument('--imgsize', type=str, default='256x256', diff --git a/object-locator/data.py b/object-locator/data.py index b95d93a..4cd9b29 100644 --- a/object-locator/data.py +++ b/object-locator/data.py @@ -9,25 +9,138 @@ import os import random from collections import OrderedDict +import copy from PIL import Image import skimage import numpy as np import pandas as pd import torch -from torch.utils import data from torchvision import datasets from torchvision import transforms import xmltodict from parse import parse +from ballpark import ballpark + from . import get_image_size IMG_EXTENSIONS = ['.png', '.jpeg', '.jpg', '.tiff'] torch.set_default_dtype(torch.float32) +def get_train_val_loaders(train_dir, + collate_fn, + height, + width, + no_data_augmentation=False, + max_trainset_size=np.infty, + seed=0, + batch_size=1, + drop_last_batch=False, + shuffle=True, + num_workers=0, + val_dir=None, + max_valset_size=np.infty): + """ + Create a training loader and a validation set. + If the validation directory is 'auto', + 20% of the dataset is used for validation. + + :param train_dir: Directory with all the training images and the CSV file. + :param train_transforms: Transform to be applied to each training image. + :param max_trainset_size: Only use first N images for training. + :param collate_fn: Function to assemble samples into batches. + :param height: Resize the images to this height. + :param width: Resize the images to this width. + :param no_data_augmentation: Do not perform data augmentation. + :param seed: Random seed. + :param batch_size: Number of samples in a batch, for training. + :param drop_last_batch: Drop the last incomplete batch during training + :param shuffle: Randomly shuffle the dataset before each epoch. + :param num_workers: Number of subprocesses dedicated for data loading. + :param val_dir: Directory with all the training images and the CSV file. + :param max_valset_size: Only use first N images for validation. + """ -class CSVDataset(data.Dataset): + # Data augmentation for training + training_transforms = [] + if not no_data_augmentation: + training_transforms += [RandomHorizontalFlipImageAndLabel(p=0.5, seed=seed)] + training_transforms += [RandomVerticalFlipImageAndLabel(p=0.5, seed=seed)] + training_transforms += [ScaleImageAndLabel(size=(height, width))] + training_transforms += [transforms.ToTensor()] + training_transforms += [transforms.Normalize((0.5, 0.5, 0.5), + (0.5, 0.5, 0.5))] + + # Data augmentation for validation + validation_transforms = transforms.Compose([ + ScaleImageAndLabel(size=(height, width)), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), + (0.5, 0.5, 0.5)), + ]) + + # Training dataset + trainset = CSVDataset(train_dir, + transforms=transforms.Compose(training_transforms), + max_dataset_size=max_trainset_size, + seed=seed) + + # Validation dataset + if val_dir is not None: + if val_dir == 'auto': + # Create a dataset just as in training + valset = CSVDataset(train_dir, + transforms=validation_transforms, + max_dataset_size=max_trainset_size) + + # Split 80% for training, 20% for validation + n_imgs_for_training = int(round(0.8*len(trainset))) + if trainset.there_is_gt: + trainset.csv_df = trainset.csv_df[:n_imgs_for_training] + valset.csv_df = valset.csv_df[n_imgs_for_training:].reset_index() + else: + trainset.listfiles = trainset.listfiles[:n_imgs_for_training] + valset.listfiles = valset.listfiles[n_imgs_for_training:] + + else: + valset = CSVDataset(val_dir, + transforms=validation_transforms, + max_dataset_size=max_valset_size) + valset_loader = torch.utils.data.DataLoader(valset, + batch_size=1, + shuffle=True, + num_workers=num_workers, + collate_fn=csv_collator) + else: + valset, valset_loader = None, None + + print(f'# images for training: ' + f'{ballpark(len(trainset))}') + if valset is not None: + print(f'# images for validation: ' + f'{ballpark(len(valset))}') + else: + print('W: no validation set was selected!') + + # Build data loaders from the datasets + trainset_loader = torch.utils.data.DataLoader(trainset, + batch_size=batch_size, + drop_last=drop_last_batch, + shuffle=True, + num_workers=num_workers, + collate_fn=csv_collator) + if valset is not None: + valset_loader = torch.utils.data.DataLoader(valset, + batch_size=1, + shuffle=True, + num_workers=num_workers, + collate_fn=csv_collator) + + return trainset_loader, valset_loader + + +class CSVDataset(torch.utils.data.Dataset): def __init__(self, directory, transforms=None, @@ -106,7 +219,6 @@ def __getitem__(self, idx): The second element is a dictionary where the keys are the columns of the CSV. If the CSV did not exist in the dataset directory, the dictionary will only contain the filename of the image. -_ :param idx: Index of the image in the dataset to get. """ @@ -309,7 +421,7 @@ def _is_pil_image(img): return isinstance(img, Image.Image) -class XMLDataset(data.Dataset): +class XMLDataset(torch.utils.data.Dataset): def __init__(self, directory, transforms=None, diff --git a/object-locator/train.py b/object-locator/train.py index 0bf609f..9b89727 100644 --- a/object-locator/train.py +++ b/object-locator/train.py @@ -35,14 +35,14 @@ from . import losses from .models import unet_model from .metrics import Judge -from .data import CSVDataset +from . import logger +from . import argparser +from . import utils +from . import data from .data import csv_collator from .data import RandomHorizontalFlipImageAndLabel from .data import RandomVerticalFlipImageAndLabel from .data import ScaleImageAndLabel -from . import logger -from . import argparser -from . import utils # Parse command line arguments @@ -68,41 +68,20 @@ port=args.visdom_port, env_name=args.visdom_env) -# Data loading code -training_transforms = [] -if not args.no_data_augm: - training_transforms += [RandomHorizontalFlipImageAndLabel(p=0.5, seed=args.seed)] - # training_transforms += [RandomVerticalFlipImageAndLabel(p=0.5, seed=args.seed)] -training_transforms += [ScaleImageAndLabel(size=(args.height, args.width))] -training_transforms += [transforms.ToTensor()] -training_transforms += [transforms.Normalize((0.5, 0.5, 0.5), - (0.5, 0.5, 0.5))] -trainset = CSVDataset(args.train_dir, - transforms=transforms.Compose(training_transforms), - max_dataset_size=args.max_trainset_size, - seed=args.seed) -print(f'# images for training: {len(trainset)}') -trainset_loader = DataLoader(trainset, - batch_size=args.batch_size, - drop_last=args.drop_last_batch, - shuffle=True, - num_workers=args.nThreads, - collate_fn=csv_collator) -if args.val_dir: - valset = CSVDataset(args.val_dir, - transforms=transforms.Compose([ - ScaleImageAndLabel(size=(args.height, args.width)), - transforms.ToTensor(), - transforms.Normalize((0.5, 0.5, 0.5), - (0.5, 0.5, 0.5)), - ]), - max_dataset_size=args.max_valset_size) - print(f'# images for validation: {len(valset)}') - valset_loader = DataLoader(valset, - batch_size=args.eval_batch_size, - shuffle=True, + +# Create data loaders (return data in batches) +trainset_loader, valset_loader = \ + data.get_train_val_loaders(train_dir=args.train_dir, + max_trainset_size=args.max_trainset_size, + collate_fn=csv_collator, + height=args.height, + width=args.width, + seed=args.seed, + batch_size=args.batch_size, + drop_last_batch=args.drop_last_batch, num_workers=args.nThreads, - collate_fn=csv_collator) + val_dir=args.val_dir, + max_valset_size=args.max_valset_size) # Model with peter('Building network'): @@ -171,7 +150,7 @@ loss_avg_this_epoch = 0 iter_train = tqdm(trainset_loader, - desc=f'Epoch {epoch} ({len(trainset)} images)') + desc=f'Epoch {epoch} ({len(trainset_loader.dataset)} images)') # === TRAIN === @@ -305,7 +284,7 @@ sum_term3 = 0 sum_loss = 0 iter_val = tqdm(valset_loader, - desc=f'Validating Epoch {epoch} ({len(valset)} images)') + desc=f'Validating Epoch {epoch} ({len(valset_loader.dataset)} images)') for batch_idx, (imgs, dictionaries) in enumerate(iter_val): # Pull info from this batch and move to device @@ -325,8 +304,8 @@ target_orig_widths = torch.stack(target_orig_widths) target_orig_sizes = torch.stack((target_orig_heights, target_orig_widths)).transpose(0, 1) - origsize = (dictionaries[0]['orig_height'].item(), - dictionaries[0]['orig_width'].item()) + orig_shape = (dictionaries[0]['orig_height'].item(), + dictionaries[0]['orig_width'].item()) # Tensor -> float & numpy target_count_int = int(round(target_counts.item())) @@ -370,7 +349,7 @@ # BMM thresholding est_map_numpy = est_maps[0, :, :].to(device_cpu).numpy() est_map_numpy_origsize = skimage.transform.resize(est_map_numpy, - output_shape=origsize, + output_shape=orig_shape, mode='constant') mask, _ = utils.threshold(est_map_numpy_origsize, tau=-1) # Obtain centroids of the mask