从右到左更快地评估矩阵乘法

14

我注意到在R语言中,从右到左计算二次形式中的矩阵运算比从左到右更快,具体取决于括号的放置方式。显然,它们都执行了相同数量的计算。不知道这是为什么。这与内存分配有关吗?

# A: 5000 * 5000
# B: 5000 * 2
A = matrix(runif(5000 * 5000), nrow = 5000)
B = matrix(rbinom(5000 * 2, size = 2, prob = 0.3), nrow = 5000)

microbenchmark((t(B) %*% A) %*% B, t(B) %*% (A %*% B), times = 100)

这里是会话信息

下面是会话信息:

R version 4.2.0 (2022-04-22)
Platform: aarch64-apple-darwin20 (64-bit)
Running under: macOS Big Sur 11.4

Matrix products: default
LAPACK: /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/lib/libRlapack.dylib

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] Rcpp_1.0.9           microbenchmark_1.4.9

loaded via a namespace (and not attached):
 [1] compiler_4.2.0           fastmap_1.1.0            cli_3.3.0                htmltools_0.5.3          tools_4.2.0             
 [6] RcppArmadillo_0.11.2.4.0 rstudioapi_0.13          yaml_2.3.5               rmarkdown_2.14           knitr_1.39              
[11] xfun_0.31                digest_0.6.29            rlang_1.0.4              evaluate_0.15           

编辑: 一个简化版本的矩阵乘法,显示相同的错误。

k <- 5000L; m <- n <- 2L;
A <- matrix(rnorm(k * k), k, k);
B <- matrix(rnorm(k * n), k, n);
tB <- t(B);
microbenchmark::microbenchmark(tB %*% A, A %*% B, times = 100)

1
我无法重现这个问题(实际上,在我的系统上第二个要慢一些)。你能发一下 sessionInfo() 吗?(我正在使用 openblas) - user20650
1
我知道这并没有直接回答你的问题,但你也可以考虑crossprod()/tcrossprod()的相对时间。 - Ben Bolker
1
我可以完全复现这个问题。有趣的是,@BenBolker使用microbenchmark(crossprod(B, A) %*% B, t(B) %*% (A %*% B), times = 100)将第一个操作减少了一半,降至67毫秒,但仍然比第二个慢得可测。基本的R操作%*%非常复杂,包含NA和无限检查,并且在矩阵或向量很薄的情况下存在重大的开销潜力。也许这篇相关的文章会对您有所帮助:https://dev59.com/PFUK5IYBdhLWcg3w9zuJ - dcsuka
2
这里有太多的红鲱鱼 - 你应该尝试简化。以下展示了我系统上的“问题”:k <- 5000L; m <- n <- 2L; A <- matrix(rnorm(k * k), k, k); B <- matrix(rnorm(k * n), k, n); tB <- t(B); microbenchmark::microbenchmark(tB %*% A, A %*% B) - Mikael Jagan
就像@user20650一样,我无法重现这个问题。我也在使用OpenBLAS。 - Ralf Stubner
显示剩余14条评论
1个回答

4

看起来这是实现的问题,与访问元素的顺序有关。请参见为什么在迭代2D数组时循环的顺序会影响性能?

当按不同的循环顺序执行相同操作时,时间上会有所不同。

Rcpp::cppFunction("NumericMatrix mm(const NumericMatrix& A, const NumericMatrix& B) {
int M = A.nrow();
int N = A.ncol();
int P = B.ncol();
NumericMatrix res(M, P);
for (int n=0; n<N; ++n) {  //Loop n, p, m
  for (int p=0; p<P; ++p) {
    for (int m=0; m<M; ++m) {
      res[m+p*M] += A[m+M*n] * B[p*N+n];
    }
  }
}
return res;}")

Rcpp::cppFunction("NumericMatrix mm2(const NumericMatrix& A, const NumericMatrix& B) {
int M = A.nrow();
int N = A.ncol();
int P = B.ncol();
NumericMatrix res(M, P);
for (int m=0; m<M; ++m) {  //Loop m, p, n
  for (int p=0; p<P; ++p) {
     for (int n=0; n<N; ++n) {
      res[m+p*M] += A[m+M*n] * B[p*N+n];
    }
  }
}
return res;}")

k <- 5000L; m <- n <- 2L;
A <- matrix(rnorm(k * k), k, k);
B <- matrix(rnorm(k * n), k, n);
tB <- t(B);

met <- alist("(tB*A)B"     = tB %*% A %*% B,
             "tB(A*B)"     = tB %*% (A %*% B),
             "mm (tB*A)B"  = mm(mm(tB, A), B),
             "mm tB(A*B)"  = mm(tB, mm(A, B)),
             "mm2 (tB*A)B" = mm2(mm2(tB, A), B),
             "mm2 tB(A*B)" = mm2(tB, mm2(A, B)),
             "cp(B,A)B"    = crossprod(B, A) %*% B,
             "cp(B,A*B)"   = crossprod(B, A %*% B) )

bench::mark(exprs = met)
#  expression     min median itr/s…¹ mem_a…² gc/se…³ n_itr  n_gc total…⁴ result  
#  <bch:expr>  <bch:> <bch:>   <dbl> <bch:b>   <dbl> <int> <dbl> <bch:t> <list>  
#1 (tB*A)B     79.5ms 80.1ms    12.5 78.17KB       0     7     0   562ms <dbl[…]>
#2 tB(A*B)     33.8ms 34.4ms    28.4 78.17KB       0    15     0   528ms <dbl[…]>
#3 mm (tB*A)B  61.9ms 62.5ms    15.9  3.85MB       0     8     0   502ms <dbl[…]>
#4 mm tB(A*B)  19.9ms 20.7ms    48.1 83.16KB       0    25     0   520ms <dbl[…]>
#5 mm2 (tB*A)B 35.9ms 39.4ms    25.8 87.29KB       0    13     0   504ms <dbl[…]>
#6 mm2 tB(A*B) 47.8ms 48.1ms    20.6 83.16KB       0    11     0   535ms <dbl[…]>
#7 cp(B,A)B    44.1ms 44.5ms    22.4 80.42KB       0    12     0   536ms <dbl[…]>
#8 cp(B,A*B)   34.1ms 36.5ms    27.1 78.17KB       0    14     0   516ms <dbl[…]>

microbenchmark::microbenchmark(list = met)
#Unit: milliseconds
#        expr      min       lq     mean   median       uq      max neval
#     (tB*A)B 77.09484 77.86891 79.09483 78.44832 80.08971 87.05563   100
#     tB(A*B) 33.63306 34.22562 36.08482 35.14064 36.64080 51.39962   100
#  mm (tB*A)B 62.05235 64.14361 66.54568 65.16927 67.98617 75.96242   100
#  mm tB(A*B) 19.67066 20.28369 20.83781 20.53820 21.19940 23.64119   100
# mm2 (tB*A)B 35.31290 35.70006 36.62846 36.10282 37.41669 40.47473   100
# mm2 tB(A*B) 48.16574 49.70702 51.55844 50.26292 52.46479 67.44558   100
#    cp(B,A)B 43.18166 44.01366 45.28434 44.71301 46.41521 48.97891   100
#   cp(B,A*B) 33.62158 34.47070 35.84743 35.11853 36.55979 48.89021   100

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接