Model Apply Relate

[TOC]

ckpt 2 pb

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def save_pb(dst_pb, sess, output_node_names=None):
# nn build
# saver = tf.train.Saver(tf.global_variables())
# sess = tf.Session()
# saver.restore(sess, ckpt)

constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names)
with tf.gfile.FastGFile(dst_pb, mode='wb') as f:
f.write(constant_graph.SerializeToString())
return constant_graph

def main():
# model restore
dst_file = 'dest.pb'
names = ['out_argmax', 'softmax', 'out_put_k_indices']
save_pb(dst_file, model.sess, output_node_names=names)

read/run pb

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def read_pb(pb_file, in_elem, return_elements):
"""

:param pb_file: './tmp/model-20000.pb'
:param in_elem
x_ = g_1.get_tensor_by_name('import/X:0')
drop = g_1.get_tensor_by_name('import/drop_out:0')
:param return_elements
out = g_1.get_tensor_by_name('import/Softmax:0')
out2 = g_1.get_tensor_by_name('import/out_top_k:0')
out3 = g_1.get_tensor_by_name('import/out_put_k_indices:0')
:return:
"""
f = open(pb_file, 'rb')
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
f.close()

with tf.Graph().as_default() as g_1:
output = tf.import_graph_def(graph_def, ) # return_elements=['Softmax']
# g = tf.Graph().as_default()
print('---------------------')
# print('read_pb----', g_1.get_operations())
print('---------------------')
in_tensor = {}
for ele in in_elem:
in_tensor[ele] = g_1.get_tensor_by_name(ele)

out_tensor = {}
for ele in return_elements:
out_tensor[ele] = g_1.get_tensor_by_name(ele)

out_tensor.values() # [v1, v2]

sess = tf.Session(graph=g_1)

return sess, in_tensor, out_tensor