Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transfer denseConv ckpt to pb file error #12

Open
bolyor opened this issue Oct 11, 2018 · 0 comments
Open

transfer denseConv ckpt to pb file error #12

bolyor opened this issue Oct 11, 2018 · 0 comments

Comments

@bolyor
Copy link

bolyor commented Oct 11, 2018

What did I do :
python train_ocr_layer.py
After I got ckpt file, I try to transfer it into pb file with code:

import tensorflow as tf
from tensorflow.python.framework import graph_util

def freeze_graph(input_checkpoint, output_graph):
    '''
    :param input_checkpoint:
    :param output_graph: PB模型保存路径
    :return:
    '''
    model_folder = './checkpoints/'
    checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
    input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径

    # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
    output_node_names = "group_deps"
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
    graph = tf.get_default_graph()  # 获得默认的图
    input_graph_def = graph.as_graph_def()  # 返回一个序列化的图代表当前的图

    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint)  # 恢复图并得到数据
        output_graph_def = graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定
            sess=sess,
            input_graph_def=input_graph_def,  # 等于:sess.graph_def
            output_node_names=output_node_names.split(","))  # 如果有多个输出节点,以逗号隔开

        with tf.gfile.GFile(output_graph, "wb") as f:  # 保存模型
            f.write(output_graph_def.SerializeToString())  # 序列化输出
        print("%d ops in the final graph." % len(output_graph_def.node))  # 得到当前图有几个操作节点

        for op in graph.get_operations():
             print(op.name, op.values())

if __name__ == '__main__':

    input_checkpoint = './checkpoints/denseConv0.ckpt'
    output_graph = './checkpoints/frozen_graph.pb'
    freeze_graph(input_checkpoint, output_graph)

After I got pb file, I just can`t open it with tensorborad or transfer it into ONNX.
I got this error message:

Traceback (most recent call last):
  File "/home/cenhong/.local/lib/python3.6/site-packages/tensorflow/python/framework/importer.py", line 418, in import_graph_def
    graph._c_graph, serialized, options)  # pylint: disable=protected-access
tensorflow.python.framework.errors_impl.InvalidArgumentError: Input 0 of node import/model/batchnorm/AssignMovingAvg was passed float from import/model/batchnorm//moving_mean:0 incompatible with expected float_ref.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/cenhong/.local/bin/tfpb_tensorboard", line 11, in <module>
    load_entry_point('doml', 'console_scripts', 'tfpb_tensorboard')()
  File "/home/cenhong/do-ml/scripts/tfpb_tensorboard.py", line 33, in main
    tfpb_tensorboard(args.input_path, args.log_path, 6006 if args.port is None else args.port)
  File "/home/cenhong/do-ml/scripts/tfpb_tensorboard.py", line 18, in tfpb_tensorboard
    g_in = tf.import_graph_def(graph_def)
  File "/home/cenhong/.local/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py", line 488, in new_func
    return func(*args, **kwargs)
  File "/home/cenhong/.local/lib/python3.6/site-packages/tensorflow/python/framework/importer.py", line 422, in import_graph_def
    raise ValueError(str(e))
ValueError: Input 0 of node import/model/batchnorm/AssignMovingAvg was passed float from import/model/batchnorm//moving_mean:0 incompatible with expected float_ref.

Does anyone know how to solve this? Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant