为什么这个二分查找优化会慢很多?

8
一项所谓的优化使代码变慢了两倍以上。
我通过找到值x出现的范围来计算在排序列表a中值x的出现次数:
from bisect import bisect_left, bisect_right

def count(a, x):
    start = bisect_left(a, x)
    stop = bisect_right(a, x)
    return stop - start

但是,嘿,事情还没开始,我们可以通过省略开始之前的部分来优化第二次搜索(文档):

def count(a, x):
    start = bisect_left(a, x)
    stop = bisect_right(a, x, start)
    return stop - start

但是当我进行基准测试时,优化后的版本花费的时间超过了两倍

 254 ms ±  1 ms  original
 525 ms ±  2 ms  optimized

为什么?

这个基准测试会生成一个从0到99999的一千万个随机整数的已排序列表,然后计算所有不同的整数(只是为了基准测试,不需要指出Counter)(在线试用!):

import random
from bisect import bisect_left, bisect_right
from timeit import repeat
from statistics import mean, stdev

def original(a, x):
    start = bisect_left(a, x)
    stop = bisect_right(a, x)
    return stop - start

def optimized(a, x):
    start = bisect_left(a, x)
    stop = bisect_right(a, x, start)
    return stop - start

a = sorted(random.choices(range(100_000), k=10_000_000))
unique = set(a)

def count_all():
    for x in unique:
        count(a, x)
for count in original, optimized:
    times = repeat(count_all, number=1)
    ts = [t * 1e3 for t in sorted(times)[:3]]
    print(f'{round(mean(ts)):4} ms ± {round(stdev(ts)):2} ms ', count.__name__)

2
当两个调用在完全相同的数组段上运行时,是否会有缓存奖励? - slothrop
1
另一个想法:我们对这个特定数据有一些“了解”,而二分查找算法没有,即对于大多数x的值,stop不会离start太远(平均来说,在100个位置的10百万项列表中)。考虑到这一点,给出“提示”是否真的有助于算法?实际上,这使得搜索首先尝试((len(a)+start)/2),这在平均情况下会太靠右。这是否实际上比从len(a)/2开始更糟糕? - slothrop
1
https://tio.run/##pVNLstMwENz7FLOLDSJlExZUCh@AM7x6pVLscaJCPyT51QOKQ3EFLhYkjZKYDRu8sXu6Z3o0I7tv8WLN4aPz16vUzvoIXpjZ6mbxVsNJBpwiVIYQV7hEdgNeni@RxFFqlHexR4eiMiGKKEOUU7ixGoVhKT7jS9PMuIBNhaQRqhUMXrtjA@lJaUk6bn2Jrqx1D7L0sWE9xtUbEr2jStXIpT7ld5z/34lR3j8MJyVCgM8mkk3251waGTlvA6qFwYtQK9YuileK7kswOZb3JlM98my8oN/knZU9CQWT1S7cgwXB2xGGe@jW5sPmE9Ui1DQi@Ya0ojQgugn76WLlhCHDM7ZDz/u@7xh8GYf8VVDXrEZ@LT0HTMPpShmVdt5q4do0AAZiq/qbonDiy4Ymu5rIhVJtPd9iPbyCNFV33JwvKeuisqjgLLzdJvZYN2XRREboC8o3NiO6q@3dmIFZ9Qn9ONBunZcm0ymXwS6/hZfBmrAjPuYiTxHewICH0m7poo6xuHRPx8Pz86basvvhk9/c5l@hjaHrjh9@gg7w@xdUpvweRL0v1I7REfecG6GR8@56/QM - Kelly Bundy
3
我在想,如果没有提示,bisect_right所访问的数组元素很大程度上与bisect_left相同,并从中获得缓存优势。 假设我们的值x在一个10k个元素的数组中的位置为850到900。 bisect_left访问的元素为5000、2500、1250、625、938、782、860、821、841、851、846、849、850。 "未优化"的bisect_right尝试访问的元素为:5000、2500、1250、625、938、782、860、899、919、909、904、902、901、900。 因此,“未优化”的bisect_right的前6个列表访问刚好被bisect_left访问过:它们是否可用于缓存,因此存在时间差异? - slothrop
2
@KellyBundy:从整个程序的角度来看,Python 太随机了,使得这种方法并不是很有用(生成数据会产生许多页面错误)。尽管如此,一个小小的调整可以减少任何缓存问题的影响,就是确保所有具有相同值的 int 是相同的对象,并且它们一次性按顺序生成(因此它们在内存中很可能是连续的)。你所要做的就是将 range(100_000) 改为 tuple(range(100_000)。这样做可以将“原始”提高约 34%,将“优化”提高约 61%。减少缓存的影响更能改善优化,这意味着它实际上是一个缓存未命中问题。 - ShadowRanger
显示剩余11条评论
2个回答

1

在基准测试中有一些方面会触发不良的缓存效应。

首先,我敢打赌这个断言对你来说也会通过(就像对我来说一样):

assert list(unique) == sorted(unique)

不能保证一定能通过,但基于CPython的set类型和整数哈希实现到目前为止,很有可能能够通过。

这意味着你的for x in unique试图以严格递增的顺序尝试x。这使得bisect_left()内部的潜在探测序列在一个x到另一个x时非常相似,因此比较的许多值可能正在缓存中。原始版中的bisect_right()也是如此,但在优化版本中,bisect_right()的潜在探测序列在尝试之间不同,因为起始索引在尝试之间也不同。

要使两个版本都“大大减慢”,请在assert之后添加以下内容:

unique = list(unique)
random.shuffle(unique)

现在,在尝试中输入的变量x没有规律性,因此在各次尝试中也不存在潜在的探测序列系统相关性。
其他缓存效应出现在单个尝试中。在原始版本中,bisect_left()bisect_right()之间的潜在探测序列完全相同。阅读以解决bisect_left()的条目很可能仍然保留在缓存中,供bisect_right()重用。
但是,在经过优化的版本中,潜在的探测序列是不同的,因为切片边界不同。例如,bisect_left()将始终从将xa[5000000]进行比较开始。在原始版本中,bisect_right()也将始终通过进行相同的比较来开始,但在经过优化的版本中,它几乎总是会选择一个不同的a索引开始-而这个索引纯粹是由于运气而等待在缓存中。
话虽如此,我通常在自己的代码中使用您的优化。但这是因为我通常有比整数比较更昂贵的比较操作,因此节省一次比较的价值远大于节省一些缓存未命中。小整数的比较非常便宜,因此节省其中一些的价值很小。

哈,最初我已经洗牌了,以避免集合的“排序”。然后为了简化基准测试而将其删除,因为速度比仍然相似。我认为还有另一个缓存效应,也被 slothrop 指出:即使洗牌,二分树中的“顶部”索引始终被全部使用(当使用完整范围时,但不是当使用起始位置时)。 - Kelly Bundy
是的,这就是我在回答的后半部分所得出的结论:“其他缓存效应发生在单次尝试中...”。 - Tim Peters
值得注意的是,“单次尝试中的其他缓存效应”不仅仅是糟糕基准测试的结果,这是优化本身的自然结果。您说得对,它确实只适用于廉价比较;如果我调整基准测试使比较稍微昂贵一些(通过创建一个带有重载的int子类__lt__,无条件返回NotImplemented,因此它最终执行一个无用的Python调用,然后回退到__gt__的C实现),以及为每个值使用唯一实例,则“optimized”将其捆绑在一起。 - ShadowRanger
在线尝试! - ShadowRanger
@Tim 好的... 我理解"在同一个try块内"是这样一个限定条件。听起来像你只在每个左右搜索对之间讨论,在bisect_right中的后续讨论并没有扩大这种印象,因为它专注于命中/未命中。 - Kelly Bundy
显示剩余2条评论

0

我尝试在Python中模拟缓存,并测量各种缓存大小的缓存未命中率:

        ORIGINAL:              OPTIMIZED:
cache |         cache-misses |         cache-misses |
 size |  time   line   item  |  time   line   item  |  
------+----------------------+----------------------+
 1024 | 1.98 s  59.1%  16.4% | 4.90 s  74.4%  57.1% |
 2048 | 2.30 s  59.1%  16.4% | 5.28 s  72.5%  56.8% |
 4096 | 2.16 s  59.0%  16.4% | 5.30 s  70.4%  56.4% |
 8192 | 2.33 s  59.0%  16.4% | 6.09 s  68.2%  56.0% |
16384 | 2.80 s  59.0%  16.4% | 6.30 s  65.8%  55.6% |

我为列表使用了代理对象。获取列表项通过getitem函数进行,该函数具有最左列中显示的LRU缓存大小。而且getitem也不直接访问列表。它通过getline函数进行,该函数获取“缓存行”,即8个连续列表元素的块。它具有缓存大小除以8的LRU缓存。

这远非完美,即与真实情况相比,无法测量真正的缓存未命中,特别是因为它仅模拟在列表中缓存引用而不是列表元素对象。但我仍然觉得它很有趣。我的函数的原始版本显示出较少的缓存未命中,并且未命中率似乎在各种缓存大小中都相当稳定。优化版本显示出更多的缓存未命中,较大的缓存大小有助于降低未命中率。

我的代码(在此在线尝试!):

import random
from bisect import bisect_left, bisect_right
from timeit import timeit
from statistics import mean, stdev
from functools import lru_cache

def original(a, x):
    start = bisect_left(a, x)
    stop = bisect_right(a, x)
    return stop - start

def optimized(a, x):
    start = bisect_left(a, x)
    stop = bisect_right(a, x, start)
    return stop - start

a = sorted(random.choices(range(100_000), k=10_000_000))
unique = set(a)

class Proxy:
    __len__ = a.__len__
    def __getitem__(self, index):
        return getitem(index)
p = Proxy()

def count_all():
    for x in unique:
        count(p, x)

linesize = 8

print('''        ORIGINAL:              OPTIMIZED:
cache |         cache misses |         cache misses |
 size |  time   line   item  |  time   line   item  |  
------+----------------------+----------------------+''')

for cachesize in 1024, 2048, 4096, 8192, 16384:
    print(f'{cachesize:5} |', end='')

    @lru_cache(cachesize // linesize)
    def getline(i):
        i *= linesize
        return a[i : i+linesize]
    
    @lru_cache(cachesize)
    def getitem(index):
        q, r = divmod(index, linesize)
        return getline(q)[r]
    
    for count in original, optimized:
        getline.cache_clear()
        getitem.cache_clear()
        time = timeit(count_all, number=1)
        def misses(func):
            ci = func.cache_info()
            misses = ci.misses / (ci.misses + ci.hits)
            return f'{misses:.1%}'
        print(f'{time:5.2f} s  {misses(getline)}  {misses(getitem)}', end=' |')
    print()

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接