Skip to content

Commit

Permalink
UCF-101 dataset mean and proper normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
tomrunia committed Nov 14, 2018
1 parent d50b7a3 commit d2a9797
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 33 deletions.
6 changes: 5 additions & 1 deletion config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ def parse_opts():
parser.add_argument('--dataset', type=str, required=True, help='Dataset string (kinetics | activitynet | ucf101 | blender)')
parser.add_argument('--num_val_samples', type=int, default=1, help='Number of validation samples for each activity')
parser.add_argument('--norm_value', default=255, type=int, help='Divide inputs by 255 or 1')
parser.add_argument('--no_dataset_mean', action='store_true', help='Dont use the dataset mean but normalize to zero mean')
parser.add_argument('--no_dataset_std', action='store_true', help='Dont use the dataset std but normalize to unity std')
parser.add_argument('--num_classes', default=400, type=int, help= 'Number of classes (activitynet: 200, kinetics: 400, ucf101: 101, hmdb51: 51)')
parser.set_defaults(no_dataset_std=True)

# Preprocessing pipeline
parser.add_argument('--spatial_size', default=224, type=int, help='Height and width of inputs')
Expand Down Expand Up @@ -59,12 +62,13 @@ def parse_opts():
parser.add_argument('--checkpoint_frequency', type=int, default=1, help='Save checkpoint after this number of epochs')
parser.add_argument('--checkpoints_num_keep', type=int, default=5, help='Number of checkpoints to keep')
parser.add_argument('--log_frequency', type=int, default=5, help='Logging frequency in number of steps')
parser.add_argument('--log_image_frequency', type=int, default=200, help='Logging images frequency in number of steps')
parser.add_argument('--no_tensorboard', action='store_true', default=False, help='Disable the use of TensorboardX')

# Misc
parser.add_argument('--device', default='cuda:0', help='Device string cpu | cuda:0')
parser.add_argument('--history_steps', default=25, type=int, help='History of running average meters')
parser.add_argument('--num_workers', default=4, type=int, help='Number of threads for multi-thread loading')
parser.add_argument('--num_workers', default=6, type=int, help='Number of threads for multi-thread loading')
parser.add_argument('--no_eval', action='store_true', default=False, help='Disable evaluation')

return parser.parse_args()
15 changes: 6 additions & 9 deletions datasets/ucf101.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@

from utils.utils import load_value_file

##########################################################################################
##########################################################################################

def pil_loader(path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
with Image.open(f) as img:
return img.convert('RGB')


def accimage_loader(path):
try:
import accimage
Expand All @@ -27,15 +28,13 @@ def accimage_loader(path):
# Potentially a decoding problem, fall back to PIL.Image
return pil_loader(path)


def get_default_image_loader():
from torchvision import get_image_backend
if get_image_backend() == 'accimage':
return accimage_loader
else:
return pil_loader


def video_loader(video_dir_path, frame_indices, image_loader):
video = []
for i in frame_indices:
Expand All @@ -47,17 +46,14 @@ def video_loader(video_dir_path, frame_indices, image_loader):

return video


def get_default_video_loader():
image_loader = get_default_image_loader()
return functools.partial(video_loader, image_loader=image_loader)


def load_annotation_data(data_file_path):
with open(data_file_path, 'r') as data_file:
return json.load(data_file)


def get_class_labels(data):
class_labels_map = {}
index = 0
Expand All @@ -66,7 +62,6 @@ def get_class_labels(data):
index += 1
return class_labels_map


def get_video_names_and_annotations(data, subset):
video_names = []
annotations = []
Expand All @@ -80,6 +75,8 @@ def get_video_names_and_annotations(data, subset):

return video_names, annotations

##########################################################################################
##########################################################################################

def make_dataset(root_path, annotation_path, subset, n_samples_for_each_video,
sample_duration):
Expand Down Expand Up @@ -143,8 +140,8 @@ def make_dataset(root_path, annotation_path, subset, n_samples_for_each_video,

return dataset, idx_to_class

############################################################################
############################################################################
##########################################################################################
##########################################################################################

class UCF101(data.Dataset):
"""
Expand Down
22 changes: 21 additions & 1 deletion epoch_iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def train_epoch(config, model, criterion, optimizer, device,
optimizer.zero_grad()

# Move inputs to GPU memory
clips = clips.to(device)
clips = clips.to(device)
targets = targets.to(device)
if config.model == 'i3d':
targets = torch.unsqueeze(targets, -1)
Expand Down Expand Up @@ -97,6 +97,16 @@ def train_epoch(config, model, criterion, optimizer, device,
summary_writer.add_scalar('train/learning_rate', current_learning_rate(optimizer), global_step)
summary_writer.add_scalar('train/weight_decay', current_weight_decay(optimizer), global_step)

if summary_writer and step % config.log_image_frequency == 0:
# TensorboardX video summary
for example_idx in range(4):
clip_for_display = clips[example_idx].clone().cpu()
min_val = float(clip_for_display.min())
max_val = float(clip_for_display.max())
clip_for_display.clamp_(min=min_val, max=max_val)
clip_for_display.add_(-min_val).div_(max_val - min_val + 1e-5)
summary_writer.add_video('train_clips/{:04d}'.format(example_idx), clip_for_display.unsqueeze(0), global_step)

# Epoch statistics
epoch_duration = float(time.time() - epoch_start_time)
epoch_avg_loss = np.mean(losses)
Expand Down Expand Up @@ -159,6 +169,16 @@ def validation_epoch(config, model, criterion, device, data_loader, epoch, summa
step, steps_in_epoch, examples_per_second,
accuracies[step], losses[step]))

if summary_writer and step == 0:
# TensorboardX video summary
for example_idx in range(4):
clip_for_display = clips[example_idx].clone().cpu()
min_val = float(clip_for_display.min())
max_val = float(clip_for_display.max())
clip_for_display.clamp_(min=min_val, max=max_val)
clip_for_display.add_(-min_val).div_(max_val - min_val + 1e-5)
summary_writer.add_video('validation_clips/{:04d}'.format(example_idx), clip_for_display.unsqueeze(0), epoch*steps_in_epoch)

# Epoch statistics
epoch_duration = float(time.time() - epoch_start_time)
epoch_avg_loss = np.mean(losses)
Expand Down
4 changes: 2 additions & 2 deletions factory/data_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,8 @@ def get_data_loaders(config, train_transforms, validation_transforms=None):
if not config.no_eval and validation_transforms:

dataset_validation = get_validation_set(
config, train_transforms['spatial'],
train_transforms['temporal'], train_transforms['target'])
config, validation_transforms['spatial'],
validation_transforms['temporal'], validation_transforms['target'])

print('Found {} validation examples'.format(len(dataset_validation)))

Expand Down
38 changes: 31 additions & 7 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,14 @@
from datetime import datetime

import torch.nn as nn
import torch.backends.cudnn as cudnn

from transforms.spatial_transforms import Compose, Normalize, RandomHorizontalFlip, \
RandomVerticalFlip, MultiScaleRandomCrop, ToTensor, CenterCrop
from transforms.spatial_transforms import Compose, Normalize, RandomHorizontalFlip, MultiScaleRandomCrop, ToTensor, CenterCrop
from transforms.temporal_transforms import TemporalRandomCrop
from transforms.target_transforms import ClassLabel

from epoch_iterators import train_epoch, validation_epoch
from utils.utils import *
import utils.mean_values
import factory.data_factory as data_factory
import factory.model_factory as model_factory
from config import parse_opts
Expand All @@ -43,6 +42,9 @@
config = init_cropping_scales(config)
config = set_lr_scheduling_policy(config)

config.image_mean = utils.mean_values.get_mean(config.norm_value, config.dataset)
config.image_std = utils.mean_values.get_std(config.norm_value)

print_config(config)
write_config(config, os.path.join(config.save_dir, 'config.json'))

Expand Down Expand Up @@ -75,17 +77,39 @@
####################################################################
# Setup of data transformations

if config.no_dataset_mean and config.no_dataset_std:
# Just zero-center and scale to unit std
print('Data normalization: no dataset mean, no dataset std')
norm_method = Normalize([0, 0, 0], [1, 1, 1])
elif not config.no_dataset_mean and config.no_dataset_std:
# Subtract dataset mean and scale to unit std
print('Data normalization: use dataset mean, no dataset std')
norm_method = Normalize(config.image_mean, [1, 1, 1])
else:
# Subtract dataset mean and scale to dataset std
print('Data normalization: use dataset mean, use dataset std')
norm_method = Normalize(config.image_mean, config.image_std)

train_transforms = {
'spatial': Compose([MultiScaleRandomCrop(config.scales, config.spatial_size),
RandomHorizontalFlip(), RandomVerticalFlip(),
ToTensor(config.norm_value), Normalize([0, 0, 0], [1, 1, 1])]),
RandomHorizontalFlip(),
ToTensor(config.norm_value),
norm_method]),
'temporal': TemporalRandomCrop(config.sample_duration),
'target': ClassLabel()
}

# print('WARNING: setting train transforms for dataset statistics')
# train_transforms = {
# 'spatial': Compose([ToTensor(1.0)]),
# 'temporal': TemporalRandomCrop(64),
# 'target': ClassLabel()
# }

validation_transforms = {
'spatial': Compose([CenterCrop(config.spatial_size), ToTensor(config.norm_value),
Normalize([0, 0, 0], [1, 1, 1])]),
'spatial': Compose([CenterCrop(config.spatial_size),
ToTensor(config.norm_value),
norm_method]),
'temporal': TemporalRandomCrop(config.sample_duration),
'target': ClassLabel()
}
Expand Down
21 changes: 8 additions & 13 deletions utils/mean_values.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,18 @@
# Source: https://github.com/kenshohara/3D-ResNets-PyTorch/blob/master/mean.py

def get_mean(norm_value=255, dataset='activitynet'):
assert dataset in ['activitynet', 'kinetics']

# Below values are in RGB order
assert dataset in ['activitynet', 'kinetics', 'ucf101']

if dataset == 'activitynet':
return [
114.7748 / norm_value, 107.7354 / norm_value, 99.4750 / norm_value
]
return [114.7748/norm_value, 107.7354/norm_value, 99.4750/norm_value]
elif dataset == 'kinetics':
# Kinetics (10 videos for each class)
return [
110.63666788 / norm_value, 103.16065604 / norm_value,
96.29023126 / norm_value
]

return [110.63666788/norm_value, 103.16065604/norm_value, 96.29023126/norm_value]
elif dataset == 'ucf101':
return [101.00131/norm_value, 97.3644226/norm_value, 89.42114168/norm_value]

def get_std(norm_value=255):
# Kinetics (10 videos for each class)
return [
38.7568578 / norm_value, 37.88248729 / norm_value,
40.02898126 / norm_value
]
return [38.7568578/norm_value, 37.88248729/norm_value, 40.02898126/norm_value]

0 comments on commit d2a9797

Please sign in to comment.