From 7f99c7af1bf99ccefaee872c4269392cbde78462 Mon Sep 17 00:00:00 2001 From: osmr Date: Mon, 15 Oct 2018 17:46:04 +0300 Subject: [PATCH] Work on TF branch, 2 --- convert_models.py | 12 ++-- eval_tf.py | 3 + tensorflow_/models/model_store.py | 67 +++++++++++++++++++- tensorflow_/models/others/resnet_.py | 2 +- tensorflow_/models/resnet.py | 94 +++++++++++++++++++++++----- tensorflow_/utils.py | 8 +++ tensorflow_/utils_tp.py | 6 +- train_tf.py | 3 + 8 files changed, 172 insertions(+), 23 deletions(-) diff --git a/convert_models.py b/convert_models.py index 7acc45e43..b72591788 100644 --- a/convert_models.py +++ b/convert_models.py @@ -480,7 +480,7 @@ def convert_gl2tf(dst_net, # if not (tuple(dst_params[dst_key].get_shape().as_list()[::-1]) == src_params[src_key].shape): # a = 1 assert (tuple(dst_params[dst_key].get_shape().as_list()[::-1]) == src_params[src_key].shape) - src_value = np.transpose(src_params[src_key]._data[0].asnumpy(), axes=(3, 2, 1, 0)) + src_value = np.transpose(src_params[src_key]._data[0].asnumpy(), axes=(2, 3, 1, 0)) elif len(src_value.shape) == 2: assert (tuple(dst_params[dst_key].get_shape().as_list()[::-1]) == src_params[src_key].shape) src_value = np.transpose(src_params[src_key]._data[0].asnumpy(), axes=(1, 0)) @@ -490,10 +490,14 @@ def convert_gl2tf(dst_net, assert (tuple(dst_params[dst_key].get_shape().as_list()) == src_params[src_key].shape) sess.run(dst_params[dst_key].assign(src_value)) # print(dst_params[dst_key].eval(sess)) - saver = tf.train.Saver() - saver.save( + # saver = tf.train.Saver() + # saver.save( + # sess=sess, + # save_path=dst_params_file_path) + from tensorflow_.utils import save_model_params + save_model_params( sess=sess, - save_path=dst_params_file_path) + file_path=dst_params_file_path) def convert_pt2pt(dst_params_file_path, diff --git a/eval_tf.py b/eval_tf.py index daebe8ff6..ad56b09f2 100644 --- a/eval_tf.py +++ b/eval_tf.py @@ -137,8 +137,11 @@ def main(): num_gpus=args.num_gpus, batch_size=args.batch_size) + classes = 1000 net, inputs_desc = prepare_model( model_name=args.model, + classes=classes, + use_pretrained=args.use_pretrained, pretrained_model_file_path=args.resume.strip()) val_dataflow = get_data( diff --git a/tensorflow_/models/model_store.py b/tensorflow_/models/model_store.py index a30d5f624..600c4ab8d 100644 --- a/tensorflow_/models/model_store.py +++ b/tensorflow_/models/model_store.py @@ -2,7 +2,7 @@ Model store which provides pretrained models. """ -__all__ = ['get_model_file'] +__all__ = ['get_model_file', 'load_model', 'download_model'] import os import zipfile @@ -31,7 +31,7 @@ def get_model_file(model_name, ---------- model_name : str Name of the model. - local_model_store_dir_path : str, default $KERAS_HOME/models + local_model_store_dir_path : str, default $TENSORFLOW_HOME/models Location for keeping the model parameters. Returns @@ -183,3 +183,66 @@ def _check_sha1(filename, sha1_hash): sha1.update(data) return sha1.hexdigest() == sha1_hash + + +def load_model(sess, + file_path, + ignore_extra=True): + """ + Load model state dictionary from a file. + + Parameters + ---------- + sess: Session or None, default None + A Session to use to load the weights. + file_path : str + Path to the file. + ignore_extra : bool, default True + Whether to silently ignore parameters from the file that are not present in this Module. + """ + import numpy as np + import tensorflow as tf + + assert os.path.exists(file_path) and os.path.isfile(file_path) + if file_path.endswith('.npy'): + src_params = np.load(file_path, encoding='latin1').item() + elif file_path.endswith('.npz'): + src_params = dict(np.load(file_path)) + else: + raise NotImplementedError + dst_params = {v.name: v for v in tf.global_variables()} + if sess is None: + sess = tf.Session() + sess.run(tf.global_variables_initializer()) + for src_key in src_params.keys(): + if src_key in dst_params.keys(): + assert (src_params[src_key].shape == tuple(dst_params[src_key].get_shape().as_list())) + sess.run(dst_params[src_key].assign(src_params[src_key])) + elif not ignore_extra: + raise Exception("The file `{}` is incompatible with the model".format(file_path)) + + +def download_model(sess, + model_name, + local_model_store_dir_path=os.path.join('~', '.tensorflow', 'models'), + ignore_extra=True): + """ + Load model state dictionary from a file with downloading it if necessary. + + Parameters + ---------- + sess: Session or None, default None + A Session to use to load the weights. + model_name : str + Name of the model. + local_model_store_dir_path : str, default $TENSORFLOW_HOME/models + Location for keeping the model parameters. + ignore_extra : bool, default True + Whether to silently ignore parameters from the file that are not present in this Module. + """ + load_model( + sess=sess, + file_path=get_model_file( + model_name=model_name, + local_model_store_dir_path=local_model_store_dir_path), + ignore_extra=ignore_extra) diff --git a/tensorflow_/models/others/resnet_.py b/tensorflow_/models/others/resnet_.py index 3e7b54491..af156dd26 100644 --- a/tensorflow_/models/others/resnet_.py +++ b/tensorflow_/models/others/resnet_.py @@ -17,7 +17,7 @@ layer_register from tensorpack.tfutils import argscope import tensorflow.contrib.slim as slim -from tensorflow_.models.common import ImageNetModel, conv2d, se_block +from .common_ import ImageNetModel, conv2d, se_block @layer_register(log_shape=True) diff --git a/tensorflow_/models/resnet.py b/tensorflow_/models/resnet.py index 0342fd5c5..a1189c212 100644 --- a/tensorflow_/models/resnet.py +++ b/tensorflow_/models/resnet.py @@ -441,6 +441,7 @@ def get_resnet(blocks, width_scale=1.0, model_name=None, pretrained=False, + sess=None, root=os.path.join('~', '.tensorflow', 'models'), **kwargs): """ @@ -460,6 +461,8 @@ def get_resnet(blocks, Model name for loading pretrained model. pretrained : bool, default False Whether to load the pretrained weights for model. + sess: Session or None, default None + A Session to use to load the weights. root : str, default '~/.tensorflow/models' Location for keeping the model parameters. @@ -509,8 +512,20 @@ def get_resnet(blocks, channels = [[int(cij * width_scale) for cij in ci] for ci in channels] init_block_channels = int(init_block_channels * width_scale) - def net_lambda(x): - return resnet( + if pretrained and ((model_name is None) or (not model_name)): + raise ValueError("Parameter `model_name` should be properly initialized for loading pretrained model.") + + def net_lambda(x, + channels=channels, + init_block_channels=init_block_channels, + bottleneck=bottleneck, + conv1_stride=conv1_stride, + use_se=use_se, + pretrained=pretrained, + sess=sess, + model_name=model_name, + root=root): + y_net = resnet( x=x, channels=channels, init_block_channels=init_block_channels, @@ -518,20 +533,15 @@ def net_lambda(x): conv1_stride=conv1_stride, use_se=use_se, **kwargs) - net = net_lambda + if pretrained: + from .model_store import download_model + download_model( + sess=sess, + model_name=model_name, + local_model_store_dir_path=root) + return y_net - param_dict = None - if pretrained: - if (model_name is None) or (not model_name): - raise ValueError("Parameter `model_name` should be properly initialized for loading pretrained model.") - # from .model_store import get_model_file - # param_dict = - # net.load_weights( - # filepath=get_model_file( - # model_name=model_name, - # local_model_store_dir_path=root)) - - return net, param_dict + return net_lambda def resnet10(**kwargs): @@ -543,6 +553,8 @@ def resnet10(**kwargs): ---------- pretrained : bool, default False Whether to load the pretrained weights for model. + sess: Session or None, default None + A Session to use to load the weights. root : str, default '~/.tensorflow/models' Location for keeping the model parameters. """ @@ -558,6 +570,8 @@ def resnet12(**kwargs): ---------- pretrained : bool, default False Whether to load the pretrained weights for model. + sess: Session or None, default None + A Session to use to load the weights. root : str, default '~/.tensorflow/models' Location for keeping the model parameters. """ @@ -573,6 +587,8 @@ def resnet14(**kwargs): ---------- pretrained : bool, default False Whether to load the pretrained weights for model. + sess: Session or None, default None + A Session to use to load the weights. root : str, default '~/.tensorflow/models' Location for keeping the model parameters. """ @@ -588,6 +604,8 @@ def resnet16(**kwargs): ---------- pretrained : bool, default False Whether to load the pretrained weights for model. + sess: Session or None, default None + A Session to use to load the weights. root : str, default '~/.tensorflow/models' Location for keeping the model parameters. """ @@ -603,6 +621,8 @@ def resnet18_wd4(**kwargs): ---------- pretrained : bool, default False Whether to load the pretrained weights for model. + sess: Session or None, default None + A Session to use to load the weights. root : str, default '~/.tensorflow/models' Location for keeping the model parameters. """ @@ -618,6 +638,8 @@ def resnet18_wd2(**kwargs): ---------- pretrained : bool, default False Whether to load the pretrained weights for model. + sess: Session or None, default None + A Session to use to load the weights. root : str, default '~/.tensorflow/models' Location for keeping the model parameters. """ @@ -633,6 +655,8 @@ def resnet18_w3d4(**kwargs): ---------- pretrained : bool, default False Whether to load the pretrained weights for model. + sess: Session or None, default None + A Session to use to load the weights. root : str, default '~/.tensorflow/models' Location for keeping the model parameters. """ @@ -647,6 +671,8 @@ def resnet18(**kwargs): ---------- pretrained : bool, default False Whether to load the pretrained weights for model. + sess: Session or None, default None + A Session to use to load the weights. root : str, default '~/.tensorflow/models' Location for keeping the model parameters. """ @@ -661,6 +687,8 @@ def resnet34(**kwargs): ---------- pretrained : bool, default False Whether to load the pretrained weights for model. + sess: Session or None, default None + A Session to use to load the weights. root : str, default '~/.tensorflow/models' Location for keeping the model parameters. """ @@ -675,6 +703,8 @@ def resnet50(**kwargs): ---------- pretrained : bool, default False Whether to load the pretrained weights for model. + sess: Session or None, default None + A Session to use to load the weights. root : str, default '~/.tensorflow/models' Location for keeping the model parameters. """ @@ -690,6 +720,8 @@ def resnet50b(**kwargs): ---------- pretrained : bool, default False Whether to load the pretrained weights for model. + sess: Session or None, default None + A Session to use to load the weights. root : str, default '~/.tensorflow/models' Location for keeping the model parameters. """ @@ -704,6 +736,8 @@ def resnet101(**kwargs): ---------- pretrained : bool, default False Whether to load the pretrained weights for model. + sess: Session or None, default None + A Session to use to load the weights. root : str, default '~/.tensorflow/models' Location for keeping the model parameters. """ @@ -719,6 +753,8 @@ def resnet101b(**kwargs): ---------- pretrained : bool, default False Whether to load the pretrained weights for model. + sess: Session or None, default None + A Session to use to load the weights. root : str, default '~/.tensorflow/models' Location for keeping the model parameters. """ @@ -733,6 +769,8 @@ def resnet152(**kwargs): ---------- pretrained : bool, default False Whether to load the pretrained weights for model. + sess: Session or None, default None + A Session to use to load the weights. root : str, default '~/.tensorflow/models' Location for keeping the model parameters. """ @@ -748,6 +786,8 @@ def resnet152b(**kwargs): ---------- pretrained : bool, default False Whether to load the pretrained weights for model. + sess: Session or None, default None + A Session to use to load the weights. root : str, default '~/.tensorflow/models' Location for keeping the model parameters. """ @@ -763,6 +803,8 @@ def resnet200(**kwargs): ---------- pretrained : bool, default False Whether to load the pretrained weights for model. + sess: Session or None, default None + A Session to use to load the weights. root : str, default '~/.tensorflow/models' Location for keeping the model parameters. """ @@ -778,6 +820,8 @@ def resnet200b(**kwargs): ---------- pretrained : bool, default False Whether to load the pretrained weights for model. + sess: Session or None, default None + A Session to use to load the weights. root : str, default '~/.tensorflow/models' Location for keeping the model parameters. """ @@ -792,6 +836,8 @@ def seresnet18(**kwargs): ---------- pretrained : bool, default False Whether to load the pretrained weights for model. + sess: Session or None, default None + A Session to use to load the weights. root : str, default '~/.tensorflow/models' Location for keeping the model parameters. """ @@ -806,6 +852,8 @@ def seresnet34(**kwargs): ---------- pretrained : bool, default False Whether to load the pretrained weights for model. + sess: Session or None, default None + A Session to use to load the weights. root : str, default '~/.tensorflow/models' Location for keeping the model parameters. """ @@ -820,6 +868,8 @@ def seresnet50(**kwargs): ---------- pretrained : bool, default False Whether to load the pretrained weights for model. + sess: Session or None, default None + A Session to use to load the weights. root : str, default '~/.tensorflow/models' Location for keeping the model parameters. """ @@ -835,6 +885,8 @@ def seresnet50b(**kwargs): ---------- pretrained : bool, default False Whether to load the pretrained weights for model. + sess: Session or None, default None + A Session to use to load the weights. root : str, default '~/.tensorflow/models' Location for keeping the model parameters. """ @@ -849,6 +901,8 @@ def seresnet101(**kwargs): ---------- pretrained : bool, default False Whether to load the pretrained weights for model. + sess: Session or None, default None + A Session to use to load the weights. root : str, default '~/.tensorflow/models' Location for keeping the model parameters. """ @@ -864,6 +918,8 @@ def seresnet101b(**kwargs): ---------- pretrained : bool, default False Whether to load the pretrained weights for model. + sess: Session or None, default None + A Session to use to load the weights. root : str, default '~/.tensorflow/models' Location for keeping the model parameters. """ @@ -878,6 +934,8 @@ def seresnet152(**kwargs): ---------- pretrained : bool, default False Whether to load the pretrained weights for model. + sess: Session or None, default None + A Session to use to load the weights. root : str, default '~/.tensorflow/models' Location for keeping the model parameters. """ @@ -893,6 +951,8 @@ def seresnet152b(**kwargs): ---------- pretrained : bool, default False Whether to load the pretrained weights for model. + sess: Session or None, default None + A Session to use to load the weights. root : str, default '~/.tensorflow/models' Location for keeping the model parameters. """ @@ -908,6 +968,8 @@ def seresnet200(**kwargs): ---------- pretrained : bool, default False Whether to load the pretrained weights for model. + sess: Session or None, default None + A Session to use to load the weights. root : str, default '~/.tensorflow/models' Location for keeping the model parameters. """ @@ -923,6 +985,8 @@ def seresnet200b(**kwargs): ---------- pretrained : bool, default False Whether to load the pretrained weights for model. + sess: Session or None, default None + A Session to use to load the weights. root : str, default '~/.tensorflow/models' Location for keeping the model parameters. """ diff --git a/tensorflow_/utils.py b/tensorflow_/utils.py index 1eb53b5f0..db0a5d2bb 100644 --- a/tensorflow_/utils.py +++ b/tensorflow_/utils.py @@ -1,8 +1,16 @@ +import numpy as np import tensorflow as tf from .model_provider import get_model +def save_model_params(sess, + file_path): + # assert file_path.endswith('.npz') + param_dict = {v.name: v.eval(sess) for v in tf.global_variables()} + np.savez_compressed(file_path, **param_dict) + + def load_model_params(net, param_dict, sess, diff --git a/tensorflow_/utils_tp.py b/tensorflow_/utils_tp.py index 815e567d8..e9400ae9b 100644 --- a/tensorflow_/utils_tp.py +++ b/tensorflow_/utils_tp.py @@ -224,9 +224,13 @@ def prepare_tf_context(num_gpus, def prepare_model(model_name, + classes, + use_pretrained, pretrained_model_file_path): + kwargs = {'pretrained': use_pretrained, + 'classes': classes} - net = get_model(model_name) + net = get_model(model_name, **kwargs) net = ImageNetModel(model_lambda=net) inputs_desc = None diff --git a/train_tf.py b/train_tf.py index 490f683e5..8503c7956 100644 --- a/train_tf.py +++ b/train_tf.py @@ -210,8 +210,11 @@ def main(): num_gpus=args.num_gpus, batch_size=args.batch_size) + classes = 1000 net, inputs_desc = prepare_model( model_name=args.model, + classes=classes, + use_pretrained=args.use_pretrained, pretrained_model_file_path=args.resume.strip()) train_dataflow = get_data(