使用data.table进行滚动加权平均

3
我想为像这样的数据表计算每个组的滚动加权平均值:

我想为像这样的数据表计算每个组的滚动加权平均值:

DT <- data.table(group = rep(c(1,2), each = 5), value = 1:10, weight = 11:20)
   group value weight
 1:     1     1     11
 2:     1     2     12
 3:     1     3     13
 4:     1     4     14
 5:     1     5     15
 6:     2     6     16
 7:     2     7     17
 8:     2     8     18
 9:     2     9     19
10:     2    10     20

我在这个问题 Rolling over function with 2 vector arguments 中找到了一个使用runner包的有效解决方案:

my_weighted_mean <- function(data) {
  weighted.mean(data[, 1], w = data[, 2])
}

DT[, weighted_mean := runner::runner(x = .SD, f = my_weighted_mean , k = 3, na_pad = TRUE), .SDcols = c("value", "weight"), by = list(group)]

但是代码运行速度非常慢。

我猜应该可以使用 frollapply,但以下代码并不能正常工作,因为我不知道如何使用一个双列函数来调用 frollapply:

 DT[, weighted_mean := frollapply(value, FUN = weighted.mean, n = 3, w = weights), by = list(group)]

寻找更好的性能(并且不需要运行程序的解决方案)

2个回答

6
"frollapply使用双列函数": 不要在"值"上滚动,而是在"索引"上滚动,并且内部函数可以使用任意多列。
 DT[, weighted_mean := frollapply(seq_len(.N),
                                  FUN = function(ind) weighted.mean(value[ind], weight[ind]),
                                  n = 3),
    by = .(group)]
#     group value weight weighted_mean
#     <num> <int>  <int>         <num>
#  1:     1     1     11            NA
#  2:     1     2     12            NA
#  3:     1     3     13      2.055556
#  4:     1     4     14      3.051282
#  5:     1     5     15      4.047619
#  6:     2     6     16            NA
#  7:     2     7     17            NA
#  8:     2     8     18      7.039216
#  9:     2     9     19      8.037037
# 10:     2    10     20      9.035088

1
该死的埃文斯,比我跑得还快。 - undefined

2
这里有另一个选项:
k <- 3L
DT[, v := frollsum(value * weight, k) / frollsum(weight, k)][
    rowid(group) %in% seq(k-1L), v:= NA_real_]

输出:

    group value weight        v
 1:     1     1     11       NA
 2:     1     2     12       NA
 3:     1     3     13 2.055556
 4:     1     4     14 3.051282
 5:     1     5     15 4.047619
 6:     2     6     16       NA
 7:     2     7     17       NA
 8:     2     8     18 7.039216
 9:     2     9     19 8.037037
10:     2    10     20 9.035088

如果有很多组,可能会更快:

set.seed(0L)
ng <- 1e5
nr <- 1e6
DT <- data.table(group=sample(ng, nr, TRUE), value=rnorm(nr), weight=rnorm(nr))
setkey(DT, group)
DT2 <- copy(DT)
k <- 3L

microbenchmark::microbenchmark(times=3L,
    m0 = DT[, weighted_mean := frollapply(seq_len(.N),
              FUN = function(ind) weighted.mean(value[ind], weight[ind]),
              n = k),
        by = .(group)],
    m1 = DT2[, v := frollsum(value * weight, k) / frollsum(weight, k)][
        rowid(group) %in% seq(k-1L), v:= NA_real_]
)
all.equal(DT$weighted_mean, DT2$v)
#[1] TRUE

时间:

Unit: milliseconds
 expr       min        lq       mean    median         uq      max neval
   m0 5670.6707 5725.5539 5805.01047 5780.4370 5872.18035 5963.924     3
   m1   49.2789   54.5392   59.12413   59.7995   64.04675   68.294     3

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接