使用列和行索引从矩阵中提取数值

3

假设我有两个矩阵:

> a
     [,1] [,2] [,3] [,4] [,5] [,6] [,7]
[1,]    6   10    5    7    2    2    6
[2,]   10    6    7    7    4    3   12
[3,]   11   10    2   10    6   11    9

并且

> b
         [,1] [,2] [,3]
    [1,]    4    1    4
    [2,]    3    6    3
    [3,]    2    5    2

ab中的行数相同。我正在寻找一种向量化的方法,以逐行基础从a中提取由b中列号指示的项。因此,结果c应如下所示:

> c
     [,1] [,2] [,3]
[1,]    7    6    7
[2,]    7    3    7
[3,]    10    6   10
a[,b[1,]]a[,b[2,]]a[,b[3,]] 只能正确地处理行 1、2 和 3 的结果。这是否可以通过简单的矩阵函数完成?是否需要使用 apply 函数?
我试图改编一个类似问题的解决方案,参见使用行列索引从矩阵中获取值,但是不理解在这里如何使用 cbind 提取矩阵元素。
2个回答

3

你可以尝试

t(sapply(seq_len(nrow(a)), function(i) a[i, b[i, ]]))
#      [,1] [,2] [,3]
# [1,]    7    6    7
# [2,]    7    3    7
# [3,]   10    6   10

而且你可能会发现,使用 vapply 命令的解决方案比上面使用 sapply 命令的解决方案略微提高了速度。

s <- seq_len(nrow(a))
t(vapply(s, function(i) a[i, b[i, ]], numeric(ncol(b))))
#      [,1] [,2] [,3]
# [1,]    7    6    7
# [2,]    7    3    7
# [3,]   10    6   10

或者使用 for 循环的解决方案如下:

m <- matrix(, nrow(b), ncol(b))
for(i in seq_len(nrow(a))) { m[i, ] <- a[i, b[i, ]] }
m
#      [,1] [,2] [,3]
# [1,]    7    6    7
# [2,]    7    3    7
# [3,]   10    6   10

@David - 太好了。我担心它会很慢,所以我正在努力做出更好的东西。 - Rich Scriven
1
在我的机器上,对于一个包含整数数据的200kb对象,它大约需要1秒钟完成任务。对于我的应用程序来说,速度非常出色。再次感谢。 - David

3
这是一个cbind版本。
 t(`dim<-`(a[cbind(rep(1:nrow(a), each=ncol(b)), c(t(b)))], dim(b)))
 #     [,1] [,2] [,3]
 #[1,]    7    6    7
 #[2,]    7    3    7
 #[3,]   10    6   10

或者像@thelatemail所建议的那样。
 matrix(a[cbind(c(row(b)),c(b))],nrow=nrow(a))
 #     [,1] [,2] [,3]
 #[1,]    7    6    7
 #[2,]    7    3    7
 #[3,]   10    6   10

基准测试

set.seed(24)
a1 <- matrix(sample(1:10, 2e5*7, replace=TRUE), ncol=7)
set.seed(28)
b1 <- matrix(sample(1:7,2e5*3, replace=TRUE), ncol=3)

f1 <- function() {s <- seq_len(nrow(a1))
 t(vapply(s, function(i) a1[i, b1[i,]],numeric(ncol(b1))))
}
f2 <- function() {matrix(a1[cbind(c(row(b1)),c(b1))], nrow=nrow(a1)) }
f3 <- function(){t(`dim<-`(a1[cbind(rep(1:nrow(a1),
                    each=ncol(b1)), c(t(b1)))], dim(b1)))} 
library(microbenchmark)
microbenchmark(f1(), f2(), f3(), unit='relative', times=10L)
#Unit: relative
# expr       min        lq      mean    median        uq       max neval cld
#f1() 16.636045 16.603856 15.319595 15.799335 13.869147 14.629315    10   b
#f2()  1.000000  1.000000  1.000000  1.000000  1.000000  1.000000    10  a 
#f3()  1.310433  1.306228  1.258715  1.278504  1.237299  1.236448    10  a 

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接