Skip to content

Commit

Permalink
Work on TF-convert script for ResNet
Browse files Browse the repository at this point in the history
  • Loading branch information
osmr committed Oct 13, 2018
1 parent b5f9ba6 commit 3d6de95
Show file tree
Hide file tree
Showing 6 changed files with 369 additions and 286 deletions.
63 changes: 63 additions & 0 deletions convert_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,15 @@ def prepare_dst_model(dst_fwk,
if weight.name:
dst_params.setdefault(weight.name, []).append(weight)
dst_params[weight.name] = (layer, weight)
elif dst_fwk == "tensorflow":
import tensorflow as tf
from tensorflow_.utils import prepare_model as prepare_model_tf
dst_net = prepare_model_tf(
model_name=dst_model,
classes=num_classes,
use_pretrained=False)
dst_param_keys = [v.name for v in tf.global_variables()]
dst_params = {v.name: v for v in tf.global_variables()}
else:
raise ValueError("Unsupported dst fwk: {}".format(dst_fwk))

Expand Down Expand Up @@ -443,6 +452,50 @@ def process_width(src_key, dst_key, src_weight):
dst_net.save_weights(dst_params_file_path)


def convert_gl2tf(dst_net,
dst_params_file_path,
dst_params,
dst_param_keys,
src_params,
src_param_keys):
dst_param_keys = [key.replace('/kernel:', '/weight:') for key in dst_param_keys]

src_param_keys.sort()
src_param_keys.sort(key=lambda var: ['{:10}'.format(int(x)) if
x.isdigit() else x for x in re.findall(r'[^0-9]|[0-9]+', var)])

dst_param_keys.sort()
dst_param_keys.sort(key=lambda var: ['{:10}'.format(int(x)) if
x.isdigit() else x for x in re.findall(r'[^0-9]|[0-9]+', var)])

dst_param_keys = [key.replace('/weight:', '/kernel:') for key in dst_param_keys]

import tensorflow as tf
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i, (src_key, dst_key) in enumerate(zip(src_param_keys, dst_param_keys)):
# assert (tuple(dst_params[dst_key].get_shape().as_list()) == src_params[src_key].shape)
src_value = src_params[src_key]._data[0].asnumpy()
if len(src_value.shape) == 4:
# if not (tuple(dst_params[dst_key].get_shape().as_list()[::-1]) == src_params[src_key].shape):
# a = 1
assert (tuple(dst_params[dst_key].get_shape().as_list()[::-1]) == src_params[src_key].shape)
src_value = np.transpose(src_params[src_key]._data[0].asnumpy(), axes=(3, 2, 1, 0))
elif len(src_value.shape) == 2:
assert (tuple(dst_params[dst_key].get_shape().as_list()[::-1]) == src_params[src_key].shape)
src_value = np.transpose(src_params[src_key]._data[0].asnumpy(), axes=(1, 0))
else:
# if not (tuple(dst_params[dst_key].get_shape().as_list()) == src_params[src_key].shape):
# a = 1
assert (tuple(dst_params[dst_key].get_shape().as_list()) == src_params[src_key].shape)
sess.run(dst_params[dst_key].assign(src_value))
# print(dst_params[dst_key].eval(sess))
saver = tf.train.Saver()
saver.save(
sess=sess,
save_path=dst_params_file_path)


def convert_pt2pt(dst_params_file_path,
dst_params,
dst_param_keys,
Expand Down Expand Up @@ -589,6 +642,14 @@ def main():
dst_param_keys=dst_param_keys,
src_params=src_params,
src_param_keys=src_param_keys)
elif args.src_fwk == "gluon" and args.dst_fwk == "tensorflow":
convert_gl2tf(
dst_net=dst_net,
dst_params_file_path=args.dst_params,
dst_params=dst_params,
dst_param_keys=dst_param_keys,
src_params=src_params,
src_param_keys=src_param_keys)
elif args.src_fwk == "pytorch" and args.dst_fwk == "gluon":
convert_pt2gl(
dst_net=dst_net,
Expand All @@ -608,6 +669,8 @@ def main():
src_arg_params=src_arg_params,
src_model=args.src_model,
ctx=ctx)
else:
raise NotImplementedError

logging.info('Convert {}-model {} into {}-model {}'.format(
args.src_fwk, args.src_model, args.dst_fwk, args.dst_model))
Expand Down
2 changes: 1 addition & 1 deletion eval_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from tensorpack.input_source import QueueInput, StagingInput

from common.logger_utils import initialize_logging
from tensorflow_.utils import prepare_tf_context, prepare_model, get_data, calc_flops
from tensorflow_.utils_tp import prepare_tf_context, prepare_model, get_data, calc_flops


def parse_args():
Expand Down
4 changes: 2 additions & 2 deletions tensorflow_/model_provider.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .models.resnet import *
from tensorflow_.models.others.shufflenetv2 import *
# from tensorflow_.models.others.shufflenetv2 import *

__all__ = ['get_model']

Expand Down Expand Up @@ -35,7 +35,7 @@
'seresnet200': seresnet200,
'seresnet200b': seresnet200b,

'shufflenetv2_wd2': shufflenetv2_wd2,
# 'shufflenetv2_wd2': shufflenetv2_wd2,
}


Expand Down
Loading

0 comments on commit 3d6de95

Please sign in to comment.