TensorFLow用Saver保存和恢复变量

本文为大家分享了TensorFLow用Saver保存和恢复变量的具体代码,供大家参考,具体内容如下

建立文件tensor_save.py, 保存变量v1,v2的tensor到checkpoint files中,名称分别设置为v3,v4。

import tensorflow as tf

# Create some variables.

v1 = tf.Variable(3, name="v1")

v2 = tf.Variable(4, name="v2")

# Create model

y=tf.add(v1,v2)

# Add an op to initialize the variables.

init_op = tf.initialize_all_variables()

# Add ops to save and restore all the variables.

saver = tf.train.Saver({'v3':v1,'v4':v2})

# Later, launch the model, initialize the variables, do some work, save the

# variables to disk.

with tf.Session() as sess:

sess.run(init_op)

print("v1 = ", v1.eval())

print("v2 = ", v2.eval())

# Save the variables to disk.

save_path = saver.save(sess, "f:/tmp/model.ckpt")

print ("Model saved in file: ", save_path)

建立文件tensor_restror.py, 将checkpoint files中名称分别为v3,v4的tensor分别恢复到变量v3,v4中。

import tensorflow as tf

# Create some variables.

v3 = tf.Variable(0, name="v3")

v4 = tf.Variable(0, name="v4")

# Create model

y=tf.mul(v3,v4)

# Add ops to save and restore all the variables.

saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and

# do some work with the model.

with tf.Session() as sess:

# Restore variables from disk.

saver.restore(sess, "f:/tmp/model.ckpt")

print ("Model restored.")

print ("v3 = ", v3.eval())

print ("v4 = ", v4.eval())

print ("y = ",sess.run(y))

以上是 TensorFLow用Saver保存和恢复变量 的全部内容, 来源链接: utcz.com/z/352107.html

回到顶部