CPython字符串拼接优化失败案例

9

问题

为什么在CPython中,

def add_string(n):
    s = ''
    for _ in range(n):
        s += ' '

需要花费线性时间,但是

def add_string_in_list(n):
    l = ['']
    for _ in range(n):
        l[0] += ' '

需要二次时间?


证明:

Timer(partial(add_string, 1000000)).timeit(1)
#>>> 0.1848409200028982
Timer(partial(add_string, 10000000)).timeit(1)
#>>> 1.1123797750042286

Timer(partial(add_string_in_list, 10000)).timeit(1)
#>>> 0.0033865350123960525
Timer(partial(add_string_in_list, 100000)).timeit(1)
#>>> 0.25131178900483064

我的理解

CPython在字符串引用计数为1时有一个优化,用于字符串相加。

这是因为Python中的字符串是不可变的,因此通常不能进行编辑。如果存在对字符串的多个引用,并且对其进行了更改,所有引用都将看到更改后的字符串。显然,这并不是我们想要的,因此不能使用多个引用进行更改操作。

但是,如果只有一个引用指向该字符串,则更改该值只会更改该引用所需的字符串。您可以通过以下方式测试这可能是原因:

from timeit import Timer
from functools import partial

def add_string_two_references(n):
    s = ''
    for _ in range(n):
        s2 = s
        s += ' '

Timer(partial(add_string_two_references, 20000)).timeit(1)
#>>> 0.032532954995986074
Timer(partial(add_string_two_references, 200000)).timeit(1)
#>>> 1.0898985149979126

我不确定为什么这个因素只有30倍,而不是预期的100倍,但我认为这是开销问题。


我不知道的事情

那么为什么列表版本会创建两个引用呢?这是否是阻止优化的原因?

您可以检查它是否对普通对象进行了任何区别处理:

class Counter:
    def __iadd__(self, other):
        print(sys.getrefcount(self))

s = Counter()
s += None
#>>> 6

class Counter:
    def __iadd__(self, other):
        print(sys.getrefcount(self))

l = [Counter()]
l[0] += None
#>>> 6

如果一个字符串存在多个引用并且它发生了改变,那么所有的引用都会看到这个改变后的字符串。但是这并不是真实情况。 - ElmoVanKielmo
"it changes" ≜ "被改变"; 我会修正措辞。// 已修正。 - Veedrac
不管是 x = 'a'; y = x; x = 'b' 还是 x = 'a'; y = x; x += 'b'y 都仍然是 'a'。 - ElmoVanKielmo
2
我同意。但在那种情况下,字符串没有被改变。只有当安全时(即只有一个引用)CPython才会改变字符串;否则y会改变,这是不好的。 - Veedrac
2
尽管 += 是一种特殊的部分优化情况,但建议使用 .join(),因为它在所有实现中都得到保证的优化,并且将受到开发人员的青睐。 - Davidmh
2个回答

9
在基于列表的方法中,从列表的索引0处取出字符串并进行修改,然后将其放回到索引0处的列表中。
在这个短暂的时刻,解释器仍然拥有列表中旧版本的字符串,并且无法执行原地修改。
如果您查看Python源代码,则会发现没有支持原地修改列表元素。因此,必须从列表中检索对象(在本例中为字符串),对其进行修改,然后将其放回。
换句话说,list类型完全不知道str类型对+=运算符的支持。

请考虑以下代码:

l = ['abc', 'def']
def nasty():
    global l
    l[0] = 'ghi'
    l[1] = 'jkl'
    return 'mno'
l[0] += nasty()

l 的值为 ['abcmno', 'jkl'],这证明了从列表中取出了 'abc',然后执行了 nasty() 修改了列表的内容,将字符串 'abc''mno' 进行了连接,并将结果赋给了 l[0]。如果在访问并就地修改 l[0] 之前评估了 nasty(),则结果将是 'ghimno'


6
所以为什么列表版本会创建两个引用?
l[0] += ' ' 中,一个引用在 l[0] 中。另一个引用是暂时创建的,用于执行 += 操作。
以下是两个更简单的函数来展示这种效果:
>>> def f():
...     l = ['']
...     l[0] += ' '
...     
>>> def g():
...     s = ''
...     s += ' '
...     

拆解它们可以得到

>>> from dis import dis
>>> dis(f)
  2           0 LOAD_CONST               1 ('')
              3 BUILD_LIST               1
              6 STORE_FAST               0 (l)

  3           9 LOAD_FAST                0 (l)
             12 LOAD_CONST               2 (0)
             15 DUP_TOPX                 2
             18 BINARY_SUBSCR       
             19 LOAD_CONST               3 (' ')
             22 INPLACE_ADD         
             23 ROT_THREE           
             24 STORE_SUBSCR        
             25 LOAD_CONST               0 (None)
             28 RETURN_VALUE        
>>> dis(g)
  2           0 LOAD_CONST               1 ('')
              3 STORE_FAST               0 (s)

  3           6 LOAD_FAST                0 (s)
              9 LOAD_CONST               2 (' ')
             12 INPLACE_ADD         
             13 STORE_FAST               0 (s)
             16 LOAD_CONST               0 (None)
             19 RETURN_VALUE        

f中,BINARY_SUBSCR(切片)指令将l[0]置于VM堆栈顶部。DUP_TOPX在堆栈上复制顶部的n项。两个函数(请参见ceval.c)都增加了引用计数;DUP_TOPX(在Py3中为DUP_TOP_TWO)直接执行,而BINARY_SUBSCR则使用PyObject_GetItem。因此,字符串的引用计数现在至少为三。 g没有这个问题。它在使用LOAD_FAST将项推入堆栈时创建一个额外的引用,使得引用计数为两个,这是VM堆栈上项目的最小数量,因此它可以进行优化。

1
+1 我的答案和你一样,但是我花了更长时间来写它 :) - chepner
@chepner,看到你使用完全相同的测试函数,我感到很惊讶——除了你的“g”是我的“f”,而你的“f”则是我的“g” :) - Fred Foo
我的答案中的代码与您的反汇编输出完全相同。 - ElmoVanKielmo
非常感谢。我没有预料到DUP_TOP_TWO会增加引用计数。此外,非常感谢提供源链接,这总是很好的。 - Veedrac
2
我想我会把接受权给ElmoVanKielmo,但你们两个都真的很值得 :). - Veedrac

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接