R数据表行操作的首选高性能过程是什么?

5
以下代码是否代表了在遍历R中data.table的行并将每行发现的值传递给函数时的首选过程?或者还有更高效的方法来实现这个功能吗?
library(data.table)
set.seed(2)
n <- 100
b <- c(0.5, 1.5, -1)
phi <- 0.8
X <- cbind(1, matrix(rnorm(n*2, 0, 1), ncol = 2))
y <- X %*% matrix(b, ncol = 1) + rnorm(n, 0, phi)
d <- data.table(y, X)
setnames(d, c("y", "x0", "x1", "x2"))

logpost <- function(d, b1, b2, b3, phi, mub = 1, taub = 10, a = 0.5, z = 0.7){
    N <- nrow(d)
    mu <- b1 + b2 * d$x1 + b3 * d$x2
    lp <- -N * log(phi) -
        (1/(2*phi^2)) * sum( (d$y-mu)^2  ) -
        (1/(2*taub^2))*( (b1-mub)^2 + (b2-mub)^2 + (b3-mub)^2 ) -
        (a+1)*log(phi) - (z/phi)
    lp
}

nn <- 21
grid <- data.table(
expand.grid(b1 = seq(0, 1, len = nn),
    b2 = seq(1, 2, len = nn),
    b3 = seq(-1.5, -0.5, len = nn),
    phi = seq(0.4, 1.2, len = nn)))
grid[, id := 1:.N]
setkey(grid, id)

wraplogpost <- function(dd){
    logpost(d, dd$b1, dd$b2, dd$b3, dd$phi)
}
start <- Sys.time()
grid[, lp := wraplogpost(.SD), by = seq_len(nrow(grid))]
difftime(Sys.time(), start)
# Time difference of 2.081544 secs

编辑:显示前几条记录

> head(grid)
b1 b2   b3 phi id        lp
1: 0.00  1 -1.5 0.4  1 -398.7618
2: 0.05  1 -1.5 0.4  2 -380.3674
3: 0.10  1 -1.5 0.4  3 -363.5356
4: 0.15  1 -1.5 0.4  4 -348.2663
5: 0.20  1 -1.5 0.4  5 -334.5595
6: 0.25  1 -1.5 0.4  6 -322.4152

我尝试使用set,但这种方法似乎不太好

start <- Sys.time()
grid[, lp := NA_real_]
for(i in 1:nrow(grid)){
    llpp <- wraplogpost(grid[i])
    set(grid, i, "lp", llpp)
}
difftime(Sys.time(), start)
# Time difference of 21.71291 secs

编辑:显示前几条记录

> head(grid)
b1 b2   b3 phi id        lp
1: 0.00  1 -1.5 0.4  1 -398.7618
2: 0.05  1 -1.5 0.4  2 -380.3674
3: 0.10  1 -1.5 0.4  3 -363.5356
4: 0.15  1 -1.5 0.4  4 -348.2663
5: 0.20  1 -1.5 0.4  5 -334.5595
6: 0.25  1 -1.5 0.4  6 -322.4152

欢迎提供相关文档或建议。

编辑:根据评论:

start <- Sys.time()
grid[, lp := wraplogpost(.SD), by = .I]
difftime(Sys.time(), start)
Warning messages:
1: In b2 * d$x1 :
    longer object length is not a multiple of shorter object length
2: In b3 * d$x2 :
    longer object length is not a multiple of shorter object length
3: In d$y - mu :
    longer object length is not a multiple of shorter object length
> difftime(Sys.time(), start)
Time difference of 0.01199317 secs
> 
> head(grid)
b1 b2   b3 phi id        lp
1: 0.00  1 -1.5 0.4  1 -620977.2
2: 0.05  1 -1.5 0.4  2 -620977.2
3: 0.10  1 -1.5 0.4  3 -620977.2
4: 0.15  1 -1.5 0.4  4 -620977.2
5: 0.20  1 -1.5 0.4  5 -620977.2
6: 0.25  1 -1.5 0.4  6 -620977.2

该代码生成了错误的值给lp变量。

编辑 感谢评论和回答。我知道这种情况可以通过使用其他方法来解决,但我对使用data.table时的首选方法感兴趣。

再次编辑 感谢回复。由于还没有明确解答如何使用data.table来实现这一点的问题,因此我暂时认为在不转向基本R的情况下无法达到理想效果。


尝试使用 by = .I。它更快,参见 ?.I - Rui Barradas
谢谢。我认为帮助想表达的是,在j中获取行索引应该使用.I而不是作为“by”项。另外,这里的答案:https://dev59.com/LFoU5IYBdhLWcg3wO1QW#37668187,至少对我来说,建议不要在“by”子句中使用`.I`。我是否解释了答案的含义? - t-student
是的,我相信你误解了那个答案。.I返回seq_len(nrow(grid)),但由于它是由data.table计算的值,所以更快。试试看吧。 - Rui Barradas
1
这通常很慢,因为你正在进行大量的 $ 提取。如果你的数据是矩阵而不是列表(即 data.table),那么你的循环将工作得更好。 - Cole
2个回答

5

如果您想获得更好的性能(时间),您可以将逐行函数重写为矩阵计算。

start <- Sys.time()
grid_mat <- as.matrix(grid[, list(b1, b2, b3, 1)])
# function parameters
N <- nrow(d); mub = 1; taub = 10; a = 0.5; z = 0.7
d$const <- 1

# combining d$y - mu in this step already
mu_op <- matrix(c(-d$const, -d$x1, -d$x2, d$y), nrow = 4, byrow = TRUE)
mu_mat <- grid_mat %*% mu_op
mub_mat <- (grid_mat[, c("b1", "b2", "b3")] - mub)^2
# just to save one calculation of the log
phi <- grid$phi
log_phi <- log(grid$phi)

grid$lp2 <- -N * log_phi -
  (1/(2*phi^2)) * rowSums(mu_mat^2) -
  (1/(2*taub^2))*( rowSums(mub_mat) ) -
  (a+1)*log_phi - (z/phi)
head(grid)
difftime(Sys.time(), start)

第一行:
     b1 b2   b3 phi id        lp       lp2
1: 0.00  1 -1.5 0.4  1 -398.7618 -398.7618
2: 0.05  1 -1.5 0.4  2 -380.3674 -380.3674
3: 0.10  1 -1.5 0.4  3 -363.5356 -363.5356
4: 0.15  1 -1.5 0.4  4 -348.2663 -348.2663
5: 0.20  1 -1.5 0.4  5 -334.5595 -334.5595
6: 0.25  1 -1.5 0.4  6 -322.4152 -322.4152

关于时间:

# on your code on my pc:
Time difference of 4.390684 secs
# my code on my pc:
Time difference of 0.680476 secs

很棒的答案,你抓住了加速的要点。我投票支持你的回答! - ThomasIsCoding
谢谢,但是你的回答让它变成了一个更加可用的函数。 - Jakob Gepp

3

我认为你可以使用矩阵乘法和其他向量化技术简化你的代码,这有助于避免逐行运行函数logpost


以下是logpost的向量化版本,即logpost2

logpost2 <- function(d, dd, mub = 1, taub = 10, a = 0.5, z = 0.7) {
  bmat <- as.matrix(dd[, .(b1, b2, b3)])
  xmat <- cbind(1, as.matrix(d[, .(x1, x2)]))
  phi <- dd$phi
  phi_log <- log(phi)
  lp <- -(a + nrow(d) + 1) * phi_log -
    (1 / (2 * phi^2)) * colSums((d$y - tcrossprod(xmat, bmat))^2) -
    (1 / (2 * taub^2)) * rowSums((bmat - mub)^2) - (z / phi)
  lp
}

你会看到

> start <- Sys.time()

> grid[, lp := logpost2(d, .SD)]

> difftime(Sys.time(), start)
Time difference of 0.1966231 secs

and

> head(grid)
     b1 b2   b3 phi id        lp
1: 0.00  1 -1.5 0.4  1 -398.7618
2: 0.05  1 -1.5 0.4  2 -380.3674
3: 0.10  1 -1.5 0.4  3 -363.5356
4: 0.15  1 -1.5 0.4  4 -348.2663
5: 0.20  1 -1.5 0.4  5 -334.5595
6: 0.25  1 -1.5 0.4  6 -322.4152

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接