在Python中生成截断负二项分布

3

我正在尝试生成数据集,这些数据集遵循截断负二项分布,其中数字集合具有最大值。

def truncated_Nbinom(n, p, max_value, size):
    import scipy.stats as sct
    temp_size = size
    while True:
        temp_size *= 2
        temp = sct.nbinom.rvs(n, p, size=temp_size)
        truncated = temp[temp <= max_value]
        if len(truncated) >= size:
            return truncated[:size]

当max_value和n较小时,我可以得到结果。但是当我尝试使用以下值时:

input_1= truncated_Nbinom(99, 0.3, 99, 5000).tolist()

内核一直崩溃。我尝试更改Python端口并提高递归限制,但都没有起作用。你有什么想法可以让我的代码更快吗?


“Dying” 是什么意思? - doctorlove
我正在使用Jupyter Notebook,在编写代码一段时间后,控制台会显示“内核已死亡,将重新启动”,然后才返回代码。 - Gözde Filiz
我怀疑你的程序可能存在无限循环,并且每次将temp_size加倍会消耗内存。 - doctorlove
1个回答

1
这里有一种方法。您可以计算负二项分布下选择 x 的概率,然后将x小于max_value的概率归一化为总和为一。现在,您只需使用适当的概率调用 np.random.choice 即可。
import numpy as np
import pandas as pd
from scipy import stats


def truncated_Nbinom2(n, p, max_value, size):
  support = np.arange(max_value + 1)
  probs = stats.nbinom.pmf(support, n, p)
  probs /= probs.sum()
  return np.random.choice(support, size=size, p=probs)

这里是一幅插图:
arr1 = truncated_Nbinom(9, 0.3, 9, 50000)
arr2 = truncated_Nbinom2(9, 0.3, 9, 50000)

df_counts = pd.DataFrame({
    "version_1": pd.Series(arr1).value_counts(),
    "version_2": pd.Series(arr2).value_counts(),
})

enter image description here


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接