TL;DR 自numpy v1.17.0起,建议使用numpy.random.default_rng()
对象而不是numpy.random
。对于选择:
import numpy as np
rng = np.random.default_rng()
rng.choice(...)
除了其他在v1.17中引入的随机API更改之外,此新版本的choice现在更加智能,并且在大多数情况下应该是最快的。旧版本保持不变以实现向后兼容性!
正如评论中所提到的,numpy一直存在一个问题,即
np.random.choice
实现对于
k<<n
与Python标准库中的
random.sample
相比是不够有效的。
问题在于
np.random.choice(arr, size=k, replace=False)
被实现为
permutation(arr)[:k]
。在数组很大而k很小的情况下,计算整个数组的排列是浪费时间和内存的。标准Python的
random.sample
以更加直接的方式工作——它只是迭代地进行采样,同时记录已经采样的内容或者从哪里采样。
在v1.17.0中,numpy介绍了
numpy.random
包的重构和改进(
文档,
更新内容,
性能)。我强烈建议您至少查看第一个链接。请注意,正如其所述,为了向后兼容,旧的
numpy.random
API保持不变,仍然使用旧的实现。
因此,现在推荐使用随机API的新方法是使用
numpy.random.default_rng()
对象而不是
numpy.random
。请注意,它是一个对象,它还接受可选的种子参数,因此您可以方便地传递它。它还默认使用了一个速度更快的不同生成器(有关详细信息,请参见上面的性能链接)。
关于您的情况,现在您可以使用
np.random.default_rng().choice(...)
。除了更快之外,由于改进后的随机生成器,
choice
本身也变得更加智能。现在它仅对足够大的数组(>10000个元素)和相对较大的k(>1/50大小)使用整个数组排列。否则,它使用Floyd的抽样算法(
简短描述,
numpy实现)。
这是在我的笔记本电脑上的性能比较:
从包含10000个元素的数组中取100个样本,重复进行10000次:
random.sample elapsed: 0.8711776689742692
np.random.choice elapsed: 1.9704092079773545
np.random.default_rng().choice elapsed: 0.818919860990718
从10000个元素的数组中取1000个样本,重复进行10000次:
random.sample elapsed: 8.785315042012371
np.random.choice elapsed: 1.9777243090211414
np.random.default_rng().choice elapsed: 1.05490942299366
从10000个元素的数组中取10000个样本,重复10000次:
random.sample elapsed: 80.15063399000792
np.random.choice elapsed: 2.0218082449864596
np.random.default_rng().choice elapsed: 2.8596064270241186
我使用的代码:
import numpy as np
import random
from timeit import default_timer as timer
from contextlib import contextmanager
@contextmanager
def timeblock(label):
start = timer()
try:
yield
finally:
end = timer()
print ('{} elapsed: {}'.format(label, end - start))
def f1(a, n_sample):
return random.sample(range(len(a)), n_sample)
def f2(a, n_sample):
return np.random.choice(len(a), n_sample, replace=False)
def f3(a, n_sample):
return np.random.default_rng().choice(len(a), n_sample, replace=False)
a = np.random.uniform(1., 100., 10000)
n_sample = 100
N = 100000
print(f'{N} times {n_sample} samples')
with timeblock("random.sample"):
for _ in range(N):
f1(a, n_sample)
with timeblock("np.random.choice"):
for _ in range(N):
f2(a, n_sample)
with timeblock("np.random.default_rng().choice"):
for _ in range(N):
f3(a, n_sample)
np.random.choice
进行无重复的随机抽样,方法是生成数组中所有索引的排列,然后取其中的前n_sample
个(参见这一行)。如果n_sample
远小于数组a
中的元素数量,则这种方法会变得非常低效。 - ali_mrandom.sample
函数只会抽取n_samples
个随机样本。它可以通过两种方式之一实现这一点——如果n_samples << N
,则跟踪已选择的项目(item)来完成;或者维护一个可选候选项的收缩池(ifn_samples
相对于N
较大)。你可以在此处查看源代码——它相当简单易读。 - ali_m