如何加速在R中用于嵌套矩阵匹配和colSums的for循环

5

我有一个看起来很简单的问题,需要比我开发的R实现更快的实现方法。

我为这个示例初始化了随机种子和维度:

set.seed(1)
d1<-400
d2<-20000
d3<-50

我有一个大小为 d1 x d2 的矩阵 X:

X<-as.data.frame(matrix(rnorm(d1*d2),nrow=d1,ncol=d2))
rownames(X)<-paste0("row",1:nrow(X))
colnames(X)<-paste0("col",1:ncol(X))

同时,还有一个包含d1个行索引的向量u:

u<-sample(rownames(X),nrow(X),replace=TRUE)

我还有一个矩阵 C,它的行有名称,大小为 d3 x d2

C<-matrix(rnorm(d3*d2),nrow=d3,ncol=d2)
rownames(C)<-sample(rownames(X),nrow(C),replace=FALSE)

现在,使用以下非常缓慢的循环,我正在将矩阵C填充为匹配X行的总和:

system.time(
    for(i in 1:nrow(C)){
        indexes<-which(u==rownames(C)[i])
        C[i,] <- colSums(X[indexes,])
    }
)

在我的个人电脑上,这个操作大约需要11.5秒钟的时间,但我相信通过避免for循环可以加速。有什么好的建议吗?非常感谢!


1
在创建u时,replace=TRUE背后的理由是什么? - jay.sf
3个回答

3
您可以尝试使用sapply进行循环。
system.time(
  C2 <- `dimnames<-`(t(sapply(match(rownames(C), u), function(x) 
    colSums(X[x, ]))), list(rownames(C), NULL))
)
#  user  system elapsed 
# 20.06    0.03   20.14 

stopifnot(all.equal(C, C2))

相比之下

system.time(
  for(i in 1:nrow(C)){
    indexes <- which(u == rownames(C)[i])
    C[i, ] <- colSums(X[indexes, ])
  }
)
#  user  system elapsed 
# 20.76    0.69   28.30  

目前,这只是一个单一的测量。

更新

似乎运行得稍微快一点...

Unit: seconds
    expr      min       lq     mean   median       uq      max neval cld
 forloop 20.44852 20.57730 21.67771 20.74106 21.01723 29.63220    10   a
  sapply 19.86707 20.17126 21.34529 20.50283 20.81254 29.73764    10   a

更新 2

但是你可以使用parallel::parSapply来实现。

system.time({
  library(parallel)
  cl <- makeCluster(detectCores() - 1)
  clusterExport(cl, c("C", "u", "X"))
  C3 <- parSapply(cl, match(rownames(C), u), function(x) colSums(X[x, ]))
  stopCluster(cl)
  C3 <- `dimnames<-`(t(C3), list(rownames(C), NULL))
})
# user  system elapsed 
# 0.81    3.16    9.82

stopifnot(all.equal(C, C3))

现在,我的计算机使用 for 循环与你的一样快 :)


谢谢jay.sf,这是一个有趣的解决方案,但还有改进的余地。我正在考虑用C实现这个,但我想在社区中寻找纯R的解决方案 :-) - Federico Giorgi
1
@FedericoGiorgi 请查看更新。对于C语言,您可能需要了解Rcpp - jay.sf

3

使用matrixStats::colSums2函数并选项传递行索引,将rownames()移出循环(需要将X转换为矩阵):

Xm <- as.matrix(X)
names_of_rows <- rownames(C)
system.time(for (i in 1:nrow(C)) {
  indexes <- which(u == names_of_rows[i])
  C[i, ] <-  matrixStats::colSums2(Xm, rows = indexes)
})
# 0.03 sek

2
到目前为止所有的解决方案都很好,但这一个是完美的。速度提升了100倍以上。谢谢minem,你帮了大忙。 - Federico Giorgi

1

在这里提供一个使用data.table的解决方案。如果OP只想要一个基于R语言的解决方案,我会删除此帖:

library(data.table)
mtd_dt <- function() {
    setDT(dtX)[, u := as.integer(gsub("row","",u))]
    mX <- melt(dtX, id.var="u", variable.name="col")
    C2 <- data.table(rn=seq_len(nrow(C)), u=as.integer(gsub("row","",rownames(C))))
    dcast(mX[C2, on=.(u)][, sum(value), by=.(rn, col)], rn ~ col, value.var="V1")[,
        "NA" := NULL][,
            lapply(.SD, function(x) replace(x, is.na(x), 0))]
}

时间:

# A tibble: 2 x 14
  expression      min     mean   median      max `itr/sec` mem_alloc  n_gc n_itr total_time result                    memory                time    gc             
  <chr>      <bch:tm> <bch:tm> <bch:tm> <bch:tm>     <dbl> <bch:byt> <dbl> <int>   <bch:tm> <list>                    <list>                <list>  <list>         
1 mtd0()        59.1s    59.1s    59.1s    59.1s    0.0169     447MB    24     1      59.1s <dbl [50 x 20,000]>       <Rprofmem [44,515 x ~ <bch:t~ <tibble [1 x 3~
2 mtd_dt()       2.7s     2.7s     2.7s     2.7s    0.370      309MB     4     1       2.7s <data.table [50 x 20,001~ <Rprofmem [88,029 x ~ <bch:t~ <tibble [1 x 3~

计时代码:
mtd0 <- function() {
    for (i in 1:nrow(C)) {
        indexes <- which(u==rownames(C)[i])
        C[i, ] <- colSums(X[indexes, ])
    }
    C
}

bench::mark(mtd0(), mtd_dt(), check=FALSE)

数据:

library(data.table)
set.seed(0)
#d1 <- 10
#d2 <- 10
#d3 <- 5
d1<-400
d2<-20000
d3<-50

X <- as.data.frame(matrix(rnorm(d1*d2),nrow=d1,ncol=d2))
rownames(X) <- paste0("row",1:nrow(X))
colnames(X) <- paste0("col",1:ncol(X))
dtX <- X

u <- sample(rownames(X),nrow(X),replace=TRUE)

C <- matrix(0,nrow=d3,ncol=d2)
rownames(C) <- sample(rownames(X),nrow(C),replace=FALSE)

1
谢谢!这是一个非常有趣的解决方案,利用了强大的data.table。虽然不是获胜者,但因为matrixStats::colSums2()的解决方案更快。 - Federico Giorgi

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接