diff --git a/.travis.yml b/.travis.yml index 02b0a7707..f3bb33d93 100644 --- a/.travis.yml +++ b/.travis.yml @@ -19,7 +19,7 @@ before_script: # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E901,E999,F821,F822,F823 --show-source --statistics # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - - flake8 . --count --exit-zero --max-complexity=32 --max-line-length=127 --ignore=F403,F405,E126,E127 --exclude=./pytorch/pytorchcv/models/others,./tensorflow_ --statistics + - flake8 . --count --exit-zero --max-complexity=32 --max-line-length=127 --ignore=F403,F405,E126,E127 --exclude=./pytorch/pytorchcv/models/others,./tensorflow_/others --statistics script: - true # pytest --capture=sys # add others tests here notifications: diff --git a/eval_tf.py b/eval_tf.py index ec591e736..8d56404a1 100644 --- a/eval_tf.py +++ b/eval_tf.py @@ -1,12 +1,14 @@ import argparse +import tqdm import time import logging -import mxnet as mx +from tensorpack.predict import PredictConfig, FeedfreePredictor +from tensorpack.utils.stats import RatioCounter +from tensorpack.input_source import QueueInput, StagingInput from common.logger_utils import initialize_logging -from tensorflow_.utils import prepare_tf_context, prepare_model, get_data_loader, calc_net_weight_count,\ - validate +from tensorflow_.utils import prepare_tf_context, prepare_model, get_data def parse_args(): @@ -78,29 +80,30 @@ def parse_args(): def test(net, - val_data, - batch_fn, - use_rec, - dtype, - ctx, - calc_weight_count=False, + session_init, + val_dataflow, extended_log=False): - acc_top1 = mx.metric.Accuracy() - acc_top5 = mx.metric.TopKAccuracy(5) + + pred_config = PredictConfig( + model=net, + session_init=session_init, + input_names=['input', 'label'], + output_names=['wrong-top1', 'wrong-top5'] + ) + err_top1 = RatioCounter() + err_top5 = RatioCounter() tic = time.time() - err_top1_val, err_top5_val = validate( - acc_top1=acc_top1, - acc_top5=acc_top5, - net=net, - val_data=val_data, - batch_fn=batch_fn, - use_rec=use_rec, - dtype=dtype, - ctx=ctx) - if calc_weight_count: - weight_count = calc_net_weight_count(net) - logging.info('Model: {} trainable parameters'.format(weight_count)) + pred = FeedfreePredictor(pred_config, StagingInput(QueueInput(val_dataflow), device='/gpu:0')) + for _ in tqdm.trange(val_dataflow.size()): + err_top1_val, err_top5_val = pred() + batch_size = err_top1_val.shape[0] + err_top1.feed(err_top1_val.sum(), batch_size) + err_top5.feed(err_top5_val.sum(), batch_size) + + err_top1_val = err_top1.ratio + err_top5_val = err_top5.ratio + if extended_log: logging.info('Test: err-top1={top1:.4f} ({top1})\terr-top5={top5:.4f} ({top5})'.format( top1=err_top1_val, top5=err_top5_val)) @@ -125,25 +128,20 @@ def main(): num_gpus=args.num_gpus, batch_size=args.batch_size) - net = prepare_model( + net, inputs_desc = prepare_model( model_name=args.model, pretrained_model_file_path=args.resume.strip()) - train_data, val_data, batch_fn = get_data_loader( - data_dir=args.data_dir, + val_dataflow = get_data( + is_train=False, batch_size=batch_size, - num_workers=args.num_workers) + data_dir_path=args.data_dir) assert (args.use_pretrained or args.resume.strip()) test( net=net, - val_data=val_data, - batch_fn=batch_fn, - use_rec=args.use_rec, - dtype=args.dtype, - # ctx=ctx, - # calc_weight_count=(not log_file_exist), - calc_weight_count=True, + session_init=inputs_desc, + val_dataflow=val_dataflow, extended_log=True) diff --git a/tensorflow_/models/shufflenet.py b/tensorflow_/models/shufflenet.py index 50aae9a3e..26d1fd029 100755 --- a/tensorflow_/models/shufflenet.py +++ b/tensorflow_/models/shufflenet.py @@ -1,25 +1,151 @@ -import argparse -import numpy as np -import math -#import os -import cv2 -#import tensorflow as tf +__all__ = ['ShufflenetModel'] + +import math +from abc import abstractmethod from tensorpack import * -from tensorpack.dataflow import imgaug -from tensorpack.tfutils import argscope, get_model_loader, model_utils +from tensorpack.tfutils import argscope from tensorpack.tfutils.scope_utils import under_name_scope -from tensorpack.utils.gpu import get_num_gpu +from tensorpack.tfutils.summary import add_moving_summary from tensorpack.utils import logger -from .imagenet_utils import get_imagenet_dataflow, ImageNetModel, GoogleNetResize, eval_on_ILSVRC12 + +class ImageNetModel(ModelDesc): + + def __init__(self, + **kwargs): + super(ImageNetModel, self).__init__(**kwargs) + + self.image_shape = 224 + + """ + uint8 instead of float32 is used as input type to reduce copy overhead. + It might hurt the performance a liiiitle bit. + The pretrained models were trained with float32. + """ + self.image_dtype = tf.uint8 + + """ + Either 'NCHW' or 'NHWC' + """ + self.data_format = 'NCHW' + + """ + Whether the image is BGR or RGB. If using DataFlow, then it should be BGR. + """ + self.image_bgr = True + + self.weight_decay = 1e-4 + + """ + To apply on normalization parameters, use '.*/W|.*/gamma|.*/beta' + """ + self.weight_decay_pattern = '.*/W' + + """ + Scale the loss, for whatever reasons (e.g., gradient averaging, fp16 training, etc) + """ + self.loss_scale = 1. + + """ + Label smoothing (See tf.losses.softmax_cross_entropy) + """ + self.label_smoothing = 0. + + def inputs(self): + return [tf.placeholder(self.image_dtype, [None, self.image_shape, self.image_shape, 3], 'input'), + tf.placeholder(tf.int32, [None], 'label')] + + def build_graph(self, image, label): + image = self.image_preprocess(image) + assert self.data_format in ['NCHW', 'NHWC'] + if self.data_format == 'NCHW': + image = tf.transpose(image, [0, 3, 1, 2]) + + logits = self.get_logits(image) + loss = ImageNetModel.compute_loss_and_error( + logits, label, label_smoothing=self.label_smoothing) + + if self.weight_decay > 0: + wd_loss = regularize_cost(self.weight_decay_pattern, + tf.contrib.layers.l2_regularizer(self.weight_decay), + name='l2_regularize_loss') + add_moving_summary(loss, wd_loss) + total_cost = tf.add_n([loss, wd_loss], name='cost') + else: + total_cost = tf.identity(loss, name='cost') + add_moving_summary(total_cost) + + if self.loss_scale != 1.: + logger.info("Scaling the total loss by {} ...".format(self.loss_scale)) + return total_cost * self.loss_scale + else: + return total_cost + + @abstractmethod + def get_logits(self, image): + """ + Args: + image: 4D tensor of ``self.input_shape`` in ``self.data_format`` + + Returns: + Nx#class logits + """ + + def optimizer(self): + lr = tf.get_variable('learning_rate', initializer=0.1, trainable=False) + tf.summary.scalar('learning_rate-summary', lr) + return tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True) + + def image_preprocess(self, image): + with tf.name_scope('image_preprocess'): + if image.dtype.base_dtype != tf.float32: + image = tf.cast(image, tf.float32) + mean = [0.485, 0.456, 0.406] # rgb + std = [0.229, 0.224, 0.225] + if self.image_bgr: + mean = mean[::-1] + std = std[::-1] + image_mean = tf.constant(mean, dtype=tf.float32) * 255. + image_std = tf.constant(std, dtype=tf.float32) * 255. + image = (image - image_mean) / image_std + return image + + @staticmethod + def compute_loss_and_error(logits, label, label_smoothing=0.): + if label_smoothing == 0.: + loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label) + else: + nclass = logits.shape[-1] + loss = tf.losses.softmax_cross_entropy( + tf.one_hot(label, nclass), + logits, label_smoothing=label_smoothing) + loss = tf.reduce_mean(loss, name='xentropy-loss') + + def prediction_incorrect(logits, label, topk=1, name='incorrect_vector'): + with tf.name_scope('prediction_incorrect'): + x = tf.logical_not(tf.nn.in_top_k(logits, label, topk)) + return tf.cast(x, tf.float32, name=name) + + wrong = prediction_incorrect(logits, label, 1, name='wrong-top1') + add_moving_summary(tf.reduce_mean(wrong, name='train-error-top1')) + + wrong = prediction_incorrect(logits, label, 5, name='wrong-top5') + add_moving_summary(tf.reduce_mean(wrong, name='train-error-top5')) + return loss @layer_register(log_shape=True) -def DepthConv(x, out_channel, kernel_shape, padding='SAME', stride=1, - W_init=None, activation=tf.identity): +def DepthConv(x, + out_channel, + kernel_shape, + padding='SAME', + stride=1, + W_init=None, + activation=tf.identity): + in_shape = x.get_shape().as_list() in_channel = in_shape[1] assert out_channel % in_channel == 0, (out_channel, in_channel) @@ -36,86 +162,106 @@ def DepthConv(x, out_channel, kernel_shape, padding='SAME', stride=1, @under_name_scope() -def channel_shuffle(l, group): - in_shape = l.get_shape().as_list() +def channel_shuffle(xl, + group): + + in_shape = xl.get_shape().as_list() in_channel = in_shape[1] assert in_channel % group == 0, in_channel - l = tf.reshape(l, [-1, in_channel // group, group] + in_shape[-2:]) - l = tf.transpose(l, [0, 2, 1, 3, 4]) - l = tf.reshape(l, [-1, in_channel] + in_shape[-2:]) - return l + xl = tf.reshape(xl, [-1, in_channel // group, group] + in_shape[-2:]) + xl = tf.transpose(xl, [0, 2, 1, 3, 4]) + xl = tf.reshape(xl, [-1, in_channel] + in_shape[-2:]) + return xl @layer_register() -def shufflenet_unit(l, out_channel, group, stride): - in_shape = l.get_shape().as_list() +def shufflenet_unit(xl, + out_channel, + group, + stride): + + in_shape = xl.get_shape().as_list() in_channel = in_shape[1] - shortcut = l + shortcut = xl # "We do not apply group convolution on the first pointwise layer # because the number of input channels is relatively small." first_split = group if in_channel > 24 else 1 - l = Conv2D('conv1', l, out_channel // 4, 1, split=first_split, activation=BNReLU) - l = channel_shuffle(l, group) - l = DepthConv('dconv', l, out_channel // 4, 3, stride=stride) - l = BatchNorm('dconv_bn', l) - - l = Conv2D('conv2', l, - out_channel if stride == 1 else out_channel - in_channel, - 1, split=group) - l = BatchNorm('conv2_bn', l) + xl = Conv2D('conv1', xl, out_channel // 4, 1, split=first_split, activation=BNReLU) + xl = channel_shuffle(xl, group) + xl = DepthConv('dconv', xl, out_channel // 4, 3, stride=stride) + xl = BatchNorm('dconv_bn', xl) + + xl = Conv2D('conv2', xl, + out_channel if stride == 1 else out_channel - in_channel, + 1, split=group) + xl = BatchNorm('conv2_bn', xl) if stride == 1: # unit (b) - output = tf.nn.relu(shortcut + l) + output = tf.nn.relu(shortcut + xl) else: # unit (c) shortcut = AvgPooling('avgpool', shortcut, 3, 2, padding='SAME') - output = tf.concat([shortcut, tf.nn.relu(l)], axis=1) + output = tf.concat([shortcut, tf.nn.relu(xl)], axis=1) return output @layer_register() -def shufflenet_unit_v2(l, out_channel, stride): +def shufflenet_unit_v2(xl, + out_channel, + stride): + if stride == 1: - shortcut, l = tf.split(l, 2, axis=1) + shortcut, xl = tf.split(xl, 2, axis=1) else: - shortcut, l = l, l + shortcut, xl = xl, xl shortcut_channel = int(shortcut.shape[1]) - l = Conv2D('conv1', l, out_channel // 2, 1, activation=BNReLU) - l = DepthConv('dconv', l, out_channel // 2, 3, stride=stride) - l = BatchNorm('dconv_bn', l) - l = Conv2D('conv2', l, out_channel - shortcut_channel, 1, activation=BNReLU) + xl = Conv2D('conv1', xl, out_channel // 2, 1, activation=BNReLU) + xl = DepthConv('dconv', xl, out_channel // 2, 3, stride=stride) + xl = BatchNorm('dconv_bn', xl) + xl = Conv2D('conv2', xl, out_channel - shortcut_channel, 1, activation=BNReLU) if stride == 2: shortcut = DepthConv('shortcut_dconv', shortcut, shortcut_channel, 3, stride=2) shortcut = BatchNorm('shortcut_dconv_bn', shortcut) shortcut = Conv2D('shortcut_conv', shortcut, shortcut_channel, 1, activation=BNReLU) - output = tf.concat([shortcut, l], axis=1) + output = tf.concat([shortcut, xl], axis=1) output = channel_shuffle(output, 2) return output @layer_register(log_shape=True) -def shufflenet_stage(input, channel, num_blocks, group): - l = input +def shufflenet_stage(input, channel, num_blocks, group, v2): + xl = input for i in range(num_blocks): name = 'block{}'.format(i) - if args.v2: - l = shufflenet_unit_v2(name, l, channel, 2 if i == 0 else 1) + if v2: + xl = shufflenet_unit_v2(name, xl, channel, 2 if i == 0 else 1) else: - l = shufflenet_unit(name, l, channel, group, 2 if i == 0 else 1) - return l + xl = shufflenet_unit(name, xl, channel, group, 2 if i == 0 else 1) + return xl class ShufflenetModel(ImageNetModel): - weight_decay = 4e-5 + + def __init__(self, + v2, + ratio, + group, + **kwargs): + super(ShufflenetModel, self).__init__(**kwargs) + + self.v2 = v2 + self.ratio = ratio + self.group = group + self.weight_decay = 4e-5 def get_logits(self, image): with argscope([Conv2D, MaxPooling, AvgPooling, GlobalAvgPooling, BatchNorm], data_format='channels_first'), \ argscope(Conv2D, use_bias=False): - group = args.group - if not args.v2: + group = self.group + if not self.v2: # Copied from the paper channels = { 3: [240, 480, 960], @@ -123,161 +269,34 @@ def get_logits(self, image): 8: [384, 768, 1536] } mul = group * 4 # #chan has to be a multiple of this number - channels = [int(math.ceil(x * args.ratio / mul) * mul) + channels = [int(math.ceil(x * self.ratio / mul) * mul) for x in channels[group]] # The first channel must be a multiple of group - first_chan = int(math.ceil(24 * args.ratio / group) * group) + first_chan = int(math.ceil(24 * self.ratio / group) * group) else: # Copied from the paper channels = { 0.5: [48, 96, 192], 1.: [116, 232, 464] - }[args.ratio] + }[self.ratio] first_chan = 24 logger.info("#Channels: " + str([first_chan] + channels)) - l = Conv2D('conv1', image, first_chan, 3, strides=2, activation=BNReLU) - l = MaxPooling('pool1', l, 3, 2, padding='SAME') + xl = Conv2D('conv1', image, first_chan, 3, strides=2, activation=BNReLU) + xl = MaxPooling('pool1', xl, 3, 2, padding='SAME') - l = shufflenet_stage('stage2', l, channels[0], 4, group) - l = shufflenet_stage('stage3', l, channels[1], 8, group) - l = shufflenet_stage('stage4', l, channels[2], 4, group) + xl = shufflenet_stage('stage2', xl, channels[0], 4, group, self.v2) + xl = shufflenet_stage('stage3', xl, channels[1], 8, group, self.v2) + xl = shufflenet_stage('stage4', xl, channels[2], 4, group, self.v2) - if args.v2: - l = Conv2D('conv5', l, 1024, 1, activation=BNReLU) + if self.v2: + xl = Conv2D('conv5', xl, 1024, 1, activation=BNReLU) - l = GlobalAvgPooling('gap', l) - logits = FullyConnected('linear', l, 1000) + xl = GlobalAvgPooling('gap', xl) + logits = FullyConnected('linear', xl, 1000) return logits -def get_data(name, batch): - isTrain = name == 'train' - - if isTrain: - augmentors = [ - # use lighter augs if model is too small - GoogleNetResize(crop_area_fraction=0.49 if args.ratio < 1 else 0.08), - imgaug.RandomOrderAug( - [imgaug.BrightnessScale((0.6, 1.4), clip=False), - imgaug.Contrast((0.6, 1.4), clip=False), - imgaug.Saturation(0.4, rgb=False), - # rgb-bgr conversion for the constants copied from fb.resnet.torch - imgaug.Lighting(0.1, - eigval=np.asarray( - [0.2175, 0.0188, 0.0045][::-1]) * 255.0, - eigvec=np.array( - [[-0.5675, 0.7192, 0.4009], - [-0.5808, -0.0045, -0.8140], - [-0.5836, -0.6948, 0.4203]], - dtype='float32')[::-1, ::-1] - )]), - imgaug.Flip(horiz=True), - ] - else: - augmentors = [ - imgaug.ResizeShortestEdge(256, cv2.INTER_CUBIC), - imgaug.CenterCrop((224, 224)), - ] - return get_imagenet_dataflow( - args.data, name, batch, augmentors) - - -def get_config(model, nr_tower): - batch = TOTAL_BATCH_SIZE // nr_tower - - logger.info("Running on {} towers. Batch size per tower: {}".format(nr_tower, batch)) - dataset_train = get_data('train', batch) - dataset_val = get_data('val', batch) - - step_size = 1280000 // TOTAL_BATCH_SIZE - max_iter = 3 * 10**5 - max_epoch = (max_iter // step_size) + 1 - callbacks = [ - ModelSaver(), - ScheduledHyperParamSetter('learning_rate', - [(0, 0.5), (max_iter, 0)], - interp='linear', step_based=True), - EstimatedTimeLeft() - ] - infs = [ClassificationError('wrong-top1', 'val-error-top1'), - ClassificationError('wrong-top5', 'val-error-top5')] - if nr_tower == 1: - # single-GPU inference with queue prefetch - callbacks.append(InferenceRunner(QueueInput(dataset_val), infs)) - else: - # multi-GPU inference (with mandatory queue prefetch) - callbacks.append(DataParallelInferenceRunner( - dataset_val, infs, list(range(nr_tower)))) - - return TrainConfig( - model=model, - dataflow=dataset_train, - callbacks=callbacks, - steps_per_epoch=step_size, - max_epoch=max_epoch, - ) - - if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('--gpu', default='0', help='comma separated list of GPU(s) to use.') - parser.add_argument('--data', default='../../imgclsmob_data/imagenet/', help='ILSVRC dataset dir') - parser.add_argument('-r', '--ratio', type=float, default=0.5, choices=[1., 0.5]) - parser.add_argument('--group', type=int, default=8, choices=[3, 4, 8], - help="Number of groups for ShuffleNetV1") - parser.add_argument('--v2', action='store_true', help='Use ShuffleNetV2') - parser.add_argument('--batch', type=int, default=1024, help='total batch size') - parser.add_argument('--load', help='path to load a model from') - parser.add_argument('--eval', action='store_true') - parser.add_argument('--flops', action='store_true', help='print flops and exit') - args = parser.parse_args() - - if args.gpu: - os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu - - if args.v2 and args.group != parser.get_default('group'): - logger.error("group= is not used in ShuffleNetV2!") - - if args.batch != 1024: - logger.warn("Total batch size != 1024, you need to change other hyperparameters to get the same results.") - TOTAL_BATCH_SIZE = args.batch - - model = ShufflenetModel() - - if args.eval: - batch = 128 # something that can run on one gpu - ds = get_data('val', batch) - eval_on_ILSVRC12(model, get_model_loader(args.load), ds) - elif args.flops: - # manually build the graph with batch=1 - input_desc = [ - InputDesc(tf.float32, [1, 224, 224, 3], 'input'), - InputDesc(tf.int32, [1], 'label') - ] - input = PlaceholderInput() - input.setup(input_desc) - with TowerContext('', is_training=False): - model.build_graph(*input.get_input_tensors()) - model_utils.describe_trainable_vars() - - tf.profiler.profile( - tf.get_default_graph(), - cmd='op', - options=tf.profiler.ProfileOptionBuilder.float_operation()) - logger.info("Note that TensorFlow counts flops in a different way from the paper.") - logger.info("TensorFlow counts multiply+add as two flops, however the paper counts them " - "as 1 flop because it can be executed in one instruction.") - else: - if args.v2: - name = "ShuffleNetV2-{}x".format(args.ratio) - else: - name = "ShuffleNetV1-{}x-g{}".format(args.ratio, args.group) - logger.set_logger_dir(os.path.join('train_log', name)) - - nr_tower = max(get_num_gpu(), 1) - config = get_config(model, nr_tower) - if args.load: - config.session_init = get_model_loader(args.load) - launch_train_with_config(config, SyncMultiGPUTrainerParameterServer(nr_tower)) + pass diff --git a/tensorflow_/utils.py b/tensorflow_/utils.py index e141b717b..4cfe63afb 100644 --- a/tensorflow_/utils.py +++ b/tensorflow_/utils.py @@ -1,11 +1,81 @@ import logging import os +import multiprocessing +import numpy as np +import cv2 from tensorpack.tfutils import get_model_loader +from tensorpack.dataflow import imgaug, dataset, AugmentImageComponent, PrefetchDataZMQ, MultiThreadMapData, BatchData from .model_provider import get_model +class GoogleNetResize(imgaug.ImageAugmentor): + """ + crop 8%~100% of the original image + See `Going Deeper with Convolutions` by Google. + """ + def __init__(self, crop_area_fraction=0.08, + aspect_ratio_low=0.75, aspect_ratio_high=1.333, + target_shape=224): + self._init(locals()) + + def _augment(self, img, _): + h, w = img.shape[:2] + area = h * w + for _ in range(10): + targetArea = self.rng.uniform(self.crop_area_fraction, 1.0) * area + aspectR = self.rng.uniform(self.aspect_ratio_low, self.aspect_ratio_high) + ww = int(np.sqrt(targetArea * aspectR) + 0.5) + hh = int(np.sqrt(targetArea / aspectR) + 0.5) + if self.rng.uniform() < 0.5: + ww, hh = hh, ww + if hh <= h and ww <= w: + x1 = 0 if w == ww else self.rng.randint(0, w - ww) + y1 = 0 if h == hh else self.rng.randint(0, h - hh) + out = img[y1:y1 + hh, x1:x1 + ww] + out = cv2.resize(out, (self.target_shape, self.target_shape), interpolation=cv2.INTER_CUBIC) + return out + out = imgaug.ResizeShortestEdge(self.target_shape, interp=cv2.INTER_CUBIC).augment(img) + out = imgaug.CenterCrop(self.target_shape).augment(out) + return out + + +def get_imagenet_dataflow(datadir, + is_train, + batch_size, + augmentors, + parallel=None): + """ + See explanations in the tutorial: + http://tensorpack.readthedocs.io/en/latest/tutorial/efficient-dataflow.html + """ + assert datadir is not None + assert isinstance(augmentors, list) + if parallel is None: + parallel = min(40, multiprocessing.cpu_count() // 2) # assuming hyperthreading + if is_train: + ds = dataset.ILSVRC12(datadir, 'train', shuffle=True) + ds = AugmentImageComponent(ds, augmentors, copy=False) + if parallel < 16: + logging.warning("DataFlow may become the bottleneck when too few processes are used.") + ds = PrefetchDataZMQ(ds, parallel) + ds = BatchData(ds, batch_size, remainder=False) + else: + ds = dataset.ILSVRC12Files(datadir, 'val', shuffle=False) + aug = imgaug.AugmentorList(augmentors) + + def mapf(dp): + fname, cls = dp + im = cv2.imread(fname, cv2.IMREAD_COLOR) + im = aug.augment(im) + return im, cls + ds = MultiThreadMapData(ds, parallel, mapf, buffer_size=2000, strict=True) + ds = BatchData(ds, batch_size, remainder=True) + ds = PrefetchDataZMQ(ds, 1) + return ds + + def prepare_tf_context(num_gpus, batch_size): batch_size *= max(1, num_gpus) @@ -24,3 +94,35 @@ def prepare_model(model_name, inputs_desc = get_model_loader(pretrained_model_file_path) return net, inputs_desc + + +def get_data(is_train, + batch_size, + data_dir_path): + + if is_train: + augmentors = [ + GoogleNetResize(crop_area_fraction=0.08), + imgaug.RandomOrderAug([ + imgaug.BrightnessScale((0.6, 1.4), clip=False), + imgaug.Contrast((0.6, 1.4), clip=False), + imgaug.Saturation(0.4, rgb=False), + # rgb-bgr conversion for the constants copied from fb.resnet.torch + imgaug.Lighting( + 0.1, + eigval=np.asarray([0.2175, 0.0188, 0.0045][::-1]) * 255.0, + eigvec=np.array([ + [-0.5675, 0.7192, 0.4009], + [-0.5808, -0.0045, -0.8140], + [-0.5836, -0.6948, 0.4203]], dtype='float32')[::-1, ::-1])]), + imgaug.Flip(horiz=True)] + else: + augmentors = [ + imgaug.ResizeShortestEdge(256, cv2.INTER_CUBIC), + imgaug.CenterCrop((224, 224))] + + return get_imagenet_dataflow( + datadir=data_dir_path, + is_train=is_train, + batch_size=batch_size, + augmentors=augmentors)