如何将本地训练的TensorFlow图文件部署到Google Cloud平台?

3
我已经按照 TensorFlow for Poets 教程,使用自己的分类替换了原有的 flower_photos 数据集。现在我的 labels.txt 和 graph.pb 文件已经保存在本地计算机上。
我想知道有没有办法将这个预训练模型部署到 Google Cloud Platform 上?我已经阅读了文档,但只找到了有关如何在其 ML Engine 中创建、训练和部署模型的说明。但是,我不想在 Google 的服务器上花费金钱来训练我的模型,我只需要它们来托管我的模型,以便可以调用它进行预测。
有其他人遇到同样的问题吗?
2个回答

1
部署本地训练模型是一种受支持的用例;无论您在哪里进行了培训,说明 基本相同:

要部署模型版本,您需要:

在 Google Cloud Storage 上保存的 TensorFlow SavedModel。您可以通过以下方式获取模型:

  • 按照云 ML 引擎训练步骤在云中进行训练。

  • 在其他地方进行训练并导出为 SavedModel。

不幸的是,TensorFlow for Poets 没有展示如何导出 SavedModel(我已经提交了一个功能请求来解决这个问题)。与此同时,您可以编写一个 "转换器" 脚本,例如下面的脚本(您也可以在训练结束时执行此操作,而不是保存 graph.pb 并重新读取它):

input_graph = 'graph.pb'
saved_model_dir = 'my_model'

with tf.Graph() as graph:
  # Read in the export graph
  with tf.gfile.FastGFile(input_graph, 'rb') as f:
      graph_def = tf.GraphDef()
      graph_def.ParseFromString(f.read())
      tf.import_graph_def(graph_def, name='')

  # CloudML Engine and early versions of TensorFlow Serving do
  # not currently support graphs without variables. Add a
  # prosthetic variable.
  dummy_var = tf.Variable(0)

  # Define SavedModel Signature (inputs and outputs)
  in_image = graph.get_tensor_by_name('DecodeJpeg/contents:0')
  inputs = {'image_bytes': 
tf.saved_model.utils.build_tensor_info(in_image)}

  out_classes = graph.get_tensor_by_name('final_result:0')
  outputs = {'prediction': tf.saved_model.utils.build_tensor_info(out_classes)}

  signature = tf.saved_model.signature_def_utils.build_signature_def(
      inputs=inputs,
      outputs=outputs,
      method_name='tensorflow/serving/predict'
  )

  # Save out the SavedModel.
  b = saved_model_builder.SavedModelBuilder(saved_model_dir)
  b.add_meta_graph_and_variables(sess,
                                 [tf.saved_model.tag_constants.SERVING],
                                 signature_def_map={'predict_images': signature})
  b.save() 

(基于此代码实验室此SO帖子的未经测试的代码)。

如果您希望输出使用字符串标签而不是整数索引,请进行以下更改:

  # Loads label file, strips off carriage return
  label_lines = [line.rstrip() for line 
                 in tf.gfile.GFile("retrained_labels.txt")]
  out_classes = graph.get_tensor_by_name('final_result:0')
  out_labels = tf.gather(label_lines, ot_classes)
  outputs = {'prediction': tf.saved_model.utils.build_tensor_info(out_labels)}

1
很抱歉,我的回答只是部分解答。我已经能够完成这个任务...但还存在一些问题尚未解决。我将训练好的pb和txt文件移植到了服务器上,安装了Tensorflow,并通过HTTP请求调用训练好的模型。第一次运行时它完美地工作...然后每隔一次就会失败。 tensorflow deployment on openshift, errors with gunicorn and mod_wsgi 惊讶于没有更多的人试图解决这个普遍问题。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接