Skip to content

Commit

Permalink
Work on TF branch
Browse files Browse the repository at this point in the history
  • Loading branch information
osmr committed Oct 15, 2018
1 parent 3d6de95 commit 71a49ef
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 3 deletions.
2 changes: 1 addition & 1 deletion tensorflow_/models/model_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def get_model_file(model_name,
"""
error, sha1_hash, repo_release_tag = get_model_name_suffix_data(model_name)
short_sha1 = sha1_hash[:8]
file_name = '{name}-{error}-{short_sha1}.tfm'.format(
file_name = '{name}-{error}-{short_sha1}.tf.npz'.format(
name=model_name,
error=error,
short_sha1=short_sha1)
Expand Down
13 changes: 11 additions & 2 deletions tensorflow_/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,13 @@ def get_resnet(blocks,
Whether to load the pretrained weights for model.
root : str, default '~/.tensorflow/models'
Location for keeping the model parameters.
Returns
-------
Function
Model script.
Dict or None
Model parameter dict.
"""

if blocks == 10:
Expand Down Expand Up @@ -513,16 +520,18 @@ def net_lambda(x):
**kwargs)
net = net_lambda

param_dict = None
if pretrained:
if (model_name is None) or (not model_name):
raise ValueError("Parameter `model_name` should be properly initialized for loading pretrained model.")
# from .model_store import get_model_file
# param_dict =
# net.load_weights(
# filepath=get_model_file(
# model_name=model_name,
# local_model_store_dir_path=root))

return net
return net, param_dict


def resnet10(**kwargs):
Expand Down Expand Up @@ -959,7 +968,7 @@ def _test():

for model in models:

net = model(pretrained=pretrained)
net, _ = model(pretrained=pretrained)

x = tf.placeholder(
dtype=tf.float32,
Expand Down
14 changes: 14 additions & 0 deletions tensorflow_/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,20 @@
from .model_provider import get_model


def load_model_params(net,
param_dict,
sess,
ignore_missing=False):
for param_name, param_data in param_dict:
with tf.variable_scope(param_name, reuse=True):
try:
var = tf.get_variable(param_name)
sess.run(var.assign(param_data))
except ValueError:
if not ignore_missing:
raise


def prepare_model(model_name,
classes,
use_pretrained):
Expand Down

0 comments on commit 71a49ef

Please sign in to comment.