为什么这个朴素的矩阵乘法比基本的R语言更快?

33
在R语言中,矩阵乘法非常优化,即实际上只是调用了BLAS/LAPACK。然而,我惊讶地发现,这个非常朴素的C++矩阵向量乘法代码似乎可靠地快了30%。
 library(Rcpp)

 # Simple C++ code for matrix multiplication
 mm_code = 
 "NumericVector my_mm(NumericMatrix m, NumericVector v){
   int nRow = m.rows();
   int nCol = m.cols();
   NumericVector ans(nRow);
   double v_j;
   for(int j = 0; j < nCol; j++){
     v_j = v[j];
     for(int i = 0; i < nRow; i++){
       ans[i] += m(i,j) * v_j;
     }
   }
   return(ans);
 }
 "
 # Compiling
 my_mm = cppFunction(code = mm_code)

 # Simulating data to use
 nRow = 10^4
 nCol = 10^4

 m = matrix(rnorm(nRow * nCol), nrow = nRow)
 v = rnorm(nCol)

 system.time(my_ans <- my_mm(m, v))
#>    user  system elapsed 
#>   0.103   0.001   0.103 
 system.time(r_ans <- m %*% v)
#>   user  system elapsed 
#>  0.154   0.001   0.154 

 # Double checking answer is correct
 max(abs(my_ans - r_ans))
 #> [1] 0

%*%这个 base R 的函数是否执行了某种数据检查,而我却跳过了呢?

编辑:

在理解了发生了什么之后(感谢 Stack Overflow!),值得注意的是,对于 R 的%*%而言,这是最坏的情况,即矩阵与向量相乘。例如,@RalfStubner 指出,使用 RcppArmadillo 实现的矩阵-向量乘法比我演示的朴素实现还要快,这意味着它比 base R 快得多,但在矩阵-矩阵乘法时(当两个矩阵都很大且是方形的)几乎与 base R 的%*%相同:

 arma_code <- 
   "arma::mat arma_mm(const arma::mat& m, const arma::mat& m2) {
 return m * m2;
 };"
 arma_mm = cppFunction(code = arma_code, depends = "RcppArmadillo")

 nRow = 10^3 
 nCol = 10^3

 mat1 = matrix(rnorm(nRow * nCol), 
               nrow = nRow)
 mat2 = matrix(rnorm(nRow * nCol), 
               nrow = nRow)

 system.time(arma_mm(mat1, mat2))
#>   user  system elapsed 
#>   0.798   0.008   0.814 
 system.time(mat1 %*% mat2)
#>   user  system elapsed 
#>   0.807   0.005   0.822  

所以 R 当前版本 (v3.5.0) 的 %*% 对于矩阵乘法已经非常优化,但如果你愿意跳过检查的话,它在矩阵-向量运算方面还是可以显著加速的。


3
R的方法必须处理缺失值,虽然这可能不能解释所有情况。此外,根据我对计算数值方法的了解,你的朴素方法在某些情况下可能会出现不可接受的不准确性,因此其他方法将用更多时间换取更好的准确性。 - joran
看一下:getAnywhere(%*%),我们有:function (x, y) .Primitive("%*%")。所以,这是与一个_C_库进行接口交互,但正如@joran指出的那样,您没有考虑NA处理。 - coatless
1
@joran:据我所知,这个处理NA的很好。我唯一能看到的区别是它会产生一个向量而不是矩阵。 - Cliff AB
当其中一个维度为1时,可能存在矩阵x矩阵的优化并不特别有用的情况吗? - Cliff AB
7
这个帖子已经有些年头了,自Radford写下这篇文章以来,他可能已经成功地将一些改进引入了R中。我认为,至少可以总结出处理NA、Inf和NaN并不总是简单的,需要一些工作。 - joran
3
使用线性代数库进行矩阵乘法可以获得巨大的改进,因为它们更好地处理内存和缓存。对于矩阵-向量乘法,内存问题不是那么重要,所以优化的空间较小。例如,请参见此文 - F. Privé
3个回答

33

快速查看特别是这里的names.c,你会找到C函数do_matprod,它被%*%调用,并在文件array.c中找到。(有趣的是,事实证明,crossprodtcrossprod也调度到了同一个函数)。这里是一个链接do_matprod代码。

浏览该函数时,您可以看到它处理了您的朴素实现所没有的一些东西,包括:

  1. 保留行和列名称,如果有意义的话。
  2. 当调用%*%时,如果被操作的两个对象属于已提供这种方法的类,则允许调度到替代的S4方法。(这就是函数此部分正在发生的事情。)
  3. 处理实数矩阵和复数矩阵。
  4. 实现一系列规则来处理矩阵和矩阵、向量和矩阵、矩阵和向量、向量和向量的乘法。(请记住,在R中进行交叉乘法时,LHS上的向量被视为行向量,而在RHS上,则被视为列向量;这就是代码使其如此的原因。)

在函数的末尾附近, 它会分配到 matprodcmatprod 中的任一个。有趣的是(至少对我而言),在实矩阵的情况下,如果任一矩阵可能包含NaNInf值,则matprod将分配( 这里 ) 到一个名为simple_matprod的函数中,该函数与您自己的函数一样简单直接。否则,它会分派到几个BLAS Fortran例程之一,假定可以保证统一“行为良好”的矩阵元素,这些例程应该更快。


1
有趣的 (+1)。如果这些是唯一的区别,那么其中一个暗示是,如果我知道我正在进行普通矩阵 x 向量操作,我应该使用my_mm。这让我感到惊讶。 - Cliff AB
5
您可以尝试使用适当的BLAS函数,直接或间接地通过RcppArmadillo,并使用多线程的BLAS,以获得更大的效益。 - Ralf Stubner
你有没有想法为什么在R中相同的矩阵乘法根据矩阵的顺序速度会有所不同?(https://dev59.com/dFEG5IYBdhLWcg3wH0U0) - dcsuka

9

Josh的回答解释了为什么R的矩阵乘法不如这种朴素方法快。我很好奇使用RcppArmadillo可以获得多少收益。代码足够简单:

arma_code <- 
  "arma::vec arma_mm(const arma::mat& m, const arma::vec& v) {
       return m * v;
   };"
arma_mm = cppFunction(code = arma_code, depends = "RcppArmadillo")

基准测试:

> microbenchmark::microbenchmark(my_mm(m,v), m %*% v, arma_mm(m,v), times = 10)
Unit: milliseconds
          expr      min       lq      mean    median        uq       max neval
   my_mm(m, v) 71.23347 75.22364  90.13766  96.88279  98.07348  98.50182    10
       m %*% v 92.86398 95.58153 106.00601 111.61335 113.66167 116.09751    10
 arma_mm(m, v) 41.13348 41.42314  41.89311  41.81979  42.39311  42.78396    10

RcppArmadillo提供了更好的语法和更好的性能。

好奇心促使我尝试使用BLAS直接解决这个问题:

blas_code = "
NumericVector blas_mm(NumericMatrix m, NumericVector v){
  int nRow = m.rows();
  int nCol = m.cols();
  NumericVector ans(nRow);
  char trans = 'N';
  double one = 1.0, zero = 0.0;
  int ione = 1;
  F77_CALL(dgemv)(&trans, &nRow, &nCol, &one, m.begin(), &nRow, v.begin(),
           &ione, &zero, ans.begin(), &ione);
  return ans;
}"
blas_mm <- cppFunction(code = blas_code, includes = "#include <R_ext/BLAS.h>")

基准测试:

Unit: milliseconds
          expr      min       lq      mean    median        uq       max neval
   my_mm(m, v) 72.61298 75.40050  89.75529  96.04413  96.59283  98.29938    10
       m %*% v 95.08793 98.53650 109.52715 111.93729 112.89662 128.69572    10
 arma_mm(m, v) 41.06718 41.70331  42.62366  42.47320  43.22625  45.19704    10
 blas_mm(m, v) 41.58618 42.14718  42.89853  42.68584  43.39182  44.46577    10

Armadillo和BLAS(在我的情况下是OpenBLAS)几乎是一样的。而BLAS代码也是R最终所做的。因此,R所做的2/3工作是错误检查等。


2
而且如果您的操作系统/编译器支持,可能还会使用OpenMP。 - Dirk Eddelbuettel
@Dirk 我本来以为Armadillo会直接将这样简单的东西转发给BLAS(在我的情况下也是多线程的)。至少它们的速度是一样快的... - Ralf Stubner
非常有趣。检查成本不像矩阵乘法计算那样快速增长,因此在这种情况下,这种成本就消失了。 - Cliff AB
@CliffAB 是的。此外,对于矩阵-矩阵运算,如果采用朴素方法实现BLAS,要想在内存访问方面超越它是更加困难的,详见F.Prive提供的链接。 - Ralf Stubner
@RalfStubner,R的矩阵乘法在使用宽/窄矩阵时表现非常奇怪[取决于顺序](https://dev59.com/dFEG5IYBdhLWcg3wH0U0)。您有任何想法是什么原因导致这种情况? - dcsuka

2

补充一下Ralf Stubner的解决方案,你可以使用以下C ++版本来

  1. 同时执行多个列,以避免多次重新读取输出向量。
  2. 添加__restrict__以潜在地允许矢量操作(我猜只是读取,所以可能无关紧要)。
#include <Rcpp.h>
using namespace Rcpp;

inline void mat_vec_mult_vanilla
(double const * __restrict__ m, 
 double const * __restrict__ v, 
 double * __restrict__ const res, 
 size_t const dn, size_t const dm) noexcept {
  for(size_t j = 0; j < dm; ++j, ++v){
    double * r = res;
    for(size_t i = 0; i < dn; ++i, ++r, ++m)
      *r += *m * *v;
  }
}

inline void mat_vec_mult
(double const * __restrict__ const m, 
 double const * __restrict__ const v, 
 double * __restrict__ const res, 
 size_t const dn, size_t const dm) noexcept {
  size_t j(0L);
  double const * vj = v,
               * mi = m;
  constexpr size_t const ncl(8L);
  {
    double const * mvals[ncl];
    size_t const end_j = dm - (dm % ncl),
                   inc = ncl * dn;
    for(; j < end_j; j += ncl, vj += ncl, mi += inc){
      double *r = res;
      mvals[0] = mi;
      for(size_t i = 1; i < ncl; ++i)
        mvals[i] = mvals[i - 1L] + dn;
      for(size_t i = 0; i < dn; ++i, ++r)
        for(size_t ii = 0; ii < ncl; ++ii)
          *r += *(vj + ii) * *mvals[ii]++;
    }
  }
  
  mat_vec_mult_vanilla(mi, vj, res, dn, dm - j);
}

// [[Rcpp::export("mat_vec_mult", rng = false)]]
NumericVector mat_vec_mult_cpp(NumericMatrix m, NumericVector v){
  size_t const dn = m.nrow(), 
               dm = m.ncol();
  NumericVector res(dn);
  mat_vec_mult(&m[0], &v[0], &res[0], dn, dm);
  return res;
}

// [[Rcpp::export("mat_vec_mult_vanilla", rng = false)]]
NumericVector mat_vec_mult_vanilla_cpp(NumericMatrix m, NumericVector v){
  size_t const dn = m.nrow(), 
               dm = m.ncol();
  NumericVector res(dn);
  mat_vec_mult_vanilla(&m[0], &v[0], &res[0], dn, dm);
  return res;
}

在我的Makevars文件和gcc-8.3中使用-O3的结果是

set.seed(1)
dn <- 10001L
dm <- 10001L
m <- matrix(rnorm(dn * dm), dn, dm)
lv <- rnorm(dm)

all.equal(drop(m %*% lv), mat_vec_mult(m = m, v = lv))
#R> [1] TRUE
all.equal(drop(m %*% lv), mat_vec_mult_vanilla(m = m, v = lv))
#R> [1] TRUE

bench::mark(
  R              = m %*% lv, 
  `OP's version` = my_mm(m = m, v = lv), 
  `BLAS`         = blas_mm(m = m, v = lv),
  `C++ vanilla`  = mat_vec_mult_vanilla(m = m, v = lv), 
  `C++`          = mat_vec_mult(m = m, v = lv), check = FALSE)
#R> # A tibble: 5 x 13
#R>   expression        min   median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time result memory                 time          gc               
#R>   <bch:expr>   <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl> <int> <dbl>   <bch:tm> <list> <list>                 <list>        <list>           
#R> 1 R             147.9ms    151ms      6.57    78.2KB        0     4     0      609ms <NULL> <Rprofmem[,3] [2 × 3]> <bch:tm [4]>  <tibble [4 × 3]> 
#R> 2 OP's version   56.9ms   57.1ms     17.4     78.2KB        0     9     0      516ms <NULL> <Rprofmem[,3] [2 × 3]> <bch:tm [9]>  <tibble [9 × 3]> 
#R> 3 BLAS           90.1ms   90.7ms     11.0     78.2KB        0     6     0      545ms <NULL> <Rprofmem[,3] [2 × 3]> <bch:tm [6]>  <tibble [6 × 3]> 
#R> 4 C++ vanilla    57.2ms   57.4ms     17.4     78.2KB        0     9     0      518ms <NULL> <Rprofmem[,3] [2 × 3]> <bch:tm [9]>  <tibble [9 × 3]> 
#R> 5 C++              51ms   51.4ms     19.3     78.2KB        0    10     0      519ms <NULL> <Rprofmem[,3] [2 × 3]> <bch:tm [10]> <tibble [10 × 3]>

稍有改进。然而,结果可能非常依赖于BLAS版本。我使用的版本是

sessionInfo()
#R> #...
#R> Matrix products: default
#R> BLAS:   /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.7.1
#R> LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.7.1
#R> ...

我使用 Rcpp::sourceCpp() 编译的整个文件是

#include <Rcpp.h>
#include <R_ext/BLAS.h>
using namespace Rcpp;

inline void mat_vec_mult_vanilla
(double const * __restrict__ m, 
 double const * __restrict__ v, 
 double * __restrict__ const res, 
 size_t const dn, size_t const dm) noexcept {
  for(size_t j = 0; j < dm; ++j, ++v){
    double * r = res;
    for(size_t i = 0; i < dn; ++i, ++r, ++m)
      *r += *m * *v;
  }
}

inline void mat_vec_mult
(double const * __restrict__ const m, 
 double const * __restrict__ const v, 
 double * __restrict__ const res, 
 size_t const dn, size_t const dm) noexcept {
  size_t j(0L);
  double const * vj = v,
               * mi = m;
  constexpr size_t const ncl(8L);
  {
    double const * mvals[ncl];
    size_t const end_j = dm - (dm % ncl),
                   inc = ncl * dn;
    for(; j < end_j; j += ncl, vj += ncl, mi += inc){
      double *r = res;
      mvals[0] = mi;
      for(size_t i = 1; i < ncl; ++i)
        mvals[i] = mvals[i - 1L] + dn;
      for(size_t i = 0; i < dn; ++i, ++r)
        for(size_t ii = 0; ii < ncl; ++ii)
          *r += *(vj + ii) * *mvals[ii]++;
    }
  }
  
  mat_vec_mult_vanilla(mi, vj, res, dn, dm - j);
}

// [[Rcpp::export("mat_vec_mult", rng = false)]]
NumericVector mat_vec_mult_cpp(NumericMatrix m, NumericVector v){
  size_t const dn = m.nrow(), 
               dm = m.ncol();
  NumericVector res(dn);
  mat_vec_mult(&m[0], &v[0], &res[0], dn, dm);
  return res;
}

// [[Rcpp::export("mat_vec_mult_vanilla", rng = false)]]
NumericVector mat_vec_mult_vanilla_cpp(NumericMatrix m, NumericVector v){
  size_t const dn = m.nrow(), 
               dm = m.ncol();
  NumericVector res(dn);
  mat_vec_mult_vanilla(&m[0], &v[0], &res[0], dn, dm);
  return res;
}

// [[Rcpp::export(rng = false)]]
NumericVector my_mm(NumericMatrix m, NumericVector v){
  int nRow = m.rows();
  int nCol = m.cols();
  NumericVector ans(nRow);
  double v_j;
  for(int j = 0; j < nCol; j++){
    v_j = v[j];
    for(int i = 0; i < nRow; i++){
      ans[i] += m(i,j) * v_j;
    }
  }
  return(ans);
}

// [[Rcpp::export(rng = false)]]
NumericVector blas_mm(NumericMatrix m, NumericVector v){
  int nRow = m.rows();
  int nCol = m.cols();
  NumericVector ans(nRow);
  char trans = 'N';
  double one = 1.0, zero = 0.0;
  int ione = 1;
  F77_CALL(dgemv)(&trans, &nRow, &nCol, &one, m.begin(), &nRow, v.begin(),
           &ione, &zero, ans.begin(), &ione);
  return ans;
}

/*** R
set.seed(1)
dn <- 10001L
dm <- 10001L
m <- matrix(rnorm(dn * dm), dn, dm)
lv <- rnorm(dm)

all.equal(drop(m %*% lv), mat_vec_mult(m = m, v = lv))
all.equal(drop(m %*% lv), mat_vec_mult_vanilla(m = m, v = lv))

bench::mark(
  R              = m %*% lv, 
  `OP's version` = my_mm(m = m, v = lv), 
  `BLAS`         = blas_mm(m = m, v = lv),
  `C++ vanilla`  = mat_vec_mult_vanilla(m = m, v = lv), 
  `C++`          = mat_vec_mult(m = m, v = lv), check = FALSE)
*/

1
有趣的是,在你的结果中,BLAS比直接使用C++版本(你的或我的)要慢得多。@RalfStubner的结果显示,他的BLAS大约比我的快两倍。难道Ralf的BLAS使用了2个(或更多)线程?或者使用了不同的版本? - Cliff AB
RalfStubner表示他正在使用OpenBLAS。我正在使用默认的BLAS,所以我认为这是差异的原因。我怀疑这只是实现上的问题,但也可能是他使用了更多的线程。 - Benjamin Christoffersen

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接