Skip to content

Commit

Permalink
wrap data loader generator in a function
Browse files Browse the repository at this point in the history
Former-commit-id: 7b8acc8ffce0810c62fa4d086755ca8cc277192a
  • Loading branch information
Javi Ribera committed Dec 6, 2018
1 parent 2b3ce85 commit afcd771
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 50 deletions.
8 changes: 5 additions & 3 deletions object-locator/argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@ def parse_command_args(training_or_testing):
required=True,
help='Directory with training images.')
parser.add_argument('--val-dir',
help='Directory with validation images. '
'If left blank no validation will be done. '
'If not provided, will not do validation')
help="Directory with validation images. "
"If 'auto', 20%% of the training samples "
"will be removed from training "
"and used for validation. "
"If left blank no validation will be done.")
parser.add_argument('--imgsize',
type=str,
default='256x256',
Expand Down
120 changes: 116 additions & 4 deletions object-locator/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,138 @@
import os
import random
from collections import OrderedDict
import copy

from PIL import Image
import skimage
import numpy as np
import pandas as pd
import torch
from torch.utils import data
from torchvision import datasets
from torchvision import transforms
import xmltodict
from parse import parse
from ballpark import ballpark

from . import get_image_size

IMG_EXTENSIONS = ['.png', '.jpeg', '.jpg', '.tiff']

torch.set_default_dtype(torch.float32)

def get_train_val_loaders(train_dir,
collate_fn,
height,
width,
no_data_augmentation=False,
max_trainset_size=np.infty,
seed=0,
batch_size=1,
drop_last_batch=False,
shuffle=True,
num_workers=0,
val_dir=None,
max_valset_size=np.infty):
"""
Create a training loader and a validation set.
If the validation directory is 'auto',
20% of the dataset is used for validation.
:param train_dir: Directory with all the training images and the CSV file.
:param train_transforms: Transform to be applied to each training image.
:param max_trainset_size: Only use first N images for training.
:param collate_fn: Function to assemble samples into batches.
:param height: Resize the images to this height.
:param width: Resize the images to this width.
:param no_data_augmentation: Do not perform data augmentation.
:param seed: Random seed.
:param batch_size: Number of samples in a batch, for training.
:param drop_last_batch: Drop the last incomplete batch during training
:param shuffle: Randomly shuffle the dataset before each epoch.
:param num_workers: Number of subprocesses dedicated for data loading.
:param val_dir: Directory with all the training images and the CSV file.
:param max_valset_size: Only use first N images for validation.
"""

class CSVDataset(data.Dataset):
# Data augmentation for training
training_transforms = []
if not no_data_augmentation:
training_transforms += [RandomHorizontalFlipImageAndLabel(p=0.5, seed=seed)]
training_transforms += [RandomVerticalFlipImageAndLabel(p=0.5, seed=seed)]
training_transforms += [ScaleImageAndLabel(size=(height, width))]
training_transforms += [transforms.ToTensor()]
training_transforms += [transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]

# Data augmentation for validation
validation_transforms = transforms.Compose([
ScaleImageAndLabel(size=(height, width)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5)),
])

# Training dataset
trainset = CSVDataset(train_dir,
transforms=transforms.Compose(training_transforms),
max_dataset_size=max_trainset_size,
seed=seed)

# Validation dataset
if val_dir is not None:
if val_dir == 'auto':
# Create a dataset just as in training
valset = CSVDataset(train_dir,
transforms=validation_transforms,
max_dataset_size=max_trainset_size)

# Split 80% for training, 20% for validation
n_imgs_for_training = int(round(0.8*len(trainset)))
if trainset.there_is_gt:
trainset.csv_df = trainset.csv_df[:n_imgs_for_training]
valset.csv_df = valset.csv_df[n_imgs_for_training:].reset_index()
else:
trainset.listfiles = trainset.listfiles[:n_imgs_for_training]
valset.listfiles = valset.listfiles[n_imgs_for_training:]

else:
valset = CSVDataset(val_dir,
transforms=validation_transforms,
max_dataset_size=max_valset_size)
valset_loader = torch.utils.data.DataLoader(valset,
batch_size=1,
shuffle=True,
num_workers=num_workers,
collate_fn=csv_collator)
else:
valset, valset_loader = None, None

print(f'# images for training: '
f'{ballpark(len(trainset))}')
if valset is not None:
print(f'# images for validation: '
f'{ballpark(len(valset))}')
else:
print('W: no validation set was selected!')

# Build data loaders from the datasets
trainset_loader = torch.utils.data.DataLoader(trainset,
batch_size=batch_size,
drop_last=drop_last_batch,
shuffle=True,
num_workers=num_workers,
collate_fn=csv_collator)
if valset is not None:
valset_loader = torch.utils.data.DataLoader(valset,
batch_size=1,
shuffle=True,
num_workers=num_workers,
collate_fn=csv_collator)

return trainset_loader, valset_loader


class CSVDataset(torch.utils.data.Dataset):
def __init__(self,
directory,
transforms=None,
Expand Down Expand Up @@ -106,7 +219,6 @@ def __getitem__(self, idx):
The second element is a dictionary where the keys are the columns of the CSV.
If the CSV did not exist in the dataset directory,
the dictionary will only contain the filename of the image.
_
:param idx: Index of the image in the dataset to get.
"""

Expand Down Expand Up @@ -309,7 +421,7 @@ def _is_pil_image(img):
return isinstance(img, Image.Image)


class XMLDataset(data.Dataset):
class XMLDataset(torch.utils.data.Dataset):
def __init__(self,
directory,
transforms=None,
Expand Down
65 changes: 22 additions & 43 deletions object-locator/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@
from . import losses
from .models import unet_model
from .metrics import Judge
from .data import CSVDataset
from . import logger
from . import argparser
from . import utils
from . import data
from .data import csv_collator
from .data import RandomHorizontalFlipImageAndLabel
from .data import RandomVerticalFlipImageAndLabel
from .data import ScaleImageAndLabel
from . import logger
from . import argparser
from . import utils


# Parse command line arguments
Expand All @@ -68,41 +68,20 @@
port=args.visdom_port,
env_name=args.visdom_env)

# Data loading code
training_transforms = []
if not args.no_data_augm:
training_transforms += [RandomHorizontalFlipImageAndLabel(p=0.5, seed=args.seed)]
# training_transforms += [RandomVerticalFlipImageAndLabel(p=0.5, seed=args.seed)]
training_transforms += [ScaleImageAndLabel(size=(args.height, args.width))]
training_transforms += [transforms.ToTensor()]
training_transforms += [transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]
trainset = CSVDataset(args.train_dir,
transforms=transforms.Compose(training_transforms),
max_dataset_size=args.max_trainset_size,
seed=args.seed)
print(f'# images for training: {len(trainset)}')
trainset_loader = DataLoader(trainset,
batch_size=args.batch_size,
drop_last=args.drop_last_batch,
shuffle=True,
num_workers=args.nThreads,
collate_fn=csv_collator)
if args.val_dir:
valset = CSVDataset(args.val_dir,
transforms=transforms.Compose([
ScaleImageAndLabel(size=(args.height, args.width)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5)),
]),
max_dataset_size=args.max_valset_size)
print(f'# images for validation: {len(valset)}')
valset_loader = DataLoader(valset,
batch_size=args.eval_batch_size,
shuffle=True,

# Create data loaders (return data in batches)
trainset_loader, valset_loader = \
data.get_train_val_loaders(train_dir=args.train_dir,
max_trainset_size=args.max_trainset_size,
collate_fn=csv_collator,
height=args.height,
width=args.width,
seed=args.seed,
batch_size=args.batch_size,
drop_last_batch=args.drop_last_batch,
num_workers=args.nThreads,
collate_fn=csv_collator)
val_dir=args.val_dir,
max_valset_size=args.max_valset_size)

# Model
with peter('Building network'):
Expand Down Expand Up @@ -171,7 +150,7 @@

loss_avg_this_epoch = 0
iter_train = tqdm(trainset_loader,
desc=f'Epoch {epoch} ({len(trainset)} images)')
desc=f'Epoch {epoch} ({len(trainset_loader.dataset)} images)')

# === TRAIN ===

Expand Down Expand Up @@ -305,7 +284,7 @@
sum_term3 = 0
sum_loss = 0
iter_val = tqdm(valset_loader,
desc=f'Validating Epoch {epoch} ({len(valset)} images)')
desc=f'Validating Epoch {epoch} ({len(valset_loader.dataset)} images)')
for batch_idx, (imgs, dictionaries) in enumerate(iter_val):

# Pull info from this batch and move to device
Expand All @@ -325,8 +304,8 @@
target_orig_widths = torch.stack(target_orig_widths)
target_orig_sizes = torch.stack((target_orig_heights,
target_orig_widths)).transpose(0, 1)
origsize = (dictionaries[0]['orig_height'].item(),
dictionaries[0]['orig_width'].item())
orig_shape = (dictionaries[0]['orig_height'].item(),
dictionaries[0]['orig_width'].item())

# Tensor -> float & numpy
target_count_int = int(round(target_counts.item()))
Expand Down Expand Up @@ -370,7 +349,7 @@
# BMM thresholding
est_map_numpy = est_maps[0, :, :].to(device_cpu).numpy()
est_map_numpy_origsize = skimage.transform.resize(est_map_numpy,
output_shape=origsize,
output_shape=orig_shape,
mode='constant')
mask, _ = utils.threshold(est_map_numpy_origsize, tau=-1)
# Obtain centroids of the mask
Expand Down

0 comments on commit afcd771

Please sign in to comment.