diff --git a/tensorflow_/models/model_store.py b/tensorflow_/models/model_store.py index acd268b50..a30d5f624 100644 --- a/tensorflow_/models/model_store.py +++ b/tensorflow_/models/model_store.py @@ -41,7 +41,7 @@ def get_model_file(model_name, """ error, sha1_hash, repo_release_tag = get_model_name_suffix_data(model_name) short_sha1 = sha1_hash[:8] - file_name = '{name}-{error}-{short_sha1}.tfm'.format( + file_name = '{name}-{error}-{short_sha1}.tf.npz'.format( name=model_name, error=error, short_sha1=short_sha1) diff --git a/tensorflow_/models/resnet.py b/tensorflow_/models/resnet.py index fb9a203a4..0342fd5c5 100644 --- a/tensorflow_/models/resnet.py +++ b/tensorflow_/models/resnet.py @@ -462,6 +462,13 @@ def get_resnet(blocks, Whether to load the pretrained weights for model. root : str, default '~/.tensorflow/models' Location for keeping the model parameters. + + Returns + ------- + Function + Model script. + Dict or None + Model parameter dict. """ if blocks == 10: @@ -513,16 +520,18 @@ def net_lambda(x): **kwargs) net = net_lambda + param_dict = None if pretrained: if (model_name is None) or (not model_name): raise ValueError("Parameter `model_name` should be properly initialized for loading pretrained model.") # from .model_store import get_model_file + # param_dict = # net.load_weights( # filepath=get_model_file( # model_name=model_name, # local_model_store_dir_path=root)) - return net + return net, param_dict def resnet10(**kwargs): @@ -959,7 +968,7 @@ def _test(): for model in models: - net = model(pretrained=pretrained) + net, _ = model(pretrained=pretrained) x = tf.placeholder( dtype=tf.float32, diff --git a/tensorflow_/utils.py b/tensorflow_/utils.py index bd3381991..1eb53b5f0 100644 --- a/tensorflow_/utils.py +++ b/tensorflow_/utils.py @@ -3,6 +3,20 @@ from .model_provider import get_model +def load_model_params(net, + param_dict, + sess, + ignore_missing=False): + for param_name, param_data in param_dict: + with tf.variable_scope(param_name, reuse=True): + try: + var = tf.get_variable(param_name) + sess.run(var.assign(param_data)) + except ValueError: + if not ignore_missing: + raise + + def prepare_model(model_name, classes, use_pretrained):