From 3d6de9563d05d660d5203da4d4a81b64087d4904 Mon Sep 17 00:00:00 2001 From: osmr Date: Sat, 13 Oct 2018 17:36:48 +0300 Subject: [PATCH] Work on TF-convert script for ResNet --- convert_models.py | 63 ++++++++ eval_tf.py | 2 +- tensorflow_/model_provider.py | 4 +- tensorflow_/utils.py | 293 ++-------------------------------- tensorflow_/utils_tp.py | 291 +++++++++++++++++++++++++++++++++ train_tf.py | 2 +- 6 files changed, 369 insertions(+), 286 deletions(-) create mode 100644 tensorflow_/utils_tp.py diff --git a/convert_models.py b/convert_models.py index ec4934083..7acc45e43 100644 --- a/convert_models.py +++ b/convert_models.py @@ -203,6 +203,15 @@ def prepare_dst_model(dst_fwk, if weight.name: dst_params.setdefault(weight.name, []).append(weight) dst_params[weight.name] = (layer, weight) + elif dst_fwk == "tensorflow": + import tensorflow as tf + from tensorflow_.utils import prepare_model as prepare_model_tf + dst_net = prepare_model_tf( + model_name=dst_model, + classes=num_classes, + use_pretrained=False) + dst_param_keys = [v.name for v in tf.global_variables()] + dst_params = {v.name: v for v in tf.global_variables()} else: raise ValueError("Unsupported dst fwk: {}".format(dst_fwk)) @@ -443,6 +452,50 @@ def process_width(src_key, dst_key, src_weight): dst_net.save_weights(dst_params_file_path) +def convert_gl2tf(dst_net, + dst_params_file_path, + dst_params, + dst_param_keys, + src_params, + src_param_keys): + dst_param_keys = [key.replace('/kernel:', '/weight:') for key in dst_param_keys] + + src_param_keys.sort() + src_param_keys.sort(key=lambda var: ['{:10}'.format(int(x)) if + x.isdigit() else x for x in re.findall(r'[^0-9]|[0-9]+', var)]) + + dst_param_keys.sort() + dst_param_keys.sort(key=lambda var: ['{:10}'.format(int(x)) if + x.isdigit() else x for x in re.findall(r'[^0-9]|[0-9]+', var)]) + + dst_param_keys = [key.replace('/weight:', '/kernel:') for key in dst_param_keys] + + import tensorflow as tf + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + for i, (src_key, dst_key) in enumerate(zip(src_param_keys, dst_param_keys)): + # assert (tuple(dst_params[dst_key].get_shape().as_list()) == src_params[src_key].shape) + src_value = src_params[src_key]._data[0].asnumpy() + if len(src_value.shape) == 4: + # if not (tuple(dst_params[dst_key].get_shape().as_list()[::-1]) == src_params[src_key].shape): + # a = 1 + assert (tuple(dst_params[dst_key].get_shape().as_list()[::-1]) == src_params[src_key].shape) + src_value = np.transpose(src_params[src_key]._data[0].asnumpy(), axes=(3, 2, 1, 0)) + elif len(src_value.shape) == 2: + assert (tuple(dst_params[dst_key].get_shape().as_list()[::-1]) == src_params[src_key].shape) + src_value = np.transpose(src_params[src_key]._data[0].asnumpy(), axes=(1, 0)) + else: + # if not (tuple(dst_params[dst_key].get_shape().as_list()) == src_params[src_key].shape): + # a = 1 + assert (tuple(dst_params[dst_key].get_shape().as_list()) == src_params[src_key].shape) + sess.run(dst_params[dst_key].assign(src_value)) + # print(dst_params[dst_key].eval(sess)) + saver = tf.train.Saver() + saver.save( + sess=sess, + save_path=dst_params_file_path) + + def convert_pt2pt(dst_params_file_path, dst_params, dst_param_keys, @@ -589,6 +642,14 @@ def main(): dst_param_keys=dst_param_keys, src_params=src_params, src_param_keys=src_param_keys) + elif args.src_fwk == "gluon" and args.dst_fwk == "tensorflow": + convert_gl2tf( + dst_net=dst_net, + dst_params_file_path=args.dst_params, + dst_params=dst_params, + dst_param_keys=dst_param_keys, + src_params=src_params, + src_param_keys=src_param_keys) elif args.src_fwk == "pytorch" and args.dst_fwk == "gluon": convert_pt2gl( dst_net=dst_net, @@ -608,6 +669,8 @@ def main(): src_arg_params=src_arg_params, src_model=args.src_model, ctx=ctx) + else: + raise NotImplementedError logging.info('Convert {}-model {} into {}-model {}'.format( args.src_fwk, args.src_model, args.dst_fwk, args.dst_model)) diff --git a/eval_tf.py b/eval_tf.py index 1b9e91ff2..daebe8ff6 100644 --- a/eval_tf.py +++ b/eval_tf.py @@ -8,7 +8,7 @@ from tensorpack.input_source import QueueInput, StagingInput from common.logger_utils import initialize_logging -from tensorflow_.utils import prepare_tf_context, prepare_model, get_data, calc_flops +from tensorflow_.utils_tp import prepare_tf_context, prepare_model, get_data, calc_flops def parse_args(): diff --git a/tensorflow_/model_provider.py b/tensorflow_/model_provider.py index 0b06486d2..020cb04a4 100644 --- a/tensorflow_/model_provider.py +++ b/tensorflow_/model_provider.py @@ -1,5 +1,5 @@ from .models.resnet import * -from tensorflow_.models.others.shufflenetv2 import * +# from tensorflow_.models.others.shufflenetv2 import * __all__ = ['get_model'] @@ -35,7 +35,7 @@ 'seresnet200': seresnet200, 'seresnet200b': seresnet200b, - 'shufflenetv2_wd2': shufflenetv2_wd2, + # 'shufflenetv2_wd2': shufflenetv2_wd2, } diff --git a/tensorflow_/utils.py b/tensorflow_/utils.py index 815e567d8..bd3381991 100644 --- a/tensorflow_/utils.py +++ b/tensorflow_/utils.py @@ -1,291 +1,20 @@ -import logging -import os -import multiprocessing -import numpy as np -import cv2 - import tensorflow as tf -from tensorpack.models import regularize_cost -from tensorpack.tfutils.summary import add_moving_summary -from tensorpack import ModelDesc -from tensorpack import InputDesc, PlaceholderInput, TowerContext -from tensorpack.tfutils import get_model_loader, model_utils -from tensorpack.dataflow import imgaug, dataset, AugmentImageComponent, PrefetchDataZMQ, MultiThreadMapData, BatchData -from tensorpack.utils import logger from .model_provider import get_model -class ImageNetModel(ModelDesc): - - def __init__(self, - model_lambda, - **kwargs): - super(ImageNetModel, self).__init__(**kwargs) - self.model_lambda = model_lambda - self.image_shape = 224 - - """ - uint8 instead of float32 is used as input type to reduce copy overhead. - It might hurt the performance a liiiitle bit. - The pretrained models were trained with float32. - """ - self.image_dtype = tf.uint8 - - """ - Either 'NCHW' or 'NHWC' - """ - self.data_format = 'NCHW' - - """ - Whether the image is BGR or RGB. If using DataFlow, then it should be BGR. - """ - self.image_bgr = True - - self.weight_decay = 1e-4 - - """ - To apply on normalization parameters, use '.*/W|.*/gamma|.*/beta' - """ - self.weight_decay_pattern = '.*/W' - - """ - Scale the loss, for whatever reasons (e.g., gradient averaging, fp16 training, etc) - """ - self.loss_scale = 1.0 - - """ - Label smoothing (See tf.losses.softmax_cross_entropy) - """ - self.label_smoothing = 0.0 - - def inputs(self): - return [tf.placeholder(self.image_dtype, [None, self.image_shape, self.image_shape, 3], 'input'), - tf.placeholder(tf.int32, [None], 'label')] - - def build_graph(self, - image, - label): - - image = self.image_preprocess(image) - assert self.data_format in ['NCHW', 'NHWC'] - if self.data_format == 'NCHW': - image = tf.transpose(image, [0, 3, 1, 2]) - - logits = self.get_logits(image) - loss = ImageNetModel.compute_loss_and_error( - logits, label, label_smoothing=self.label_smoothing) - - if self.weight_decay > 0: - wd_loss = regularize_cost(self.weight_decay_pattern, - tf.contrib.layers.l2_regularizer(self.weight_decay), - name='l2_regularize_loss') - add_moving_summary(loss, wd_loss) - total_cost = tf.add_n([loss, wd_loss], name='cost') - else: - total_cost = tf.identity(loss, name='cost') - add_moving_summary(total_cost) - - if self.loss_scale != 1.: - logger.info("Scaling the total loss by {} ...".format(self.loss_scale)) - return total_cost * self.loss_scale - else: - return total_cost - - def get_logits(self, - image): - """ - Args: - image: 4D tensor of ``self.input_shape`` in ``self.data_format`` - - Returns: - Nx#class logits - """ - return self.model_lambda(image) - - def optimizer(self): - lr = tf.get_variable('learning_rate', initializer=0.1, trainable=False) - tf.summary.scalar('learning_rate-summary', lr) - return tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True) - - def image_preprocess(self, - image): - - with tf.name_scope('image_preprocess'): - if image.dtype.base_dtype != tf.float32: - image = tf.cast(image, tf.float32) - mean = [0.485, 0.456, 0.406] # rgb - std = [0.229, 0.224, 0.225] - if self.image_bgr: - mean = mean[::-1] - std = std[::-1] - image_mean = tf.constant(mean, dtype=tf.float32) * 255. - image_std = tf.constant(std, dtype=tf.float32) * 255. - image = (image - image_mean) / image_std - return image - - @staticmethod - def compute_loss_and_error(logits, - label, - label_smoothing=0.0): - - if label_smoothing == 0.0: - loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label) - else: - nclass = logits.shape[-1] - loss = tf.losses.softmax_cross_entropy( - tf.one_hot(label, nclass), - logits, label_smoothing=label_smoothing) - loss = tf.reduce_mean(loss, name='xentropy-loss') - - def prediction_incorrect(logits, label, topk=1, name='incorrect_vector'): - with tf.name_scope('prediction_incorrect'): - x = tf.logical_not(tf.nn.in_top_k(logits, label, topk)) - return tf.cast(x, tf.float32, name=name) - - wrong = prediction_incorrect(logits, label, 1, name='wrong-top1') - add_moving_summary(tf.reduce_mean(wrong, name='train-error-top1')) - - wrong = prediction_incorrect(logits, label, 5, name='wrong-top5') - add_moving_summary(tf.reduce_mean(wrong, name='train-error-top5')) - return loss - - -class GoogleNetResize(imgaug.ImageAugmentor): - """ - crop 8%~100% of the original image - See `Going Deeper with Convolutions` by Google. - """ - def __init__(self, crop_area_fraction=0.08, - aspect_ratio_low=0.75, aspect_ratio_high=1.333, - target_shape=224): - self._init(locals()) - - def _augment(self, img, _): - h, w = img.shape[:2] - area = h * w - for _ in range(10): - targetArea = self.rng.uniform(self.crop_area_fraction, 1.0) * area - aspectR = self.rng.uniform(self.aspect_ratio_low, self.aspect_ratio_high) - ww = int(np.sqrt(targetArea * aspectR) + 0.5) - hh = int(np.sqrt(targetArea / aspectR) + 0.5) - if self.rng.uniform() < 0.5: - ww, hh = hh, ww - if hh <= h and ww <= w: - x1 = 0 if w == ww else self.rng.randint(0, w - ww) - y1 = 0 if h == hh else self.rng.randint(0, h - hh) - out = img[y1:y1 + hh, x1:x1 + ww] - out = cv2.resize(out, (self.target_shape, self.target_shape), interpolation=cv2.INTER_CUBIC) - return out - out = imgaug.ResizeShortestEdge(self.target_shape, interp=cv2.INTER_CUBIC).augment(img) - out = imgaug.CenterCrop(self.target_shape).augment(out) - return out - - -def get_imagenet_dataflow(datadir, - is_train, - batch_size, - augmentors, - parallel=None): - """ - See explanations in the tutorial: - http://tensorpack.readthedocs.io/en/latest/tutorial/efficient-dataflow.html - """ - assert datadir is not None - assert isinstance(augmentors, list) - if parallel is None: - parallel = min(40, multiprocessing.cpu_count() // 2) # assuming hyperthreading - if is_train: - ds = dataset.ILSVRC12(datadir, 'train', shuffle=True) - ds = AugmentImageComponent(ds, augmentors, copy=False) - if parallel < 16: - logging.warning("DataFlow may become the bottleneck when too few processes are used.") - ds = PrefetchDataZMQ(ds, parallel) - ds = BatchData(ds, batch_size, remainder=False) - else: - ds = dataset.ILSVRC12Files(datadir, 'val', shuffle=False) - aug = imgaug.AugmentorList(augmentors) - - def mapf(dp): - fname, cls = dp - im = cv2.imread(fname, cv2.IMREAD_COLOR) - im = aug.augment(im) - return im, cls - ds = MultiThreadMapData(ds, parallel, mapf, buffer_size=2000, strict=True) - ds = BatchData(ds, batch_size, remainder=True) - ds = PrefetchDataZMQ(ds, 1) - return ds - - -def prepare_tf_context(num_gpus, - batch_size): - batch_size *= max(1, num_gpus) - return batch_size - - def prepare_model(model_name, - pretrained_model_file_path): - - net = get_model(model_name) - net = ImageNetModel(model_lambda=net) - - inputs_desc = None - if pretrained_model_file_path: - assert (os.path.isfile(pretrained_model_file_path)) - logging.info('Loading model: {}'.format(pretrained_model_file_path)) - inputs_desc = get_model_loader(pretrained_model_file_path) - - return net, inputs_desc - - -def get_data(is_train, - batch_size, - data_dir_path): - - if is_train: - augmentors = [ - GoogleNetResize(crop_area_fraction=0.08), - imgaug.RandomOrderAug([ - imgaug.BrightnessScale((0.6, 1.4), clip=False), - imgaug.Contrast((0.6, 1.4), clip=False), - imgaug.Saturation(0.4, rgb=False), - # rgb-bgr conversion for the constants copied from fb.resnet.torch - imgaug.Lighting( - 0.1, - eigval=np.asarray([0.2175, 0.0188, 0.0045][::-1]) * 255.0, - eigvec=np.array([ - [-0.5675, 0.7192, 0.4009], - [-0.5808, -0.0045, -0.8140], - [-0.5836, -0.6948, 0.4203]], dtype='float32')[::-1, ::-1])]), - imgaug.Flip(horiz=True)] - else: - augmentors = [ - imgaug.ResizeShortestEdge(256, cv2.INTER_CUBIC), - imgaug.CenterCrop((224, 224))] - - return get_imagenet_dataflow( - datadir=data_dir_path, - is_train=is_train, - batch_size=batch_size, - augmentors=augmentors) + classes, + use_pretrained): + kwargs = {'pretrained': use_pretrained, + 'classes': classes} + net = get_model(model_name, **kwargs) -def calc_flops(model): - # manually build the graph with batch=1 - input_desc = [ - InputDesc(tf.float32, [1, 224, 224, 3], 'input'), - InputDesc(tf.int32, [1], 'label') - ] - input = PlaceholderInput() - input.setup(input_desc) - with TowerContext('', is_training=False): - model.build_graph(*input.get_input_tensors()) - model_utils.describe_trainable_vars() + x = tf.placeholder( + dtype=tf.float32, + shape=(None, 3, 224, 224), + name='xx') + y_net = net(x) - tf.profiler.profile( - tf.get_default_graph(), - cmd='op', - options=tf.profiler.ProfileOptionBuilder.float_operation()) - logger.info("Note that TensorFlow counts flops in a different way from the paper.") - logger.info("TensorFlow counts multiply+add as two flops, however the paper counts them " - "as 1 flop because it can be executed in one instruction.") + return y_net diff --git a/tensorflow_/utils_tp.py b/tensorflow_/utils_tp.py new file mode 100644 index 000000000..815e567d8 --- /dev/null +++ b/tensorflow_/utils_tp.py @@ -0,0 +1,291 @@ +import logging +import os +import multiprocessing +import numpy as np +import cv2 + +import tensorflow as tf +from tensorpack.models import regularize_cost +from tensorpack.tfutils.summary import add_moving_summary +from tensorpack import ModelDesc +from tensorpack import InputDesc, PlaceholderInput, TowerContext +from tensorpack.tfutils import get_model_loader, model_utils +from tensorpack.dataflow import imgaug, dataset, AugmentImageComponent, PrefetchDataZMQ, MultiThreadMapData, BatchData +from tensorpack.utils import logger + +from .model_provider import get_model + + +class ImageNetModel(ModelDesc): + + def __init__(self, + model_lambda, + **kwargs): + super(ImageNetModel, self).__init__(**kwargs) + self.model_lambda = model_lambda + self.image_shape = 224 + + """ + uint8 instead of float32 is used as input type to reduce copy overhead. + It might hurt the performance a liiiitle bit. + The pretrained models were trained with float32. + """ + self.image_dtype = tf.uint8 + + """ + Either 'NCHW' or 'NHWC' + """ + self.data_format = 'NCHW' + + """ + Whether the image is BGR or RGB. If using DataFlow, then it should be BGR. + """ + self.image_bgr = True + + self.weight_decay = 1e-4 + + """ + To apply on normalization parameters, use '.*/W|.*/gamma|.*/beta' + """ + self.weight_decay_pattern = '.*/W' + + """ + Scale the loss, for whatever reasons (e.g., gradient averaging, fp16 training, etc) + """ + self.loss_scale = 1.0 + + """ + Label smoothing (See tf.losses.softmax_cross_entropy) + """ + self.label_smoothing = 0.0 + + def inputs(self): + return [tf.placeholder(self.image_dtype, [None, self.image_shape, self.image_shape, 3], 'input'), + tf.placeholder(tf.int32, [None], 'label')] + + def build_graph(self, + image, + label): + + image = self.image_preprocess(image) + assert self.data_format in ['NCHW', 'NHWC'] + if self.data_format == 'NCHW': + image = tf.transpose(image, [0, 3, 1, 2]) + + logits = self.get_logits(image) + loss = ImageNetModel.compute_loss_and_error( + logits, label, label_smoothing=self.label_smoothing) + + if self.weight_decay > 0: + wd_loss = regularize_cost(self.weight_decay_pattern, + tf.contrib.layers.l2_regularizer(self.weight_decay), + name='l2_regularize_loss') + add_moving_summary(loss, wd_loss) + total_cost = tf.add_n([loss, wd_loss], name='cost') + else: + total_cost = tf.identity(loss, name='cost') + add_moving_summary(total_cost) + + if self.loss_scale != 1.: + logger.info("Scaling the total loss by {} ...".format(self.loss_scale)) + return total_cost * self.loss_scale + else: + return total_cost + + def get_logits(self, + image): + """ + Args: + image: 4D tensor of ``self.input_shape`` in ``self.data_format`` + + Returns: + Nx#class logits + """ + return self.model_lambda(image) + + def optimizer(self): + lr = tf.get_variable('learning_rate', initializer=0.1, trainable=False) + tf.summary.scalar('learning_rate-summary', lr) + return tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True) + + def image_preprocess(self, + image): + + with tf.name_scope('image_preprocess'): + if image.dtype.base_dtype != tf.float32: + image = tf.cast(image, tf.float32) + mean = [0.485, 0.456, 0.406] # rgb + std = [0.229, 0.224, 0.225] + if self.image_bgr: + mean = mean[::-1] + std = std[::-1] + image_mean = tf.constant(mean, dtype=tf.float32) * 255. + image_std = tf.constant(std, dtype=tf.float32) * 255. + image = (image - image_mean) / image_std + return image + + @staticmethod + def compute_loss_and_error(logits, + label, + label_smoothing=0.0): + + if label_smoothing == 0.0: + loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label) + else: + nclass = logits.shape[-1] + loss = tf.losses.softmax_cross_entropy( + tf.one_hot(label, nclass), + logits, label_smoothing=label_smoothing) + loss = tf.reduce_mean(loss, name='xentropy-loss') + + def prediction_incorrect(logits, label, topk=1, name='incorrect_vector'): + with tf.name_scope('prediction_incorrect'): + x = tf.logical_not(tf.nn.in_top_k(logits, label, topk)) + return tf.cast(x, tf.float32, name=name) + + wrong = prediction_incorrect(logits, label, 1, name='wrong-top1') + add_moving_summary(tf.reduce_mean(wrong, name='train-error-top1')) + + wrong = prediction_incorrect(logits, label, 5, name='wrong-top5') + add_moving_summary(tf.reduce_mean(wrong, name='train-error-top5')) + return loss + + +class GoogleNetResize(imgaug.ImageAugmentor): + """ + crop 8%~100% of the original image + See `Going Deeper with Convolutions` by Google. + """ + def __init__(self, crop_area_fraction=0.08, + aspect_ratio_low=0.75, aspect_ratio_high=1.333, + target_shape=224): + self._init(locals()) + + def _augment(self, img, _): + h, w = img.shape[:2] + area = h * w + for _ in range(10): + targetArea = self.rng.uniform(self.crop_area_fraction, 1.0) * area + aspectR = self.rng.uniform(self.aspect_ratio_low, self.aspect_ratio_high) + ww = int(np.sqrt(targetArea * aspectR) + 0.5) + hh = int(np.sqrt(targetArea / aspectR) + 0.5) + if self.rng.uniform() < 0.5: + ww, hh = hh, ww + if hh <= h and ww <= w: + x1 = 0 if w == ww else self.rng.randint(0, w - ww) + y1 = 0 if h == hh else self.rng.randint(0, h - hh) + out = img[y1:y1 + hh, x1:x1 + ww] + out = cv2.resize(out, (self.target_shape, self.target_shape), interpolation=cv2.INTER_CUBIC) + return out + out = imgaug.ResizeShortestEdge(self.target_shape, interp=cv2.INTER_CUBIC).augment(img) + out = imgaug.CenterCrop(self.target_shape).augment(out) + return out + + +def get_imagenet_dataflow(datadir, + is_train, + batch_size, + augmentors, + parallel=None): + """ + See explanations in the tutorial: + http://tensorpack.readthedocs.io/en/latest/tutorial/efficient-dataflow.html + """ + assert datadir is not None + assert isinstance(augmentors, list) + if parallel is None: + parallel = min(40, multiprocessing.cpu_count() // 2) # assuming hyperthreading + if is_train: + ds = dataset.ILSVRC12(datadir, 'train', shuffle=True) + ds = AugmentImageComponent(ds, augmentors, copy=False) + if parallel < 16: + logging.warning("DataFlow may become the bottleneck when too few processes are used.") + ds = PrefetchDataZMQ(ds, parallel) + ds = BatchData(ds, batch_size, remainder=False) + else: + ds = dataset.ILSVRC12Files(datadir, 'val', shuffle=False) + aug = imgaug.AugmentorList(augmentors) + + def mapf(dp): + fname, cls = dp + im = cv2.imread(fname, cv2.IMREAD_COLOR) + im = aug.augment(im) + return im, cls + ds = MultiThreadMapData(ds, parallel, mapf, buffer_size=2000, strict=True) + ds = BatchData(ds, batch_size, remainder=True) + ds = PrefetchDataZMQ(ds, 1) + return ds + + +def prepare_tf_context(num_gpus, + batch_size): + batch_size *= max(1, num_gpus) + return batch_size + + +def prepare_model(model_name, + pretrained_model_file_path): + + net = get_model(model_name) + net = ImageNetModel(model_lambda=net) + + inputs_desc = None + if pretrained_model_file_path: + assert (os.path.isfile(pretrained_model_file_path)) + logging.info('Loading model: {}'.format(pretrained_model_file_path)) + inputs_desc = get_model_loader(pretrained_model_file_path) + + return net, inputs_desc + + +def get_data(is_train, + batch_size, + data_dir_path): + + if is_train: + augmentors = [ + GoogleNetResize(crop_area_fraction=0.08), + imgaug.RandomOrderAug([ + imgaug.BrightnessScale((0.6, 1.4), clip=False), + imgaug.Contrast((0.6, 1.4), clip=False), + imgaug.Saturation(0.4, rgb=False), + # rgb-bgr conversion for the constants copied from fb.resnet.torch + imgaug.Lighting( + 0.1, + eigval=np.asarray([0.2175, 0.0188, 0.0045][::-1]) * 255.0, + eigvec=np.array([ + [-0.5675, 0.7192, 0.4009], + [-0.5808, -0.0045, -0.8140], + [-0.5836, -0.6948, 0.4203]], dtype='float32')[::-1, ::-1])]), + imgaug.Flip(horiz=True)] + else: + augmentors = [ + imgaug.ResizeShortestEdge(256, cv2.INTER_CUBIC), + imgaug.CenterCrop((224, 224))] + + return get_imagenet_dataflow( + datadir=data_dir_path, + is_train=is_train, + batch_size=batch_size, + augmentors=augmentors) + + +def calc_flops(model): + # manually build the graph with batch=1 + input_desc = [ + InputDesc(tf.float32, [1, 224, 224, 3], 'input'), + InputDesc(tf.int32, [1], 'label') + ] + input = PlaceholderInput() + input.setup(input_desc) + with TowerContext('', is_training=False): + model.build_graph(*input.get_input_tensors()) + model_utils.describe_trainable_vars() + + tf.profiler.profile( + tf.get_default_graph(), + cmd='op', + options=tf.profiler.ProfileOptionBuilder.float_operation()) + logger.info("Note that TensorFlow counts flops in a different way from the paper.") + logger.info("TensorFlow counts multiply+add as two flops, however the paper counts them " + "as 1 flop because it can be executed in one instruction.") diff --git a/train_tf.py b/train_tf.py index b597ab6c5..490f683e5 100644 --- a/train_tf.py +++ b/train_tf.py @@ -9,7 +9,7 @@ DataParallelInferenceRunner, TrainConfig, SyncMultiGPUTrainerParameterServer, launch_train_with_config from common.logger_utils import initialize_logging -from tensorflow_.utils import prepare_tf_context, prepare_model, get_data +from tensorflow_.utils_tp import prepare_tf_context, prepare_model, get_data def parse_args():