-
Notifications
You must be signed in to change notification settings - Fork 332
gazeml
sngyo edited this page Nov 4, 2019
·
2 revisions
import tf2onnx
sess = tf.Session()
BIG_MODEL=False
if BIG_MODEL:
EYE_WIDTH=180
EYE_HEIGHT=108
else:
EYE_WIDTH=60
EYE_HEIGHT=36
if BIG_MODEL:
elgmodel = ELG(
sess, train_data={'videostream': data_source},
first_layer_stride=3,
num_modules=3,
num_feature_maps=64,
learning_schedule=[
{
'loss_terms_to_optimize': {'dummy': ['hourglass', 'radius']},
},
],
)
else:
elgmodel = ELG(
sess, train_data={'videostream': data_source},
first_layer_stride=1,
num_modules=2,
num_feature_maps=32,
learning_schedule=[
{
'loss_terms_to_optimize': {'dummy': ['hourglass', 'radius']},
},
],
)
elgmodel.initialize_if_not(training=False)
elgmodel.checkpoint.load_all()
eye = sess.graph.get_tensor_by_name('eye:0')
if BIG_MODEL:
heatmaps = sess.graph.get_tensor_by_name('hourglass/hg_3/after/hmap/conv/BiasAdd:0')
else:
heatmaps = sess.graph.get_tensor_by_name('hourglass/hg_2/after/hmap/conv/BiasAdd:0')
landmarks = sess.graph.get_tensor_by_name('upscale/mul:0')
radius = sess.graph.get_tensor_by_name('radius/out/fc/BiasAdd:0')
# fix batch norm nodes
gd = sess.graph.as_graph_def()
for node in gd.node:
if node.op == 'RefSwitch':
node.op = 'Switch'
for index in range(len(node.input)):
if 'moving_' in node.input[index]:
node.input[index] = node.input[index] + '/read'
elif node.op == 'AssignSub':
node.op = 'Sub'
if 'use_locking' in node.attr: del node.attr['use_locking']
# Freeze the graph
if BIG_MODEL:
output_node_names=["upscale/mul","hourglass/hg_3/after/hmap/conv/BiasAdd","radius/out/fc/BiasAdd"]
else:
output_node_names=["upscale/mul","hourglass/hg_2/after/hmap/conv/BiasAdd","radius/out/fc/BiasAdd"]
frozen_graph_def = tf.graph_util.convert_variables_to_constants(
sess,
gd,
output_node_names
)
# Convert placeholder to constant
target_node_name = "learning_params/Placeholder_1"
c = tf.constant(False, dtype=bool, shape=[], name=target_node_name)
from tensorflow.core.framework import graph_pb2
import copy
detected=False
new_graph_def = graph_pb2.GraphDef()
for node in frozen_graph_def.node:
print(node.name+"/"+target_node_name)
if node.name == target_node_name:
detected=True
new_graph_def.node.extend([c.op.node_def])
else:
new_graph_def.node.extend([copy.deepcopy(node)])
frozen_graph_def = new_graph_def
# Convert to onnx
input_names=["import/eye:0"]
if BIG_MODEL:
output_names=["import/upscale/mul:0","import/hourglass/hg_3/after/hmap/conv/BiasAdd:0","import/radius/out/fc/BiasAdd:0"]
onnx_name="gazeml_elg_i180x108_n64.onnx"
else:
output_names=["import/upscale/mul:0","import/hourglass/hg_2/after/hmap/conv/BiasAdd:0","import/radius/out/fc/BiasAdd:0"]
onnx_name="gazeml_elg_i60x36_n32.onnx"
graph1 = tf.Graph()
with graph1.as_default():
tf.import_graph_def(frozen_graph_def)
onnx_graph = tf2onnx.tfonnx.process_tf_graph(graph1, input_names=input_names, output_names=output_names, opset=10)
from tf2onnx.optimizer.transpose_optimizer import TransposeOptimizer
optimizer = TransposeOptimizer()
opt_model_proto = optimizer.optimize(onnx_graph)
model_proto = onnx_graph.make_model("gazeml")
with open(onnx_name, "wb") as f:
f.write(model_proto.SerializeToString())
(c) 2019 ax Inc. & AXELL CORPORATION