如何根据存储在列表中的概率进行随机选择(加权随机分布)?

38

假设有一个概率列表:

P = [0.10, 0.25, 0.60, 0.05]

我可以确保P中所有变量的总和始终为1。

如何编写一个函数,根据列表中的值随机返回一个有效索引?换句话说,在这种特定输入情况下,我希望它返回0的概率为10%,返回1的概率为25%,返回2的概率为60%,返回3的概率为5%。


实际上,从Python 3.6开始就有random.choices(注意末尾的's')可以允许提交相对权重。 - Nick
@NickstandswithUkraine,您能否请您添加一个关于这个问题的答案吗? - Karl Knechtel
请参见 https://dev59.com/kHRC5IYBdhLWcg3wRO3k。我认为这个问题可能更好地成为规范问题。 - Karl Knechtel
还有一个需要考虑的是 https://dev59.com/D3I95IYBdhLWcg3wzhd9,因为不重复抽样的特定情况有些棘手。 - Karl Knechtel
也许将random.choices的信息编辑到顶部回答中会更好,因为接口基本相同。 - Karl Knechtel
@KarlKnechtel 完成了! - Nick
6个回答

62
你可以通过numpy轻松实现这一点。它有一个接受概率参数的choice函数。

你可以通过numpy轻松实现这一点。它有一个choice函数,接受概率参数。

np.random.choice(
  ['pooh', 'rabbit', 'piglet', 'Christopher'], 
  5,
  p=[0.5, 0.1, 0.1, 0.3]
)

简洁明了,虽然我认为在这里使用numpy有些过头了,特别是如果脚本除了标准库之外没有其他依赖。 - salezica

15
基本上,制作一个累积概率分布(CDF)数组。基本上,给定索引的CDF值等于P中所有小于或等于该索引的值的总和。然后生成0到1之间的随机数,并进行二进制搜索(如果您想要线性搜索也可以)。以下是一些简单的代码。
from bisect import bisect
from random import random

P = [0.10,0.25,0.60,0.05]

cdf = [P[0]]
for i in xrange(1, len(P)):
    cdf.append(cdf[-1] + P[i])

random_ind = bisect(cdf,random())

当然,你可以使用类似以下方式生成一组随机索引:
rs = [bisect(cdf, random()) for i in xrange(20)]

产出

[2, 2, 3, 2, 2, 1, 2, 2, 2, 1, 2, 1, 2, 1, 2, 1, 2, 2, 2, 2]

(结果将会有所不同,这是正常的)。当然,对于可能索引如此之少的情况,二分搜索相当不必要,但对于可能索引更多的情况则绝对是推荐使用的。


12

嗯,有趣,那怎么样...

  1. 生成一个0到1之间的数字。

  2. 遍历列表,用每个项目的概率减去你的数字。

  3. 选择在减法后将你的数字降至0或以下的项目。

这很简单,是O(n)的,应该能够工作 :)


如果概率按降序预先排序,迭代很可能会更快地终止。 - Nick

6
这个问题等同于从一个分类分布中进行采样。该分布通常与多项式分布混淆,后者模拟了从分类分布中进行多次采样的结果。
在numpy中,使用numpy.random.multinomial很容易从多项式分布中进行采样,但是特定的分类版本不存在。但是,可以通过从单次试验的多项式分布中进行采样,然后返回输出中的非零元素来实现。
import numpy as np
pvals = [0.10,0.25,0.60,0.05]
ind = np.where(np.random.multinomial(1,pvals))[0][0]

我认为使用argmax()而不是where()[0][0]更简单,而且效果相同。 - Hawk

3
import random

probs = [0.1, 0.25, 0.6, 0.05]
r = random.random()
index = 0
while(r >= 0 and index < len(probs)):
  r -= probs[index]
  index += 1
print index - 1

哈哈,我还以为在你发帖前的两秒钟里我很有原创性。 - salezica
这总是需要O(n)的时间,其中n是len(probs)。我们能做得更好吗? - Sush
@Sush 是的:我们可以对问题进行排序并执行二分搜索。这将减少到O(log n)。 - Nick

2

从Python 3.6 开始,random 模块中有一个名为 choices 的方法(注意末尾的 's')。

引用官方文档:

random.choices(population, weights=None, *, cum_weights=None, k=1) 返回从总体中带有重复的 k 个元素的列表

因此,解决方案如下:

>> choices(['option1', 'option2', 'option3', 'option4'], [0.10, 0.25, 0.60, 0.05])

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接