使用Python AST获取所有与给定名称相对应的变量的节点

8
请看下面的代码:

考虑下面的代码:

1 | x = 20
2 | 
3 | def f():
4 |     x = 0
5 |     for x in range(10):
6 |         x += 10
7 |     return x
8 | f()
9 |
10| for x in range(10):
11|     pass
12| x += 1
13| print(x)

上述代码执行后,x的值为10。现在,如何获取所有类名为Name,且其idx,并且引用了第1行、第10行、第12行和第13行中使用的x的节点?
换句话说,函数f中的x与其他x不同。是否可以在仅有脚本和脚本AST而不是执行它的情况下获取它们的AST节点?
1个回答

20
当遍历AST树时,跟踪上下文;从全局上下文开始,然后遇到FunctionDefClassDefLambda节点时,将该上下文记录为堆栈(在退出相关节点时弹出堆栈)。
您可以仅查看全局上下文中的Name节点。您也可以跟踪global标识符(我会使用每个堆栈级别的集合)。
使用NodeVisitor子类
import ast

class GlobalUseCollector(ast.NodeVisitor):
    def __init__(self, name):
        self.name = name
        # track context name and set of names marked as `global`
        self.context = [('global', ())]

    def visit_FunctionDef(self, node):
        self.context.append(('function', set()))
        self.generic_visit(node)
        self.context.pop()

    # treat coroutines the same way
    visit_AsyncFunctionDef = visit_FunctionDef

    def visit_ClassDef(self, node):
        self.context.append(('class', ()))
        self.generic_visit(node)
        self.context.pop()

    def visit_Lambda(self, node):
        # lambdas are just functions, albeit with no statements, so no assignments
        self.context.append(('function', ()))
        self.generic_visit(node)
        self.context.pop()

    def visit_Global(self, node):
        assert self.context[-1][0] == 'function'
        self.context[-1][1].update(node.names)

    def visit_Name(self, node):
        ctx, g = self.context[-1]
        if node.id == self.name and (ctx == 'global' or node.id in g):
            print('{} used at line {}'.format(node.id, node.lineno))

示例(假设您的示例代码在t中的AST树已知):

>>> GlobalUseCollector('x').visit(t)
x used at line 1
x used at line 10
x used at line 12
x used at line 13

在函数中使用global x

>>> u = ast.parse('''\
... x = 20
...
... def g():
...     global x
...     x = 0
...     for x in range(10):
...         x += 10
...     return x
...
... g()
... for x in range(10):
...     pass
... x += 1
... print(x)
... ''')
>>> GlobalUseCollector('x').visit(u)
x used at line 1
x used at line 5
x used at line 6
x used at line 7
x used at line 8
x used at line 11
x used at line 13
x used at line 14

1
在Python 3.5+中,您还应该为async def函数定义visit_AsyncFunctionDef(self, node) - pahaz
1
@pahaz,我认为一个别名就足够了:visit_AsyncFuncDef = visit_FuncDef - Martijn Pieters

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接