合并图表:是否有C++的TensorFlow import_graph_def等效函数?

4

我需要使用自定义输入和输出层来扩展导出的模型。我已经发现可以通过以下方式轻松地完成:

with tf.Graph().as_default() as g1: # actual model
    in1 = tf.placeholder(tf.float32,name="input")
    ou1 = tf.add(in1,2.0,name="output")
with tf.Graph().as_default() as g2: # model for the new output layer
    in2 = tf.placeholder(tf.float32,name="input")
    ou2 = tf.add(in2,2.0,name="output")

gdef_1 = g1.as_graph_def()
gdef_2 = g2.as_graph_def()

with tf.Graph().as_default() as g_combined: #merge together
    x = tf.placeholder(tf.float32, name="actual_input") # the new input layer

    # Import gdef_1, which performs f(x).
    # "input:0" and "output:0" are the names of tensors in gdef_1.
    y, = tf.import_graph_def(gdef_1, input_map={"input:0": x},
                             return_elements=["output:0"])

    # Import gdef_2, which performs g(y)
    z, = tf.import_graph_def(gdef_2, input_map={"input:0": y},
                             return_elements=["output:0"])

sess = tf.Session(graph=g_combined)

print "result is: ", sess.run(z, {"actual_input:0":5}) #result is: 9

这个可以正常工作。
然而,我需要传递一个指向网络输入的指针,而不是传递任意形状的数据集。问题是,我无法想出在Python中定义和传递指针的解决方案,并且在使用C ++ Api开发网络时,我找不到与tf.import_graph_def函数等效的内容。
这在C++中是否有不同的名称,或者在C++中合并两个图/模型的方式是否有其他方法?
感谢任何建议。
2个回答

2

这并不像Python中那样简单。

你可以使用类似以下代码来加载一个GraphDef

#include <string>
#include <tensorflow/core/framework/graph.pb.h>
#include <tensorflow/core/platform/env.h>

tensorflow::GraphDef graph;
std::string graphFileName = "...";
auto status = tensorflow::ReadBinaryProto(
    tensorflow::Env::Default(), graphFileName, &graph);
if (!status.ok()) { /* Error... */ }

然后您可以使用它创建一个会话:

#include <tensorflow/core/public/session.h>

tensorflow::Session *newSession;
auto status = tensorflow::NewSession(tensorflow::SessionOptions(), &newSession);
if (!status.ok()) { /* Error... */ }
status = session->Create(graph);
if (!status.ok()) { /* Error... */ }

或者扩展现有图表的图形:
status = session->Extend(graph);
if (!status.ok()) { /* Error... */ }

这样您就可以将几个GraphDef放入同一个图中。但是,没有额外的设施来提取特定节点,也没有避免名称冲突的方法——您必须自己找到节点,并确保GraphDef没有冲突的操作名称。例如,我使用此函数查找所有名称与给定正则表达式匹配的节点,按名称排序:

#include <vector>
#include <regex>
#include <tensorflow/core/framework/node_def.pb.h>

std::vector<const tensorflow::NodeDef *> GetNodes(const tensorflow::GraphDef &graph, const std::regex &regex)
{
    std::vector<const tensorflow::NodeDef *> nodes;
    for (const auto &node : graph.node())
    {
        if (std::regex_match(node.name(), regex))
        {
            nodes.push_back(&node);
        }
    }
    std::sort(nodes.begin(), nodes.end(),
              [](const tensorflow::NodeDef *lhs, const tensorflow::NodeDef *rhs)
              {
                  return lhs->name() < rhs->name();
              });
    return nodes;
}

0

在C++中,可以通过直接操作要合并的两个图的GraphDefs中的NodeDefs来实现。基本算法是定义两个GraphDefs,使用第二个GraphDef输入的占位符,并将它们重定向到第一个GraphDef的输出。这类似于通过将第二电路的输入连接到第一电路的输出来串联两个电路。

首先,定义示例GraphDefs以及用于观察GraphDefs内部的实用程序。重要的是要注意,来自两个GraphDefs的所有节点必须具有唯一名称。

Status Panel::SampleFirst(GraphDef *graph_def) 
{
    Scope root = Scope::NewRootScope();
    Placeholder p1(root.WithOpName("p1"), DT_INT32);
    Placeholder p2(root.WithOpName("p2"), DT_INT32);
    Add add(root.WithOpName("add"), p1, p2);
    return root.ToGraphDef(graph_def);
}

Status Panel::SampleSecond(GraphDef *graph_def)
{
    Scope root = Scope::NewRootScope();
    Placeholder q1(root.WithOpName("q1"), DT_INT32);
    Placeholder q2(root.WithOpName("q2"), DT_INT32);
    Add sum(root.WithOpName("sum"), q1, q2);
    Multiply multiply(root.WithOpName("multiply"), sum, 4);
    return root.ToGraphDef(graph_def);
}

void Panel::ShowGraphDef(GraphDef &graph_def)
{
    for (int i = 0; i < graph_def.node_size(); i++) {
        NodeDef node_def = graph_def.node(i);
        cout << "NodeDef name is " << node_def.name() << endl;
        cout << "NodeDef op is " << node_def.op() << endl;
        for (const string& input : node_def.input()) {
            cout << "\t input: " << input << endl;
        }
    }
}

现在已经创建了两个GraphDefs,并且第二个GraphDef的输入已连接到第一个GraphDef的输出。这是通过迭代节点并识别第一个操作节点来完成的,其输入为Placeholders,并将这些输入重定向到第一个GraphDef的输出。然后将该节点添加到第一个GraphDef以及所有后续节点中。结果是第一个GraphDef附加了第二个GraphDef。
Status Panel::Append(vector<Tensor> *outputs)
{
    GraphDef graph_def_first;
    GraphDef graph_def_second;
    TF_RETURN_IF_ERROR(SampleFirst(&graph_def_first));
    TF_RETURN_IF_ERROR(SampleSecond(&graph_def_second));

    for (int i = 0; i < graph_def_second.node_size(); i++) {
        NodeDef node_def = graph_def_second.node(i);
        if (node_def.name() == "sum") {
            node_def.set_input(0, "p1");
            node_def.set_input(1, "add");
        }
        *graph_def_first.add_node() = node_def;
    }

    ShowGraphDef(graph_def_first);

    unique_ptr<Session> session(NewSession(SessionOptions()));
    TF_RETURN_IF_ERROR(session->Create(graph_def_first));

    Tensor t1(2);
    Tensor t2(3);
    vector<pair<string, Tensor>> inputs = {{"p1", t1}, {"p2", t2}};

    TF_RETURN_IF_ERROR(session->Run(inputs, {"multiply"}, {}, outputs));

    return Status::OK();
}

这个特定的图表将会接收两个输入,2和3,并将它们相加。然后将该总和(5)再次添加到第一个输入(2)中,然后乘以4以获得结果28。((2+3)+2)*4=28。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接