如何让TensorFlow的'import_graph_def'返回Tensors

7
如果我尝试导入保存的 TensorFlow 图定义:TensorFlow,那么就需要进行以下操作:
import tensorflow as tf
from tensorflow.python.platform import gfile

with gfile.FastGFile(FLAGS.model_save_dir.format(log_id) + '/graph.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
x, y, y_ = tf.import_graph_def(graph_def, 
                               return_elements=['data/inputs',
                                                'output/network_activation',
                                                'data/correct_outputs'],
                               name='')

返回的值并不是预期的张量(Tensor),而是其他类型的值。例如,我们期望得到 x,但实际得到的却是另外一种类型的值。
Tensor("data/inputs:0", shape=(?, 784), dtype=float32)

我理解

name: "data/inputs_1"
op: "Placeholder"
attr {
  key: "dtype"
  value {
    type: DT_FLOAT
  }
}
attr {
  key: "shape"
  value {
    shape {
    }
  }
}

也就是说,我得到的不是预期的张量x ,而是 x.op。这让我感到困惑,因为文档似乎表明我应该得到一个 Tensor (尽管有许多在那里,使它难以理解)。

如何使tf.import_graph_def返回特定的Tensor,以便我可以使用它们(例如,在加载模型或运行分析时进行喂养)?


代码的第二行应该是 from tensorflow.python.platform import gfile - tobe
1个回答

6

'data/inputs''output/network_activation''data/correct_outputs' 这些名称实际上是操作名。要让 tf.import_graph_def() 返回 tf.Tensor 对象,您应该将输出索引附加到操作名称,对于单输出操作,通常为 ':0'

x, y, y_ = tf.import_graph_def(graph_def, 
                               return_elements=['data/inputs:0',
                                                'output/network_activation:0',
                                                'data/correct_outputs:0'],
                               name='')

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接