如何在TensorFlow图中的每个节点中获取输入形状?

5

你好:现在我正在将tensorflow的checkpoint模型转换为caffe模型。我已经成功地读取了图并提取了每个节点中的属性值。我得到了“Conv2D”节点中“dilations”,“strides”和“padding”属性的值以及“weights”节点中的形状,但是我无法获取“shape”属性的值,因为在Conv2D的输入节点中为空。然而,这些形状在tensorboard的图表中显示出来了。 以下是我的代码:

new_saver = tf.train.import_meta_graph(meta_path)          
new_saver.restore(sess, tf.train.latest_checkpoint(ckpt_path))
graph_def = sess.graph_def
node_list = graph_def.node

# conv_node, weight_node, from_node are all in node_list
# conv_node: the conv2d node in graph_def
# weight_node: the weights node of conv2d
# from_node: the input feature map node of conv2d

weight_shape_attr = weight_node.attr['shape']
weight_shapes = [dim.size for dim in weight_shape_attr.shape.dim]

strides = [ii for ii in conv_node.attr['strides'].list.i]
dilations = [ii for ii in conv_node.attr['dilations'].list.i]

shapes = from_node.attr['shape']  # this is empty

还有 Tensorboard 图: tensorboard_graph

请注意 Conv2D 节点的输入形状为 ?x79x79x32,它必须存储在模型文件中的某个位置。能否有人提供帮助?任何提示都会很有帮助,谢谢。

1个回答

6

Tensorflow图形有一个名为as_graph_def的方法,它具有可选参数add_shapes(默认为False)。如果设置为True,则会在节点的附加属性中得到_output_shapes。< / p>

因此,您可以尝试以以下方式获取GraphDef:

graph_def = sess.graph.as_graph_def(add_shapes=True)

我看到现在有属性_output_shape,但如何打印所有节点的形状?这个不行:print([n._output_shapes for n in tf.get_default_graph().as_graph_def(add_shapes=True).node]) - Primoz
@Primoz,“_output_shapes”是一个属性,因此应像这样访问它:shapes = node.attr['_output_shapes'],它会提供另一个protobuf对象,可以沿着这个路径导航到整数值,如:shapes.list.shape[0].dim[0].size。可能有更方便的访问值的方法,但我不知道。 - dm0_
我在Python API中找不到add_shapes函数? - Danijel
如果output_shapes不为空且output_shapes.list.shape长度大于0,则返回[dim.size for dim in output_shapes.list.shape[0].dim],否则返回''。 - Andrey

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接