import tensorflow as tf v = tf.Variable(0, dtype=tf.float32, name="v") for variables in tf.global_variables(): print(variables.name) # v:0 ema = tf.train.ExponentialMovingAverage(0.99) maintain_averages_op = ema.apply(tf.global_variables()) for variables in tf.global_variables(): print(variables.name) # v:0 # v/ExponentialMovingAverage:0 saver = tf.train.Saver() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.assign(v, 10)) sess.run(maintain_averages_op) saver.save(sess, "Model/model_ema.ckpt") if __name__ == '__main__': tf.app.run()
with tf.Session() as sess: saver.restore(sess, "./Model/model_ema.ckpt") print(sess.run(v))
if __name__ == '__main__': tf.app.run()
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#获取影子变量方式2 import tensorflow as tf
def main(argv=None): v = tf.Variable(0, dtype=tf.float32, name="v") # 注意此处的变量名称name一定要与已保存的变量名称一致 ema = tf.train.ExponentialMovingAverage(0.99) print(ema.variables_to_restore()) # {'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>} # 此处的v取自上面变量v的名称name="v" saver = tf.train.Saver(ema.variables_to_restore()) with tf.Session() as sess: saver.restore(sess, "./Model/model_ema.ckpt") print(sess.run(v)) if __name__ == '__main__': tf.app.run()
常量保存
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
import tensorflow as tf from tensorflow.python.framework import graph_util
def main(argv=None): v1 = tf.Variable(initial_value=tf.constant(value=1.0, shape=[1]), name='v1') v2 = tf.Variable(initial_value=tf.constant(value=2.0, shape=[1]), name='v2') result = tf.add(x=v1, y=v2, name='add') with tf.Session() as sess: tf.global_variables_initializer().run() graph_def = tf.get_default_graph().as_graph_def() output_graph_def = graph_util.convert_variables_to_constants(sess=sess, input_graph_def=graph_def, output_node_names=['add']) with tf.gfile.GFile(name='./model/combined.pd', mode='wb') as f: f.write(output_graph_def.SerializeToString()) if __name__ == '__main__': tf.app.run()
常量恢复
1 2 3 4 5 6 7 8 9 10 11 12 13 14
import tensorflow as tf from tensorflow.python.platform import gfile
def main(argv=None): with tf.Session() as sess: model_filename = './model/combined.pd' with gfile.FastGFile(model_filename, mode='rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) result = tf.import_graph_def(graph_def, return_elements=['add:0']) print(sess.run(result))