如何在C++中保存和恢复TensorFlow图形及其状态?

4

我正在使用C++中的TensorFlow训练我的模型,Python仅用于构建图形。那么有没有一种纯粹使用C++保存和恢复图形及其状态的方法? 我知道Python类tf.train.Saver,但据我所知它在C++中不存在。

2个回答

9
tf.train.Saver类目前仅存在于Python中,但是 (i) 它是由TensorFlow操作构建的,您可以从C++中运行它们,(ii) 它公开了Saver.as_saver_def()方法,让您获得一个包含必须运行以保存或还原模型的操作名称的SaverDef协议缓冲区

在Python中,您可以按如下方式获取保存和恢复操作的名称:

saver = tf.train.Saver(...)
saver_def = saver.as_saver_def()

# The name of the tensor you must feed with a filename when saving/restoring.
print saver_def.filename_tensor_name

# The name of the target operation you must run when restoring.
print saver_def.restore_op_name

# The name of the target operation you must run when saving.
print saver_def.save_tensor_name

C++中,要从检查点恢复,您需要调用Session::Run(),将检查点文件名作为saver_def.filename_tensor_name传入,目标操作为saver_def.restore_op_name。要保存另一个检查点,您需要再次调用Session::Run(),将检查点文件名作为saver_def.filename_tensor_name传入,并获取saver_def.save_tensor_name的值。


2
太好了!我必须从一个字符串的末尾删除“:0”。此外,在恢复模型期间,相对路径不起作用。Tensorcreation:tf :: Tensor string(tf :: DT_STRING,tf :: TensorShape({1,1})); 提供字符串:string.matrix <std :: string>()(0,0)= file_path_ + filename; 执行:TF_CHECK_OK(session_->Run({{“save / Const:0”,string }}, {},{“save / control_dependency”},nullptr)); - Trevir
@Trevir,mrry:你能发一下代码片段吗?我刚接触tensorflow,文档并没有什么帮助..如果你能帮忙,我将不胜感激! - Surfer on the fall
@Surferonthefall:前面的评论已经包含了所有必要的代码。使用Python脚本获取正确的操作名称,例如“save/Const:0”。之后,您可以通过session->run方法在C++中使用操作名称。 - Trevir
厉害的黑科技解决方案!Python脚本必须包含saver = tf.train.Saver(...)代码行。我可以确认,必须将“save/control_dependency:0”重命名为“save/control_dependency”。 - Snurka Bill
@mrry 能否看一下这个问题?https://github.com/tensorflow/tensorflow/issues/10669#issuecomment-462230580 - Sathyamoorthy R
TF 2.0 的等效物是什么?使用“save/control_dependency”操作保存的我的检查点无法加载回 Keras 模型。 - jregalad

3
最近的TensorFlow版本包含了一些辅助函数,可以在不使用Python的情况下在C++中进行同样的操作。这些函数是从pip包中的ProtoBuf生成的(${HOME}/.local/lib/python2.7/site-packages/tensorflow/include/tensorflow/core/protobuf/saver.pb.h)。
// save
tensorflow::Tensor checkpointPathTensor(tensorflow::DT_STRING, tensorflow::TensorShape());
checkpointPathTensor.scalar<std::string>()() = "some/path";
tensor_dict feed_dict = {{graph_def.saver_def().filename_tensor_name(), checkpointPathTensor}};
status = sess->Run(feed_dict, {}, {graph_def.saver_def().save_tensor_name()}, nullptr);

// restore
tensorflow::Tensor checkpointPathTensor(tensorflow::DT_STRING, tensorflow::TensorShape());
checkpointPathTensor.scalar<std::string>()() = "some/path";
tensor_dict feed_dict = {{graph_def.saver_def().filename_tensor_name(), checkpointPathTensor}};
status = sess->Run(feed_dict, {}, {graph_def.saver_def().restore_op_name()}, nullptr);

这是基于未记录的Python方式(更多细节)恢复模型的方法。

def restore(sess, metaGraph, fn):
    restore_op_name = metaGraph.as_saver_def().restore_op_name   # u'save/restore_all'
    restore_op = tf.get_default_graph().get_operation_by_name(restore_op_name)
    filename_tensor_name = metaGraph.as_saver_def().filename_tensor_name  # u'save/Const'
    sess.run(restore_op, {filename_tensor_name: fn})

想要一个运作完整的版本请点击这里


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接