Skip to content

Commit

Permalink
Work on TF branch, 2
Browse files Browse the repository at this point in the history
  • Loading branch information
osmr committed Oct 15, 2018
1 parent 71a49ef commit 7f99c7a
Show file tree
Hide file tree
Showing 8 changed files with 172 additions and 23 deletions.
12 changes: 8 additions & 4 deletions convert_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ def convert_gl2tf(dst_net,
# if not (tuple(dst_params[dst_key].get_shape().as_list()[::-1]) == src_params[src_key].shape):
# a = 1
assert (tuple(dst_params[dst_key].get_shape().as_list()[::-1]) == src_params[src_key].shape)
src_value = np.transpose(src_params[src_key]._data[0].asnumpy(), axes=(3, 2, 1, 0))
src_value = np.transpose(src_params[src_key]._data[0].asnumpy(), axes=(2, 3, 1, 0))
elif len(src_value.shape) == 2:
assert (tuple(dst_params[dst_key].get_shape().as_list()[::-1]) == src_params[src_key].shape)
src_value = np.transpose(src_params[src_key]._data[0].asnumpy(), axes=(1, 0))
Expand All @@ -490,10 +490,14 @@ def convert_gl2tf(dst_net,
assert (tuple(dst_params[dst_key].get_shape().as_list()) == src_params[src_key].shape)
sess.run(dst_params[dst_key].assign(src_value))
# print(dst_params[dst_key].eval(sess))
saver = tf.train.Saver()
saver.save(
# saver = tf.train.Saver()
# saver.save(
# sess=sess,
# save_path=dst_params_file_path)
from tensorflow_.utils import save_model_params
save_model_params(
sess=sess,
save_path=dst_params_file_path)
file_path=dst_params_file_path)


def convert_pt2pt(dst_params_file_path,
Expand Down
3 changes: 3 additions & 0 deletions eval_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,11 @@ def main():
num_gpus=args.num_gpus,
batch_size=args.batch_size)

classes = 1000
net, inputs_desc = prepare_model(
model_name=args.model,
classes=classes,
use_pretrained=args.use_pretrained,
pretrained_model_file_path=args.resume.strip())

val_dataflow = get_data(
Expand Down
67 changes: 65 additions & 2 deletions tensorflow_/models/model_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Model store which provides pretrained models.
"""

__all__ = ['get_model_file']
__all__ = ['get_model_file', 'load_model', 'download_model']

import os
import zipfile
Expand Down Expand Up @@ -31,7 +31,7 @@ def get_model_file(model_name,
----------
model_name : str
Name of the model.
local_model_store_dir_path : str, default $KERAS_HOME/models
local_model_store_dir_path : str, default $TENSORFLOW_HOME/models
Location for keeping the model parameters.
Returns
Expand Down Expand Up @@ -183,3 +183,66 @@ def _check_sha1(filename, sha1_hash):
sha1.update(data)

return sha1.hexdigest() == sha1_hash


def load_model(sess,
file_path,
ignore_extra=True):
"""
Load model state dictionary from a file.
Parameters
----------
sess: Session or None, default None
A Session to use to load the weights.
file_path : str
Path to the file.
ignore_extra : bool, default True
Whether to silently ignore parameters from the file that are not present in this Module.
"""
import numpy as np
import tensorflow as tf

assert os.path.exists(file_path) and os.path.isfile(file_path)
if file_path.endswith('.npy'):
src_params = np.load(file_path, encoding='latin1').item()
elif file_path.endswith('.npz'):
src_params = dict(np.load(file_path))
else:
raise NotImplementedError
dst_params = {v.name: v for v in tf.global_variables()}
if sess is None:
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for src_key in src_params.keys():
if src_key in dst_params.keys():
assert (src_params[src_key].shape == tuple(dst_params[src_key].get_shape().as_list()))
sess.run(dst_params[src_key].assign(src_params[src_key]))
elif not ignore_extra:
raise Exception("The file `{}` is incompatible with the model".format(file_path))


def download_model(sess,
model_name,
local_model_store_dir_path=os.path.join('~', '.tensorflow', 'models'),
ignore_extra=True):
"""
Load model state dictionary from a file with downloading it if necessary.
Parameters
----------
sess: Session or None, default None
A Session to use to load the weights.
model_name : str
Name of the model.
local_model_store_dir_path : str, default $TENSORFLOW_HOME/models
Location for keeping the model parameters.
ignore_extra : bool, default True
Whether to silently ignore parameters from the file that are not present in this Module.
"""
load_model(
sess=sess,
file_path=get_model_file(
model_name=model_name,
local_model_store_dir_path=local_model_store_dir_path),
ignore_extra=ignore_extra)
2 changes: 1 addition & 1 deletion tensorflow_/models/others/resnet_.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
layer_register
from tensorpack.tfutils import argscope
import tensorflow.contrib.slim as slim
from tensorflow_.models.common import ImageNetModel, conv2d, se_block
from .common_ import ImageNetModel, conv2d, se_block


@layer_register(log_shape=True)
Expand Down
Loading

0 comments on commit 7f99c7a

Please sign in to comment.