tensorflow ckpt模型和pb模型获取节点名称,及ckpt转pb模型实例

ckpt

from tensorflow.python import pywrap_tensorflow

checkpoint_path = 'model.ckpt-8000'

reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)

var_to_shape_map = reader.get_variable_to_shape_map()

for key in var_to_shape_map:

print("tensor_name: ", key)

pb

import tensorflow as tf

import os

model_name = './mobilenet_v2_140_inf_graph.pb'

def create_graph():

with tf.gfile.FastGFile(model_name, 'rb') as f:

graph_def = tf.GraphDef()

graph_def.ParseFromString(f.read())

tf.import_graph_def(graph_def, name='')

create_graph()

tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]

for tensor_name in tensor_name_list:

print(tensor_name,'\n')

ckpt转pb

def freeze_graph(input_checkpoint,output_graph):

'''

:param input_checkpoint:

:param output_graph: PB模型保存路径

:return:

'''

output_node_names = "xxx"

saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)

graph = tf.get_default_graph()

input_graph_def = graph.as_graph_def()

with tf.Session() as sess:

saver.restore(sess, input_checkpoint)

output_graph_def = graph_util.convert_variables_to_constants(

sess=sess,

input_graph_def=input_graph_def,# 等于:sess.graph_def

output_node_names=output_node_names.split(","))

with tf.gfile.GFile(output_graph, "wb") as f:

f.write(output_graph_def.SerializeToString())

print("%d ops in the final graph." % len(output_graph_def.node))

for op in graph.get_operations():

print(op.name, op.values())

以上这篇tensorflow ckpt模型和pb模型获取节点名称,及ckpt转pb模型实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。

以上是 tensorflow ckpt模型和pb模型获取节点名称,及ckpt转pb模型实例 的全部内容, 来源链接: utcz.com/z/318135.html

回到顶部