更快的无重复加权抽样

49
这个问题引发了一个新的R包:wrswoR。R的默认无重复采样使用sample.int似乎需要二次运行时间,例如使用从均匀分布中抽取的权重。对于大样本量来说,这很慢。有人知道一个更快的实现方法可以在R内部使用吗?两个选择是“有放回拒绝采样”(请参见stats.sx上的此问题)和Wong和Easton(1980)的算法(在StackOverflow答案中有一个Python实现)。

感谢Ben Bolker提供的线索,该线索指向了在使用replace=F和不均匀权重调用sample.int时内部调用的C函数:ProbSampleNoReplace。实际上,该代码显示了两个嵌套的for循环(random.c的420行及之后)。

以下是用于经验性分析运行时间的代码:

library(plyr)

sample.int.test <- function(n, p) {
    sample.int(2 * n, n, replace=F, prob=p); NULL }

times <- ldply(
  1:7,
  function(i) {
    n <- 1024 * (2 ** i)
    p <- runif(2 * n)
    data.frame(
      n=n,
      user=system.time(sample.int.test(n, p), gcFirst=T)['user.self'])
  },
  .progress='text'
)

times

library(ggplot2)
ggplot(times, aes(x=n, y=user/n)) + geom_point() + scale_x_log10() +
  ylab('Time per unit (s)')

# Output:
       n   user
1   2048  0.008
2   4096  0.028
3   8192  0.100
4  16384  0.408
5  32768  1.645
6  65536  6.604
7 131072 26.558

Plot

编辑:感谢Arun指出,无权重抽样似乎没有这种性能惩罚。


1
我不明白为什么 runif() 会使运行时间呈二次方增长... - Ben Bolker
3
似乎使用probsample操作会比较耗时。 - Arun
3
该使用的算法可在 https://github.com/wch/r-source/blob/trunk/src/main/random.c 上查看:搜索“ProbSampleReplace”。我不知道这是否有用,但它应该能够让你大致了解使用的算法以及它是否可以轻松改进。请注意,它正在对整个向量进行排序... - Ben Bolker
1
我偶然发现了一些其他链接(需要订阅,所以在星期一之前我无法仔细检查它们)。http://link.springer.com/chapter/10.1007%2F978-3-642-30347-0_27 和 http://dl.acm.org/citation.cfm?id=1711169。不确定是否有任何真正的实现。 - Jouni Helske
1
@krlmlr 我知道发一条“感谢”评论不太好,但是你的 wrswoR 包将我的代码运行时间从大约8小时缩短到了不到1分钟!非常感谢! - Phil
显示剩余9条评论
3个回答

24

更新:

感谢 @Hemmo、@Dinrem、@krlmlr 和 @rtlgrmpf,实现了Efraimidis & Spirakis算法的Rcpp版本:

library(inline)
library(Rcpp)
src <- 
'
int num = as<int>(size), x = as<int>(n);
Rcpp::NumericVector vx = Rcpp::clone<Rcpp::NumericVector>(x);
Rcpp::NumericVector pr = Rcpp::clone<Rcpp::NumericVector>(prob);
Rcpp::NumericVector rnd = rexp(x) / pr;
for(int i= 0; i<vx.size(); ++i) vx[i] = i;
std::partial_sort(vx.begin(), vx.begin() + num, vx.end(), Comp(rnd));
vx = vx[seq(0, num - 1)] + 1;
return vx;
'
incl <- 
'
struct Comp{
  Comp(const Rcpp::NumericVector& v ) : _v(v) {}
  bool operator ()(int a, int b) { return _v[a] < _v[b]; }
  const Rcpp::NumericVector& _v;
};
'
funFast <- cxxfunction(signature(n = "Numeric", size = "integer", prob = "numeric"),
                       src, plugin = "Rcpp", include = incl)

# See the bottom of the answer for comparison
p <- c(995/1000, rep(1/1000, 5))
n <- 100000
system.time(print(table(replicate(funFast(6, 3, p), n = n)) / n))

      1       2       3       4       5       6 
1.00000 0.39996 0.39969 0.39973 0.40180 0.39882 
   user  system elapsed 
   3.93    0.00    3.96 
# In case of:
# Rcpp::IntegerVector vx = Rcpp::clone<Rcpp::IntegerVector>(x);
# i.e. instead of NumericVector
      1       2       3       4       5       6 
1.00000 0.40150 0.39888 0.39925 0.40057 0.39980 
   user  system elapsed 
   1.93    0.00    2.03 

旧版本:

让我们尝试几种可能的方法:

带替换的简单拒绝抽样。 这个函数比@krlmlr提供的sample.int.rej简单得多,即样本大小始终等于n。正如我们将看到的那样,在权重服从均匀分布的情况下,它仍然非常快,但在其他情况下极其缓慢。

fastSampleReject <- function(all, n, w){
  out <- numeric(0)
  while(length(out) < n)
    out <- unique(c(out, sample(all, n, replace = TRUE, prob = w)))
  out[1:n]
}
Wong和Easton(1980年)的算法。这是this Python版本的实现。它很稳定,但与其他函数相比速度较慢。请注意,本文不进行解释,保留HTML标记。
fastSample1980 <- function(all, n, w){
  tws <- w
  for(i in (length(tws) - 1):0)
    tws[1 + i] <- sum(tws[1 + i], tws[1 + 2 * i + 1], 
                      tws[1 + 2 * i + 2], na.rm = TRUE)      
  out <- numeric(n)
  for(i in 1:n){
    gas <- tws[1] * runif(1)
    k <- 0        
    while(gas > w[1 + k]){
      gas <- gas - w[1 + k]
      k <- 2 * k + 1
      if(gas > tws[1 + k]){
        gas <- gas - tws[1 + k]
        k <- k + 1
      }
    }
    wgh <- w[1 + k]
    out[i] <- all[1 + k]        
    w[1 + k] <- 0
    while(1 + k >= 1){
      tws[1 + k] <- tws[1 + k] - wgh
      k <- floor((k - 1) / 2)
    }
  }
  out
}
Rcpp实现了Wong和Easton的算法。 可能可以进一步优化,因为这是我第一个可用的Rcpp函数,但无论如何它都运行良好。
library(inline)
library(Rcpp)

src <-
'
Rcpp::NumericVector weights = Rcpp::clone<Rcpp::NumericVector>(w);
Rcpp::NumericVector tws = Rcpp::clone<Rcpp::NumericVector>(w);
Rcpp::NumericVector x = Rcpp::NumericVector(all);
int k, num = as<int>(n);
Rcpp::NumericVector out(num);
double gas, wgh;

if((weights.size() - 1) % 2 == 0){
  tws[((weights.size()-1)/2)] += tws[weights.size()-1] + tws[weights.size()-2];
}
else
{
  tws[floor((weights.size() - 1)/2)] += tws[weights.size() - 1];
}

for (int i = (floor((weights.size() - 1)/2) - 1); i >= 0; i--){
  tws[i] += (tws[2 * i + 1]) + (tws[2 * i + 2]);
}
for(int i = 0; i < num; i++){
  gas = as<double>(runif(1)) * tws[0];
  k = 0;
  while(gas > weights[k]){
    gas -= weights[k];
    k = 2 * k + 1;
    if(gas > tws[k]){
      gas -= tws[k];
      k += 1;
    }
  }
  wgh = weights[k];
  out[i] = x[k];
  weights[k] = 0;
  while(k > 0){
    tws[k] -= wgh;
    k = floor((k - 1) / 2);
  }
  tws[0] -= wgh;
}
return out;
'

fun <- cxxfunction(signature(all = "numeric", n = "integer", w = "numeric"),
                   src, plugin = "Rcpp")

现在是一些结果:

times1 <- ldply(
  1:6,
  function(i) {
    n <- 1024 * (2 ** i)
    p <- runif(2 * n) # Uniform distribution
    p <- p/sum(p)
    data.frame(
      n=n,
      user=c(system.time(sample.int.test(n, p), gcFirst=T)['user.self'],
             system.time(weighted_Random_Sample(1:(2*n), p, n), gcFirst=T)['user.self'],
             system.time(fun(1:(2*n), n, p), gcFirst=T)['user.self'],
             system.time(sample.int.rej(2*n, n, p), gcFirst=T)['user.self'],
             system.time(fastSampleReject(1:(2*n), n, p), gcFirst=T)['user.self'],
             system.time(fastSample1980(1:(2*n), n, p), gcFirst=T)['user.self']),
      id=c("Base", "Reservoir", "Rcpp", "Rejection", "Rejection simple", "1980"))
  },
  .progress='text'
)


times2 <- ldply(
  1:6,
  function(i) {
    n <- 1024 * (2 ** i)
    p <- runif(2 * n - 1)
    p <- p/sum(p) 
    p <- c(0.999, 0.001 * p) # Special case
    data.frame(
      n=n,
      user=c(system.time(sample.int.test(n, p), gcFirst=T)['user.self'],
             system.time(weighted_Random_Sample(1:(2*n), p, n), gcFirst=T)['user.self'],
             system.time(fun(1:(2*n), n, p), gcFirst=T)['user.self'],
             system.time(sample.int.rej(2*n, n, p), gcFirst=T)['user.self'],
             system.time(fastSampleReject(1:(2*n), n, p), gcFirst=T)['user.self'],
             system.time(fastSample1980(1:(2*n), n, p), gcFirst=T)['user.self']),
      id=c("Base", "Reservoir", "Rcpp", "Rejection", "Rejection simple", "1980"))
  },
  .progress='text'
)

enter image description here

enter image description here

arrange(times1, id)
       n  user               id
1   2048  0.53             1980
2   4096  0.94             1980
3   8192  2.00             1980
4  16384  4.32             1980
5  32768  9.10             1980
6  65536 21.32             1980
7   2048  0.02             Base
8   4096  0.05             Base
9   8192  0.18             Base
10 16384  0.75             Base
11 32768  2.99             Base
12 65536 12.23             Base
13  2048  0.00             Rcpp
14  4096  0.01             Rcpp
15  8192  0.03             Rcpp
16 16384  0.07             Rcpp
17 32768  0.14             Rcpp
18 65536  0.31             Rcpp
19  2048  0.00        Rejection
20  4096  0.00        Rejection
21  8192  0.00        Rejection
22 16384  0.02        Rejection
23 32768  0.02        Rejection
24 65536  0.03        Rejection
25  2048  0.00 Rejection simple
26  4096  0.01 Rejection simple
27  8192  0.00 Rejection simple
28 16384  0.01 Rejection simple
29 32768  0.00 Rejection simple
30 65536  0.05 Rejection simple
31  2048  0.00        Reservoir
32  4096  0.00        Reservoir
33  8192  0.00        Reservoir
34 16384  0.02        Reservoir
35 32768  0.03        Reservoir
36 65536  0.05        Reservoir

arrange(times2, id)
       n  user               id
1   2048  0.43             1980
2   4096  0.93             1980
3   8192  2.00             1980
4  16384  4.36             1980
5  32768  9.08             1980
6  65536 19.34             1980
7   2048  0.01             Base
8   4096  0.04             Base
9   8192  0.18             Base
10 16384  0.75             Base
11 32768  3.11             Base
12 65536 12.04             Base
13  2048  0.01             Rcpp
14  4096  0.02             Rcpp
15  8192  0.03             Rcpp
16 16384  0.08             Rcpp
17 32768  0.15             Rcpp
18 65536  0.33             Rcpp
19  2048  0.00        Rejection
20  4096  0.00        Rejection
21  8192  0.02        Rejection
22 16384  0.02        Rejection
23 32768  0.05        Rejection
24 65536  0.08        Rejection
25  2048  1.43 Rejection simple
26  4096  2.87 Rejection simple
27  8192  6.17 Rejection simple
28 16384 13.68 Rejection simple
29 32768 29.74 Rejection simple
30 65536 73.32 Rejection simple
31  2048  0.00        Reservoir
32  4096  0.00        Reservoir
33  8192  0.02        Reservoir
34 16384  0.02        Reservoir
35 32768  0.02        Reservoir
36 65536  0.04        Reservoir

显然,我们可以拒绝函数1980,因为在两种情况下它都比Base慢。当第二种情况中只有一个概率为0.999时,Rejection simple也会遇到麻烦。
因此,剩下的选择是RejectionRcppReservoir。最后一步是检查这些值本身是否正确。为了确保它们正确,我们将使用sample作为基准(还要消除由于无替换抽样而导致的概率不必与p相符的混淆)。
p <- c(995/1000, rep(1/1000, 5))
n <- 100000

system.time(print(table(replicate(sample(1:6, 3, repl = FALSE, prob = p), n = n))/n))
      1       2       3       4       5       6 
1.00000 0.39992 0.39886 0.40088 0.39711 0.40323  # Benchmark
   user  system elapsed 
   1.90    0.00    2.03 

system.time(print(table(replicate(sample.int.rej(2*3, 3, p), n = n))/n))
      1       2       3       4       5       6 
1.00000 0.40007 0.40099 0.39962 0.40153 0.39779 
   user  system elapsed 
  76.02    0.03   77.49 # Slow

system.time(print(table(replicate(weighted_Random_Sample(1:6, p, 3), n = n))/n))
      1       2       3       4       5       6 
1.00000 0.49535 0.41484 0.36432 0.36338 0.36211  # Incorrect
   user  system elapsed 
   3.64    0.01    3.67 

system.time(print(table(replicate(fun(1:6, 3, p), n = n))/n))
      1       2       3       4       5       6 
1.00000 0.39876 0.40031 0.40219 0.40039 0.39835 
   user  system elapsed 
   4.41    0.02    4.47 

注意这里有几个细节。由于某种原因,weighted_Random_Sample 返回的值不正确(我还没有深入研究,但是它在假定均匀分布时可以正常工作)。sample.int.rej 在重复抽样中非常缓慢。
总之,在需要重复抽样时,Rcpp 是最佳选择,而在其他情况下,sample.int.rej 更快且更易于使用。

非常好,特别是测试采样器的代码!可能weighted_Random_Sample受到IEEE浮点值有限精度的影响。不幸的是,Dinre还没有回复我的请求,关于这个话题,但请参见数学.sx上的这个相关问题 :-) - krlmlr
不要着急,sort(rexp(n)) 已经很慢了。让我们等待 Rcpp 的解决方案,应该不会太难。 - krlmlr
2
为了在处理过程中增加额外的困难,我在我的电脑上得到了极其相似的结果,其中基准测试1.0000 0.4017 0.4002 0.4091 0.3917 0.3973和加权随机采样1.0000 0.3991 0.4056 0.4012 0.4014 0.3927。我使用的是 p <- c(955, rep(1, 5)),因为这实际上是一种更可靠的抽样方法。所有主要的抽样函数都接受权值,并且不假设sum(prob)==1,因此我更喜欢在我的工作中使用权值。然而,这可能会限制WRS算法的使用方式。 - Dinre
1
@Dinre:这表明你的实现取决于权重的大小——使用“-rexp(n) / prob = log(runif(n)) / prob”在数值上比“runif(n) ^ (1 / prob)”更健壮,并且在排序顺序方面是等效的。 - krlmlr
@krlmlr,我只是直接翻译了Egraimidis&Spirakis论文中提出的算法作为概念验证。它有效。将其加强为一个健壮的函数需要更多的工作,不仅仅是数值工作;目前它也没有检查参数的有效性。在当前形式下,该函数只是证明该算法可以成功使用的证明,这是证明算法存在于实际空间的第一步。该论文仅证明该算法存在于理论空间中。 - Dinre
显示剩余4条评论

21
我决定深入研究一些评论,并发现Efraimidis & Spirakis论文非常有趣(感谢@Hemmo找到参考文献)。 本文的总体思路是:通过生成随机均匀数并将其提高到每个项目的重量的倒数幂来创建密钥。 然后,您只需取最高的键值作为样本。 这非常出色!
weighted_Random_Sample <- function(
    .data,
    .weights,
    .n
    ){

    key <- runif(length(.data)) ^ (1 / .weights)
    return(.data[order(key, decreasing=TRUE)][1:.n])
}

如果您将“ .n”设置为“ .data”的长度(应始终为“ .weights”的长度),则实际上这是一种加权水库置换,但该方法对于采样和置换都有效。更新:我应该提到上述函数期望权重大于零。否则,key < - runif(length(.data))^(1 / .weights)将无法正确排序。

仅仅是为了好玩,我也使用了原帖中的测试场景来比较这两个函数。

set.seed(1)

times_WRS <- ldply(
1:7,
function(i) {
    n <- 1024 * (2 ** i)
    p <- runif(2 * n)
    n_Set <- 1:(2 * n)
    data.frame(
      n=n,
      user=system.time(weighted_Random_Sample(n_Set, p, n), gcFirst=T)['user.self'])
  },
  .progress='text'
)

sample.int.test <- function(n, p) {
sample.int(2 * n, n, replace=F, prob=p); NULL }

times_sample.int <- ldply(
  1:7,
  function(i) {
    n <- 1024 * (2 ** i)
    p <- runif(2 * n)
    data.frame(
      n=n,
      user=system.time(sample.int.test(n, p), gcFirst=T)['user.self'])
  },
  .progress='text'
)

times_WRS$group <- "WRS"
times_sample.int$group <- "sample.int"
library(ggplot2)

ggplot(rbind(times_WRS, times_sample.int) , aes(x=n, y=user/n, col=group)) + geom_point() + scale_x_log10() +  ylab('Time per unit (s)')

这里是时间:

times_WRS
#        n user
# 1   2048 0.00
# 2   4096 0.01
# 3   8192 0.00
# 4  16384 0.01
# 5  32768 0.03
# 6  65536 0.06
# 7 131072 0.16

times_sample.int
#        n  user
# 1   2048  0.02
# 2   4096  0.05
# 3   8192  0.14
# 4  16384  0.58
# 5  32768  2.33
# 6  65536  9.23
# 7 131072 37.79

performance comparison


1
这是否真的与无重复加权抽样相同,还是只是一个近似?我在math.sx上提出了这个问题,但没有得到答案... - krlmlr
除非你只想对人群的一部分进行抽样...? - krlmlr
1
不是的。因为你没有使用替换,所以这是同一件事情。你仍然对整个集合进行排序,但是从排序后的集合中选择1:n个项目。 - Dinre
我想我应该修改我的最后一条评论:在这个算法中是一样的,因为整个集合总是有序的。这就是为什么你不需要循环来评估每个项目选择的有效性。 - Dinre
3
Pavlos S. Efraimidis 和 Paul G. Spirakis 提出的算法是我很久以来见过的最美妙的东西,仅仅因为它的简单。这就像通过 FFT 实现卷积一样甜美,不确定哪个更胜一筹... 注意:作者证明了他们的算法等价于加权无重复随机抽样。 - krlmlr
显示剩余10条评论

3

让我介绍一种基于带替换的拒绝抽样的更快速方法的实现。其思路如下:

  • 生成一个比所需大小“略大”的带替换样本

  • 丢弃重复值

  • 如果没有抽取足够的值,则使用调整后的nsizeprob参数递归调用相同的过程

  • 将返回的索引重新映射到原始索引

我们需要抽样多少样本?假设分布均匀,结果是 从N个值中选出x个唯一值所需的预期试验次数。这是两个调和数(H_n and H_{n - size})的差异。前几个调和数已经被列成表格了,否则将使用自然对数的逼近值。(这只是一个大致数字,这里没有必要太精确。)现在,对于非均匀分布,预期抽取的物品数量只能更大,因此我们不会抽取太多的样本。此外,抽样次数的限制为两倍的人口数量——我假设进行几次递归调用比抽取O(n ln n)项更快。

代码可在 R 包 wrswoRsample.int.rej 程序中找到 sample_int_rej.R。安装方式为:

library(devtools)
install_github('wrswoR', 'muelleki')

看起来它的工作速度“足够快”,但是尚未进行正式的运行时测试。此外,该软件包仅在Ubuntu中进行了测试。感谢您的反馈。


是的,如果假设一个均匀分布,那么它是快速的;但是假设不太方便的情况下,它会变得相当糟糕。今天我将发布一个关于此的答案,这个的R实现,并尝试完成它的Rcpp版本。 - Julius Vainora
@Julius:期待基准测试的结果 :-) 我已经使用指数分布权重测试了我的代码(这是使用IEEE浮点数可能遇到的最糟糕的情况),本来预计会出现非常糟糕的表现,但令我惊讶的是,情况并没有那么糟糕... - krlmlr
@krlmlr,根据您描述的算法,包含概率使用的权重将不会成比例。 - Ferdinand.kraft
@ Ferdinand.kraft: 哦,真的吗?能具体说明一下吗? - krlmlr
在拒绝抽样中,如果在有替换的步骤中存在重复项,则必须丢弃整个样本。即使如此,包含概率也以一种复杂的非线性方式与权重相关联。在基于设计的推断中,我们需要最终的包含概率,而不是算法内部使用的权重。因此,我的观点是:如何在使用您提出的方法绘制样本的权重和大小的情况下计算它们? - Ferdinand.kraft
@Ferdinand.kraft:whuber的回答详细解释了为什么可以使用相对权重逐个采样元素并丢弃每个重复元素。现在,在这个过程中,逐个采样元素等同于采样一批然后丢弃重复项。如果您仍然有疑问,您可以比较我的wrswoR包中实现的结果与sample.int甚至Efraimidis&Spirakis算法的实现结果。如果您发现任何有意义的差异,请告诉我。 - krlmlr

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接