我正在使用C++中的TensorFlow训练我的模型,Python仅用于构建图形。那么有没有一种纯粹使用C++保存和恢复图形及其状态的方法? 我知道Python类tf.train.Saver
,但据我所知它在C++中不存在。
我正在使用C++中的TensorFlow训练我的模型,Python仅用于构建图形。那么有没有一种纯粹使用C++保存和恢复图形及其状态的方法? 我知道Python类tf.train.Saver
,但据我所知它在C++中不存在。
tf.train.Saver
类目前仅存在于Python中,但是 (i) 它是由TensorFlow操作构建的,您可以从C++中运行它们,(ii) 它公开了Saver.as_saver_def()
方法,让您获得一个包含必须运行以保存或还原模型的操作名称的SaverDef
协议缓冲区。
在Python中,您可以按如下方式获取保存和恢复操作的名称:
saver = tf.train.Saver(...)
saver_def = saver.as_saver_def()
# The name of the tensor you must feed with a filename when saving/restoring.
print saver_def.filename_tensor_name
# The name of the target operation you must run when restoring.
print saver_def.restore_op_name
# The name of the target operation you must run when saving.
print saver_def.save_tensor_name
C++中,要从检查点恢复,您需要调用Session::Run()
,将检查点文件名作为saver_def.filename_tensor_name
传入,目标操作为saver_def.restore_op_name
。要保存另一个检查点,您需要再次调用Session::Run()
,将检查点文件名作为saver_def.filename_tensor_name
传入,并获取saver_def.save_tensor_name
的值。
${HOME}/.local/lib/python2.7/site-packages/tensorflow/include/tensorflow/core/protobuf/saver.pb.h
)。// save
tensorflow::Tensor checkpointPathTensor(tensorflow::DT_STRING, tensorflow::TensorShape());
checkpointPathTensor.scalar<std::string>()() = "some/path";
tensor_dict feed_dict = {{graph_def.saver_def().filename_tensor_name(), checkpointPathTensor}};
status = sess->Run(feed_dict, {}, {graph_def.saver_def().save_tensor_name()}, nullptr);
// restore
tensorflow::Tensor checkpointPathTensor(tensorflow::DT_STRING, tensorflow::TensorShape());
checkpointPathTensor.scalar<std::string>()() = "some/path";
tensor_dict feed_dict = {{graph_def.saver_def().filename_tensor_name(), checkpointPathTensor}};
status = sess->Run(feed_dict, {}, {graph_def.saver_def().restore_op_name()}, nullptr);
这是基于未记录的Python方式(更多细节)恢复模型的方法。
def restore(sess, metaGraph, fn):
restore_op_name = metaGraph.as_saver_def().restore_op_name # u'save/restore_all'
restore_op = tf.get_default_graph().get_operation_by_name(restore_op_name)
filename_tensor_name = metaGraph.as_saver_def().filename_tensor_name # u'save/Const'
sess.run(restore_op, {filename_tensor_name: fn})
想要一个运作完整的版本请点击这里。
tf :: Tensor string(tf :: DT_STRING,tf :: TensorShape({1,1}));
提供字符串:string.matrix <std :: string>()(0,0)= file_path_ + filename;
执行:TF_CHECK_OK(session_->Run({{“save / Const:0”,string }}, {},{“save / control_dependency”},nullptr));
- Trevirsaver = tf.train.Saver(...)
代码行。我可以确认,必须将“save/control_dependency:0”重命名为“save/control_dependency”。 - Snurka Bill