如何加速这个Rcpp函数?

4
我希望能够在Rcpp中实现一个简单的“分割-应用-组合”程序,其中数据集(矩阵)被分成组,并返回每组列的总和。这是一个在R中容易实现但常常需要相当长时间的过程。我已经成功地实现了一个Rcpp解决方案,它超越了R的性能,但我想知道是否可以进一步改进。为了说明,这里有一些代码,首先是使用R的代码:
n <- 50000
k <- 50
set.seed(42)
X <- matrix(rnorm(n*k), nrow=n)
g=rep(1:8,length.out=n )

use.for <- function(mat, ind){
  sums <- matrix(NA, nrow=length(unique(ind)), ncol=ncol(mat))
  for(i in seq_along(unique(ind))){
    sums[i,] <- colSums(mat[ind==i,])
  }
  return(sums)
}

use.apply <- function(mat, ind){
  apply(mat,2, function(x) tapply(x, ind, sum))
}

use.dt <- function(mat, ind){ # based on Roland's answer
   dt <- as.data.table(mat)
   dt[, cvar := ind]
   dt2 <- dt[,lapply(.SD, sum), by=cvar]
   as.matrix(dt2[,cvar:=NULL])
}

事实证明,for循环实际上非常快,并且使用Rcpp实现最容易(对我来说)。它的工作原理是为每个组创建一个子矩阵,然后在矩阵上调用colSums。这是使用RcppArmadillo实现的:

#include <RcppArmadillo.h>
// [[Rcpp::depends(RcppArmadillo)]]
using namespace Rcpp;
using namespace arma;

// [[Rcpp::export]]
arma::mat use_arma(arma::mat X, arma::colvec G){

  arma::colvec gr = arma::unique(G);
  int gr_n = gr.n_rows;
  int ncol = X.n_cols;

  arma::mat out = zeros(gr_n, ncol); 

  for(int g=0; g<gr_n; g++){
   int g_id = gr(g);
   arma::uvec subvec = find(G==g_id);
   arma::mat submat = X.rows(subvec);
   arma::rowvec res = sum(submat,0);
   out.row(g) = res;     
  }
 return out;
}

然而,根据这个问题的回答,我了解到在C++(就像在R中一样),创建副本是代价昂贵的,但循环不像在R中那么糟糕。由于arma方案依赖于为每个组创建矩阵(代码中的submat),因此我猜想避免这种情况将进一步加速该过程。因此,这里是第二个仅使用循环基于Rcpp的实现:
#include <Rcpp.h>
using namespace Rcpp;

// [[Rcpp::export]]
NumericMatrix use_Rcpp(NumericMatrix X, IntegerVector G){

  IntegerVector gr = unique(G);
  std::sort(gr.begin(), gr.end());
  int gr_n = gr.size();
  int nrow = X.nrow(), ncol = X.ncol();

  NumericMatrix out(gr_n, ncol);

  for(int g=0; g<gr_n; g++){
     int g_id = gr(g);

      for (int j = 0; j < ncol; j++) {
      double total = 0;
        for (int i = 0; i < nrow; i++) {

          if (G(i) != g_id) continue;    // not sure how else to do this
          total += X(i, j);
        }
        out(g,j) = total;
      }
  }
      return out;
}

对这些解决方案进行基准测试,包括@Roland提供的use_dt版本(我的先前版本不公平地歧视了data.table),以及@beginneR建议的dplyr解决方案,得到以下结果:
 library(rbenchmark)
 benchmark(use.for(X,g), use.apply(X,g), use.dt(X,g), use.dplyr(X,g), use_arma(X,g), use_Rcpp(X,g), 
+           columns = c("test", "replications", "elapsed", "relative"), order = "relative", replications = 1000)
             test replications elapsed relative
# 5  use_arma(X, g)         1000   29.65    1.000
# 4 use.dplyr(X, g)         1000   42.05    1.418
# 3    use.dt(X, g)         1000   56.94    1.920
# 1   use.for(X, g)         1000   60.97    2.056
# 6  use_Rcpp(X, g)         1000  113.96    3.844
# 2 use.apply(X, g)         1000  301.14   10.156

我的直觉是(use_Rcppuse_arma更好),但事实证明我错了。话虽如此,我猜测在我的use_Rcpp函数中的这行代码if (G(i) != g_id) continue;减慢了整个过程。如果有其他可替代方案,我很愿意学习。

我很高兴自己完成了与R相同的任务所需时间的一半,但也许若干Rcpp比R快得多的例子误导了我的期望值,我想知道是否还能进一步提速。请问有人有什么想法吗?因为我对RcppC++都相对陌生,所以我也欢迎任何编程或代码的评论。

3个回答

4
不,你需要击败的不是 for 循环:
library(data.table)
#it doesn't seem fair to include calls to library in benchmarks
#you need to do that only once in your session after all

use.dt2 <- function(mat, ind){
  dt <- as.data.table(mat)
  dt[, cvar := ind]
  dt2 <- dt[,lapply(.SD, sum), by=cvar]
  as.matrix(dt2[,cvar:=NULL])
}

all.equal(use.dt(X,g), use.dt2(X,g))
#TRUE

benchmark(use.for(X,g), use.apply(X,g), use.dt(X,g), use.dt2(X,g),
          columns = c("test", "replications", "elapsed", "relative"), 
          order = "relative", replications = 50)

#             test replications elapsed relative
#4   use.dt2(X, g)           50    3.12    1.000
#1   use.for(X, g)           50    4.67    1.497
#3    use.dt(X, g)           50    7.53    2.413
#2 use.apply(X, g)           50   17.46    5.596

谢谢指出这一点。我没有想到它会有如此明显的差异,但显然确实如此。感谢! - coffeinjunky
如果您的输入是一个 data.frame,使用 data.table 会更快,因为第一行会复制输入,但对于 data.frame 输入,可以通过 setDT 避免这种情况。 - Roland
最终,我需要对X和求和输出进行多次矩阵乘法。出于这个原因,一开始就将所有东西设置为矩阵。虽然我可以将一些内容更改为data.frame形式,但是除非我将所有内容转换为矩阵,否则我不知道如何继续操作。data.table是否允许对整个表进行类似矩阵的操作,例如取逆等? - coffeinjunky
不,如果是这种情况,您甚至可以考虑使用Armadillo(或Eigen)来完成所有事情。 - Roland
实际上,我确实打算这样做。但是我想优化组成部分,因为该过程非常耗时。 - coffeinjunky
只是一个提示:您不必执行cvar := ind。您可以直接在聚合中使用ind,如下所示:dt[,lapply(.SD, sum), by=ind] - Arun

2
以下是关于你的Rcpp解决方案的评论和内联注释:

以下是关于你的Rcpp解决方案的评论和内联注释:

#include <Rcpp.h>
using namespace Rcpp;

// [[Rcpp::export]]
NumericMatrix use_Rcpp(NumericMatrix X, IntegerVector G){

  // Rcpp has a sort_unique() function, which combines the
  // sort and unique steps into one, and is often faster than
  // performing the operations separately. Try `sort_unique(G)`
  IntegerVector gr = unique(G);
  std::sort(gr.begin(), gr.end());
  int gr_n = gr.size();
  int nrow = X.nrow(), ncol = X.ncol();

  // This constructor zero-initializes memory (kind of like
  // making a copy). You should use:
  // 
  //     NumericMatrix out = no_init(gr_n, ncol)
  //
  // to ensure the memory is allocated, but not zeroed.
  // 
  // EDIT: We don't have no_init for matrices right now, but you can hack
  // around that with:
  // 
  //     NumericMatrix out(Rf_allocMatrix(REALSXP, gr_n, ncol));
  NumericMatrix out(gr_n, ncol);

  for(int g=0; g<gr_n; g++){

     // subsetting with operator[] is cheaper, so use gr[g] when
     // you can be sure bounds checks are not necessary
     int g_id = gr(g);

      for (int j = 0; j < ncol; j++) {
      double total = 0;
        for (int i = 0; i < nrow; i++) {

          // similarily here
          if (G(i) != g_id) continue;    // not sure how else to do this
          total += X(i, j);
        }
        // IIUC, you are filling the matrice row-wise. This is slower as
        // R matrices are stored in column-major format, and so filling
        // matrices column-wise will be faster.
        out(g,j) = total;
      }
  }
      return out;
}

2
也许你正在寻找(奇怪命名的)rowsum
library(microbenchmark)
use.rowsum = rowsum

并且

> all.equal(use.for(X, g), use.rowsum(X, g), check.attributes=FALSE)
[1] TRUE
> microbenchmark(use.for(X, g), use.rowsum(X, g), times=5)
Unit: milliseconds
             expr       min        lq    median        uq       max neval
    use.for(X, g) 126.92876 127.19027 127.51403 127.64082 128.06579     5
 use.rowsum(X, g)  10.56727  10.93942  11.01106  11.38697  11.38918     5

不知道那个(名字真的很误导人)函数。看起来好像比其他所有函数都要好。谢谢你指出来! - coffeinjunky

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接