在Tensorflow中,获取图中所有张量的名称。

142

我正在使用 Tensorflowskflow 创建神经网络;由于某些原因,我想要获取给定输入的一些内部张量的值,因此我使用 myClassifier.get_layer_value(input, "tensorName"),其中 myClassifier 是一个 skflow.estimators.TensorFlowEstimator

但是,即使知道张量名称(而且我在操作和张量之间感到困惑),我仍然很难找到张量名称的正确语法,因此我正在使用 tensorboard 来绘制图形并查找张量名称。

有没有一种方法可以枚举出图中的所有张量而不使用 tensorboard?

10个回答

203

你可以做

[n.name for n in tf.get_default_graph().as_graph_def().node]

此外,如果您正在IPython笔记本中进行原型制作,您可以直接在笔记本中显示图形,请参见Alexander的Deep Dream notebook中的show_graph函数。


2
你可以通过在推导式末尾添加 if "Variable" in n.op 来过滤例如变量这样的内容。 - Radu
如果您知道名称,是否有一种方法可以获取特定节点? - Rocket Pingu
要了解有关图节点的更多信息,请访问:https://www.tensorflow.org/extend/tool_developers/#nodes - Ivan Talalaev
4
上述命令返回所有操作/节点的名称。若要获取所有张量的名称,请执行以下命令:tensors_per_node = [node.values() for node in graph.get_operations()] tensor_names = [tensor.name for tensors in tensors_per_node for tensor in tensors] - gebbissimo

42

我会尝试总结答案:

获取图中所有节点:(类型为tensorflow.core.framework.node_def_pb2.NodeDef

all_nodes = [n for n in tf.get_default_graph().as_graph_def().node]

获取图中的所有操作:(输入tensorflow.python.framework.ops.Operation

all_ops = tf.get_default_graph().get_operations()

获取图中的所有变量(输入:tensorflow.python.ops.resource_variable_ops.ResourceVariable

all_vars = tf.global_variables()

获取图中所有张量: (类型为tensorflow.python.framework.ops.Tensor)

all_tensors = [tensor for op in tf.get_default_graph().get_operations() for tensor in op.values()]

获取图中的所有占位符: (类型为tensorflow.python.framework.ops.Tensor

all_placeholders = [placeholder for op in tf.get_default_graph().get_operations() if op.type=='Placeholder' for placeholder in op.values()]

Tensorflow 2

要在Tensorflow 2中获取图形,您需要首先实例化一个tf.function,然后访问graph属性,而不是tf.get_default_graph(),例如:

graph = func.get_concrete_function().graph

其中func是一个tf.function


4
注意那个TF2版本! - ibarrond

25

使用 get_operations 比 Yaroslav 答案中提到的方法稍微快一些。这是一个快速的示例:

import tensorflow as tf

a = tf.constant(1.3, name='const_a')
b = tf.Variable(3.1, name='variable_b')
c = tf.add(a, b, name='addition')
d = tf.multiply(c, a, name='multiply')

for op in tf.get_default_graph().get_operations():
    print(str(op.name))

3
你无法使用 tf.get_operations() 获得张量,只能获取操作。 - Soulduck
@Soulduck,你可以使用op.values()获取每个操作的张量,例如: last_tensor = graph.get_operations()[-1].values() 其中,graph.get_operations()[-1]是图中的最后一个操作。 - Youcef4k

11

2
该函数已被弃用。 - CAFEBABE
8
它的后继函数是 tf.global_variables() - bluenote10
11
这仅获取变量,不包括张量。 - Rajarshee Mitra
在Tensorflow 1.9.0中,显示all_variables (来自tensorflow.python.ops.variables)已被弃用,并将在2017-03-02之后删除 - stackoverYC
module 'tensorflow' has no attribute 'all_variables' - skytree

5
我认为这个也可以:
print(tf.contrib.graph_editor.get_tensors(tf.get_default_graph()))

但是与Salvado和Yaroslav的答案相比,我不知道哪一个更好。


这个使用从冻结的inference_graph.pb文件导入的图表,与tensorflow目标检测API一起使用。谢谢。 - simo23

5

接受的答案只提供了一个带有名称的字符串列表。我更喜欢另一种方法,它可以让您(几乎)直接访问张量:

graph = tf.get_default_graph()
list_of_tuples = [op.values() for op in graph.get_operations()]

list_of_tuples 现在包含每个张量,每个张量都在一个元组中。您也可以调整它以直接获取张量:

graph = tf.get_default_graph()
list_of_tuples = [op.values()[0] for op in graph.get_operations()]

4

因为原帖请求的是张量列表,而不是操作/节点列表,所以代码应该稍有不同:

graph = tf.get_default_graph()    
tensors_per_node = [node.values() for node in graph.get_operations()]
tensor_names = [tensor.name for tensors in tensors_per_node for tensor in tensors]

3

之前的回答很好,我想分享一个我写的从图中选择张量的实用函数:

def get_graph_op(graph, and_conds=None, op='and', or_conds=None):
    """Selects nodes' names in the graph if:
    - The name contains all items in and_conds
    - OR/AND depending on op
    - The name contains any item in or_conds

    Condition starting with a "!" are negated.
    Returns all ops if no optional arguments is given.

    Args:
        graph (tf.Graph): The graph containing sought tensors
        and_conds (list(str)), optional): Defaults to None.
            "and" conditions
        op (str, optional): Defaults to 'and'. 
            How to link the and_conds and or_conds:
            with an 'and' or an 'or'
        or_conds (list(str), optional): Defaults to None.
            "or conditions"

    Returns:
        list(str): list of relevant tensor names
    """
    assert op in {'and', 'or'}

    if and_conds is None:
        and_conds = ['']
    if or_conds is None:
        or_conds = ['']

    node_names = [n.name for n in graph.as_graph_def().node]

    ands = {
        n for n in node_names
        if all(
            cond in n if '!' not in cond
            else cond[1:] not in n
            for cond in and_conds
        )}

    ors = {
        n for n in node_names
        if any(
            cond in n if '!' not in cond
            else cond[1:] not in n
            for cond in or_conds
        )}

    if op == 'and':
        return [
            n for n in node_names
            if n in ands.intersection(ors)
        ]
    elif op == 'or':
        return [
            n for n in node_names
            if n in ands.union(ors)
        ]

所以,如果您有一个包含操作的图表:
['model/classifier/dense/kernel',
'model/classifier/dense/kernel/Assign',
'model/classifier/dense/kernel/read',
'model/classifier/dense/bias',
'model/classifier/dense/bias/Assign',
'model/classifier/dense/bias/read',
'model/classifier/dense/MatMul',
'model/classifier/dense/BiasAdd',
'model/classifier/ArgMax/dimension',
'model/classifier/ArgMax']

那么,运行以下代码:
get_graph_op(tf.get_default_graph(), ['dense', '!kernel'], 'or', ['Assign'])

返回:

['model/classifier/dense/kernel/Assign',
'model/classifier/dense/bias',
'model/classifier/dense/bias/Assign',
'model/classifier/dense/bias/read',
'model/classifier/dense/MatMul',
'model/classifier/dense/BiasAdd']

1
以下解决方案适用于我在TensorFlow 2.3中使用 -
def load_pb(path_to_pb):
    with tf.io.gfile.GFile(path_to_pb, 'rb') as f:
        graph_def = tf.compat.v1.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name='')
        return graph

tf_graph = load_pb(MODEL_FILE)
sess = tf.compat.v1.Session(graph=tf_graph)

# Show tensor names in graph
for op in tf_graph.get_operations():
    print(op.values())

其中MODEL_FILE是您冻结图的路径。

引用自这里


0
这对我有用:
for n in tf.get_default_graph().as_graph_def().node:
    print('\n',n)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接