TensorFlow 2.0 C++ - 载入预训练模型

9

有人能给我一些提示,如何使用 TensorFlow 2.0 的 C++ API 加载用 Keras 在 Python 中训练和导出的模型吗?

我找不到相关信息,只有针对 TensorFlow 版本 < 2 的资料。

谢谢!

3个回答

5

好的,我找到了一个解决方案,但是还有其他问题:

在Python中,您需要使用以下方式导出它:

tf.keras.models.save_model(model, 'model')

在 C++ 中,你需要使用以下代码来加载:

tensorflow::SavedModelBundle model;
tensorflow::Status status = tensorflow::LoadSavedModel(
  tensorflow::SessionOptions(), 
  tensorflow::RunOptions(), 
  "path/to/model/folder", 
  {tensorflow::kSavedModelTagServe}, 
  &model);

基于这篇文章:在C++中使用Tensorflow checkpoint还原模型

如果我现在尝试设置输入和输出,会抛出一个错误:"无法找到名为'outputlayer'的节点"和"无效参数:输入张量:0,在feed_devices或fetch_devices中指定的张量不存在于Graph中"。

有人知道问题出在哪里吗?


2

您的初步想法很好。您需要使用来自tensorflow的saved_model_cli工具。它会输出类似于以下内容:

PS C:\model_dir> saved_model_cli show --dir . --all
       
[...]

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['__saved_model_init_op']:
  The given SavedModel SignatureDef contains the following input(s): 
  The given SavedModel SignatureDef contains the following output(s):
    outputs['__saved_model_init_op'] tensor_info:
        dtype: DT_INVALID
        shape: unknown_rank
        name: NoOp
  Method name is:

signature_def['serving_default']:
  The given SavedModel SignatureDef contains the following input(s): 
    inputs['flatten_input'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 2)
        name: serving_default_flatten_input:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['dense_2'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 1)
        name: StatefulPartitionedCall:0
  Method name is: tensorflow/serving/predict

您需要查找将要使用的输入和输出的名称。这些名称如下:

        name: serving_default_flatten_input:0

对于输入,以及

        name: StatefulPartitionedCall:0

为输出而做准备。

当你拥有这些内容后,可以将它们嵌入到你的代码中。

#include "tensorflow/cc/saved_model/loader.h"
#include "tensorflow/cc/saved_model/tag_constants.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/framework/logging.h" 

// ...

// We need to use SaveModelBundleLite as a in-memory model object for tensorflow's model bundle.
const auto savedModelBundle = std::make_unique<tensorflow::SavedModelBundleLite>();

// Create dummy options.
tensorflow::SessionOptions sessionOptions;
tensorflow::RunOptions runOptions;

// Load the model bundle.
const auto loadResult = tensorflow::LoadSavedModel(
        sessionOptions,
        runOptions,
        modelPath, //std::string containing path of the model bundle
        { tensorflow::kSavedModelTagServe },
        savedModelBundle.get());

// Check if loading was okay.
TF_CHECK_OK(loadResult);

// Provide input data.
tensorflow::Tensor tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({ 2 }));
tensor.vec<float>()(0) = 20.f;
tensor.vec<float>()(1) = 6000.f;

// Link the data with some tags so tensorflow know where to put those data entries.
std::vector<std::pair<std::string, tensorflow::Tensor>> feedInputs = { {"serving_default_flatten_input:0", tensor} };
std::vector<std::string> fetches = { "StatefulPartitionedCall:0" };

// We need to store the results somewhere.
std::vector<tensorflow::Tensor> outputs;

// Let's run the model...
auto status = savedModelBundle->GetSession()->Run(feedInputs, fetches, {}, &outputs);
TF_CHECK_OK(status);

// ... and print out it's predictions.
for (const auto& record : outputs) {
    LOG(INFO) << record.DebugString();
}

运行这个代码会得到以下结果:
Directory ./model_bundle does contain a model.
2022-08-03 10:50:43.367619: I tensorflow/cc/saved_model/reader.cc:43] Reading SavedModel from: ./model_bundle 
2022-08-03 10:50:43.370764: I tensorflow/cc/saved_model/reader.cc:81] Reading meta graph with tags { serve }
2022-08-03 10:50:43.370862: I tensorflow/cc/saved_model/reader.cc:122] Reading SavedModel debug info (if present) from: ./model_bundle 
2022-08-03 10:50:43.371034: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX AVX2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-08-03 10:50:43.390553: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:354] MLIR V1 optimization pass is not enabled
2022-08-03 10:50:43.391459: I tensorflow/cc/saved_model/loader.cc:228] Restoring SavedModel bundle.
2022-08-03 10:50:43.426841: I tensorflow/cc/saved_model/loader.cc:212] Running initialization op on SavedModel bundle at path: ./model_bundle 
2022-08-03 10:50:43.433764: I tensorflow/cc/saved_model/loader.cc:301] SavedModel load for tags { serve }; Status: success: OK. Took 66144 microseconds.
2022-08-03 10:50:43.450891: I TensorflowPoC.cpp:46] Tensor<type: float shape: [1,1] values: [-1667.12402]>

TensorflowPoC.exe (process 21228) exited with code 0.

0
你需要检查输入和输出的名称。使用tensorboard选项来显示模型结构,它在Graph选项卡中。或者使用Netron等网络查看器。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接