为什么random.sample比numpy的random.choice更快?

18

我需要一种无重复抽样数组a的方法。我尝试了两种方法(见以下MCVE),使用了random.sample()np.random.choice

我认为numpy函数会更快,但事实证明它并不是如此。在我的测试中,random.samplenp.random.choice快约15%。

这是否正确?如果是,为什么?

import numpy as np
import random
import time
from contextlib import contextmanager


@contextmanager
def timeblock(label):
    start = time.clock()
    try:
        yield
    finally:
        end = time.clock()
        print ('{} elapsed: {}'.format(label, end - start))


def f1(a, n_sample):
    return random.sample(range(len(a)), n_sample)


def f2(a, n_sample):
    return np.random.choice(len(a), n_sample, replace=False)


# Generate random array
a = np.random.uniform(1., 100., 10000)
# Number of samples' indexes to randomly take from a
n_sample = 100
# Number of times to repeat functions f1 and f2
N = 100000

with timeblock("random.sample"):
    for _ in range(N):
        f1(a, n_sample)

with timeblock("np.random.choice"):
    for _ in range(N):
        f2(a, n_sample)

5
https://github.com/numpy/numpy/issues/2764 - ayhan
我明白了,这是一个长期存在的问题。@ayhan,你能否把你的评论写成答案,这样我就可以标记为已接受的答案了吗? - Gabriel
3
我认为问题在于np.random.choice进行无重复的随机抽样,方法是生成数组中所有索引的排列,然后取其中的前n_sample个(参见这一行)。如果n_sample远小于数组a中的元素数量,则这种方法会变得非常低效。 - ali_m
3
另一方面,random.sample 函数只会抽取 n_samples 个随机样本。它可以通过两种方式之一实现这一点——如果 n_samples << N,则跟踪已选择的项目(item)来完成;或者维护一个可选候选项的收缩池(if n_samples 相对于 N 较大)。你可以在此处查看源代码——它相当简单易读。 - ali_m
1个回答

17

TL;DR 自numpy v1.17.0起,建议使用numpy.random.default_rng()对象而不是numpy.random。对于选择:

import numpy as np

rng = np.random.default_rng()    # you can pass seed
rng.choice(...)    # interface is the same

除了其他在v1.17中引入的随机API更改之外,此新版本的choice现在更加智能,并且在大多数情况下应该是最快的。旧版本保持不变以实现向后兼容性!


正如评论中所提到的,numpy一直存在一个问题,即np.random.choice实现对于k<<n与Python标准库中的random.sample相比是不够有效的。
问题在于np.random.choice(arr, size=k, replace=False)被实现为permutation(arr)[:k]。在数组很大而k很小的情况下,计算整个数组的排列是浪费时间和内存的。标准Python的random.sample以更加直接的方式工作——它只是迭代地进行采样,同时记录已经采样的内容或者从哪里采样。
在v1.17.0中,numpy介绍了numpy.random包的重构和改进(文档更新内容性能)。我强烈建议您至少查看第一个链接。请注意,正如其所述,为了向后兼容,旧的numpy.random API保持不变,仍然使用旧的实现。
因此,现在推荐使用随机API的新方法是使用numpy.random.default_rng() 对象而不是numpy.random。请注意,它是一个对象,它还接受可选的种子参数,因此您可以方便地传递它。它还默认使用了一个速度更快的不同生成器(有关详细信息,请参见上面的性能链接)。
关于您的情况,现在您可以使用np.random.default_rng().choice(...)。除了更快之外,由于改进后的随机生成器,choice本身也变得更加智能。现在它仅对足够大的数组(>10000个元素)和相对较大的k(>1/50大小)使用整个数组排列。否则,它使用Floyd的抽样算法(简短描述numpy实现)。

这是在我的笔记本电脑上的性能比较:

从包含10000个元素的数组中取100个样本,重复进行10000次:

random.sample elapsed: 0.8711776689742692
np.random.choice elapsed: 1.9704092079773545
np.random.default_rng().choice elapsed: 0.818919860990718

从10000个元素的数组中取1000个样本,重复进行10000次:

random.sample elapsed: 8.785315042012371
np.random.choice elapsed: 1.9777243090211414
np.random.default_rng().choice elapsed: 1.05490942299366

从10000个元素的数组中取10000个样本,重复10000次:

random.sample elapsed: 80.15063399000792
np.random.choice elapsed: 2.0218082449864596
np.random.default_rng().choice elapsed: 2.8596064270241186

我使用的代码:

import numpy as np
import random
from timeit import default_timer as timer
from contextlib import contextmanager


@contextmanager
def timeblock(label):
    start = timer()
    try:
        yield
    finally:
        end = timer()
        print ('{} elapsed: {}'.format(label, end - start))


def f1(a, n_sample):
    return random.sample(range(len(a)), n_sample)


def f2(a, n_sample):
    return np.random.choice(len(a), n_sample, replace=False)


def f3(a, n_sample):
    return np.random.default_rng().choice(len(a), n_sample, replace=False)


# Generate random array
a = np.random.uniform(1., 100., 10000)
# Number of samples' indexes to randomly take from a
n_sample = 100
# Number of times to repeat tested functions
N = 100000

print(f'{N} times {n_sample} samples')
with timeblock("random.sample"):
    for _ in range(N):
        f1(a, n_sample)

with timeblock("np.random.choice"):
    for _ in range(N):
        f2(a, n_sample)

with timeblock("np.random.default_rng().choice"):
    for _ in range(N):
        f3(a, n_sample)

有趣的是,np.random.RandomState().choice 的速度比 np.random.default_rng().choice 慢了4倍以上。 - isarandi
我意识到 RandomState 现在已经被弃用了。 - isarandi

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接