Rcpp使用outer和pmax函数

3

我有一个R函数,需要对长度约为5000的向量进行大约100万次的计算。是否可以通过在Rcpp中实现来加速它?我以前几乎没有使用过Rcpp,并且下面的代码似乎无法工作:

set.seet(1)
a <- rt(5e3, df = 2)
b <- rt(5e3, df = 2.5)
c <- rt(5e3, df = 3)
d <- rt(5e3, df = 3.5)
sum((1 - outer(a, b, pmax)) * (1 - outer(c, d, pmax)))
#[1] -367780.1

#include <Rcpp.h>
using namespace Rcpp;
// [[Rcpp::export]]
double f_outer(NumericVector u, NumericVector v, NumericVector x, NumericVector y) {
double result = sum((1 - Rcpp::outer(u, v, Rcpp::pmax)) * (1 - Rcpp::outer(x, y, Rcpp::pmax)));
return(result);
}

非常感谢你!

通常情况下,在C++中使用循环来避免不必要的内存分配。 - F. Privé
1个回答

4

F. Privé 是正确的 - 我们将使用循环; 我在一个名为 so-answer.cpp 的文件中有以下 C++ 代码:

#include <Rcpp.h>

using namespace Rcpp;

// [[Rcpp::export]]
double f_outer(NumericVector u, NumericVector v, NumericVector x, NumericVector y) {
    // We'll use the size of the first and second vectors for our for loops
    int n = u.size();
    int m = v.size();
    // Make sure the vectors are appropriately sized for what we're doing
    if ( (n != x.size() ) || ( m != y.size() ) ) {
        ::Rf_error("Vectors not of compatible sizes.");
    }
    // Initialize a result variable
    double result = 0.0;
    // And use loops instead of outer
    for ( int i = 0; i < n; ++i ) {
        for ( int j = 0; j < m; ++j ) {
            result += (1 - std::max(u[i], v[j])) * (1 - std::max(x[i], y[j]));
        }
    }
    // Then return the result
    return result;
}

接下来我们在R中看到,C++代码给出了与您的R代码相同的答案,但运行速度要快得多:

library(Rcpp) # for sourceCpp()
library(microbenchmark) # for microbenchmark() (for benchmarking)
sourceCpp("so-answer.cpp") # compile our C++ code and make it available in R
set.seed(1) # for reproducibility
a <- rt(5e3, df = 2)
b <- rt(5e3, df = 2.5)
c <- rt(5e3, df = 3)
d <- rt(5e3, df = 3.5)
sum((1 - outer(a, b, pmax)) * (1 - outer(c, d, pmax)))
#> [1] -69677.99
f_outer(a, b, c, d)
#> [1] -69677.99
# Same answer, so looking good. Which one's faster?
microbenchmark(base = sum((1 - outer(a, b, pmax)) * (1 - outer(c, d, pmax))),
               rcpp = f_outer(a, b, c, d))
#> Unit: milliseconds
#>  expr       min        lq      mean    median        uq        max neval
#>  base 3978.9201 4119.6757 4197.9292 4131.3300 4144.4524 10121.5558   100
#>  rcpp  118.8963  119.1531  129.4071  119.4767  122.5218   909.2744   100
#>  cld
#>    b
#>   a

这是由 reprex 包 (v0.2.1) 在2018年12月13日创建的。


干得好,你是我的英雄! - user8934968

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接