data.table | 更快的行递归更新组内数据

19

我需要执行以下递归的逐行操作来获取z

myfun = function (xb, a, b) {

z = NULL

for (t in 1:length(xb)) {

    if (t >= 2) { a[t] = b[t-1] + xb[t] }
    z[t] = rnorm(1, mean = a[t])
    b[t] = a[t] + z[t]

}

return(z)

}

set.seed(1)

n_smpl = 1e6 
ni = 5

id = rep(1:n_smpl, each = ni)

smpl = data.table(id)
smpl[, time := 1:.N, by = id]

a_init = 1; b_init = 1
smpl[, ':=' (a = a_init, b = b_init)]
smpl[, xb := (1:.N)*id, by = id]

smpl[, z := myfun(xb, a, b), by = id]

我希望能够得到这样的结果:
      id time a b  xb            z
  1:   1    1 1 1   1    0.3735462
  2:   1    2 1 1   2    2.7470924
  3:   1    3 1 1   3    8.4941848
  4:   1    4 1 1   4   20.9883695
  5:   1    5 1 1   5   46.9767390
 ---                              
496: 100    1 1 1 100    0.3735462
497: 100    2 1 1 200  200.7470924
498: 100    3 1 1 300  701.4941848
499: 100    4 1 1 400 1802.9883695
500: 100    5 1 1 500 4105.9767390

这种方法是可行的,但需要时间:

system.time(smpl[, z := myfun(xb, a, b), by = id])
   user  system elapsed 
 33.646   0.994  34.473

给定我的实际数据量(超过200万个观测值),我需要使其更快。我猜想使用by = .(id, time)do.call(myfun, .SD), .SDcols = c('xb', 'a', 'b')会更快,避免了myfun内部的循环。然而,当我在data.table中逐行运行此操作时,我不确定如何更新b及其滞后(可能使用shift)。有任何建议吗?


我得到了 Error in rnorm(1, mean = a[t]) : object 'a' not found 错误。请问您能确保代码在新的 R 会话中能够正常运行吗? - Matt Dowle
你把它粘贴到一个新的 R 会话中了吗?我仍然得到相同的错误。 - Matt Dowle
8
不用担心。欢迎来到S.O.祝贺你提出第5000个data.table问题! - Matt Dowle
for 循环内的 set.seed(1) 看起来很奇怪。那是必要的吗?通常在可重复示例的开头调用一次 set.seed(),而不是多次调用。 - Matt Dowle
刚刚尝试使示例可重现 - 这并不是必要的。 - jayc
显示剩余3条评论
1个回答

40

好问题!

从一个新的 R 会话开始,展示包含 500 万行数据的演示数据,以下是你提出的函数以及在我的笔记本电脑上的时间。注释已经嵌入到代码中。

require(data.table)   # v1.10.0
n_smpl = 1e6
ni = 5
id = rep(1:n_smpl, each = ni)
smpl = data.table(id)
smpl[, time := 1:.N, by = id]
a_init = 1; b_init = 1
smpl[, ':=' (a = a_init, b = b_init)]
smpl[, xb := (1:.N)*id, by = id]

myfun = function (xb, a, b) {

  z = NULL
  # initializes a new length-0 variable

  for (t in 1:length(xb)) {

      if (t >= 2) { a[t] = b[t-1] + xb[t] }
      # if() on every iteration. t==1 could be done before loop

      z[t] = rnorm(1, mean = a[t])
      # z vector is grown by 1 item, each time

      b[t] = a[t] + z[t]
      # assigns to all of b vector when only really b[t-1] is
      # needed on the next iteration 
  }
  return(z)
}

set.seed(1); system.time(smpl[, z := myfun(xb, a, b), by = id][])
   user  system elapsed 
 19.216   0.004  19.212

smpl
              id time a b      xb            z
      1:       1    1 1 1       1 3.735462e-01
      2:       1    2 1 1       2 3.557190e+00
      3:       1    3 1 1       3 9.095107e+00
      4:       1    4 1 1       4 2.462112e+01
      5:       1    5 1 1       5 5.297647e+01
     ---                                      
4999996: 1000000    1 1 1 1000000 1.618913e+00
4999997: 1000000    2 1 1 2000000 2.000000e+06
4999998: 1000000    3 1 1 3000000 7.000003e+06
4999999: 1000000    4 1 1 4000000 1.800001e+07
5000000: 1000000    5 1 1 5000000 4.100001e+07

所以需要击败的时间是19.2秒。在所有这些时间中,我已经在本地运行了3次命令,以确保它是稳定的时间。在这个任务中,时间变化是微不足道的,因此我只报告一个时间,以使答案更快阅读。

解决上面myfun()中的内联注释:

myfun2 = function (xb, a, b) {

  z = numeric(length(xb))
  # allocate size up front rather than growing

  z[1] = rnorm(1, mean=a[1])
  prevb = a[1]+z[1]
  t = 2L
  while(t<=length(xb)) {
    at = prevb + xb[t]
    z[t] = rnorm(1, mean=at)
    prevb = at + z[t]
    t = t+1L
  }
  return(z)
}
set.seed(1); system.time(smpl[, z2 := myfun2(xb, a, b), by = id][])
   user  system elapsed 
 13.212   0.036  13.245 
smpl[,identical(z,z2)]
[1] TRUE

那已经很好了(从19.2秒降到13.2秒),但它仍然是R级别的for循环。乍一看,它不可能被矢量化,因为rnorm()调用取决于先前的值。事实上,通过使用属性 m+sd*rnorm(mean=0,sd=1) == rnorm(mean=m, sd=sd) 并调用矢量化的rnorm(n=5e6) 一次而不是5e6次,它可能可以被矢量化。但是,这可能涉及cumsum() 来处理组。所以我们不要尝试,因为那可能会使代码更难读,并且只适用于这个特定问题。
因此,让我们尝试 Rcpp,它看起来非常类似于您编写的风格,并且更广泛适用:
require(Rcpp)   # v0.12.8
cppFunction(
'NumericVector myfun3(IntegerVector xb, NumericVector a, NumericVector b) {
  NumericVector z = NumericVector(xb.length());
  z[0] = R::rnorm(/*mean=*/ a[0], /*sd=*/ 1);
  double prevb = a[0]+z[0];
  int t = 1;
  while (t<xb.length()) {
    double at = prevb + xb[t];
    z[t] = R::rnorm(at, 1);
    prevb = at + z[t];
    t++;
  }
  return z;
}')

set.seed(1); system.time(smpl[, z3 := myfun3(xb, a, b), by = id][])
   user  system elapsed 
  1.800   0.020   1.819 
smpl[,identical(z,z3)]
[1] TRUE

更好的情况是:从19.2秒降到1.8秒。但每次调用该函数都会调用第一行(NumericVector()),它会分配一个与组中行数相同长度的新向量。然后向其填充数据并返回该向量,最终将其复制到正确位置的最终列中(通过:=),仅用于释放。所有这些100万个小临时向量(每个组一个)的分配和管理都有点复杂。
为什么不一次性处理整列呢?您已经以for循环的方式编写了代码,这样没有任何问题。让我们调整C函数,使其也接受id列,并添加if语句以在达到新组时执行处理。
cppFunction(
'NumericVector myfun4(IntegerVector id, IntegerVector xb, NumericVector a, NumericVector b) {

  // ** id must be pre-grouped, such as via setkey(DT,id) **

  NumericVector z = NumericVector(id.length());
  int previd = id[0]-1;  // initialize to anything different than id[0]
  for (int i=0; i<id.length(); i++) {
    double prevb;
    if (id[i]!=previd) {
      // first row of new group
      z[i] = R::rnorm(a[i], 1);
      prevb = a[i]+z[i];
      previd = id[i];
    } else {
      // 2nd row of group onwards
      double at = prevb + xb[i];
      z[i] = R::rnorm(at, 1);
      prevb = at + z[i];
    }
  }
  return z;
}')

system.time(setkey(smpl,id))  # ensure grouped by id
   user  system elapsed
  0.028   0.004   0.033
set.seed(1); system.time(smpl[, z4 := myfun4(id, xb, a, b)][])
   user  system elapsed 
  0.232   0.004   0.237 
smpl[,identical(z,z4)]
[1] TRUE

现在好多了:从19.2秒降至0.27秒


12
那真是令人印象深刻的。 - zacdav
2
这非常有帮助!非常感谢,Matt。 - jayc
2
@jayc,你被好好地dowled了。一次彻底的dowleing。 - rosscova
1
是的,确实如此。祝贺他的第5000个问题,而我很荣幸能成为历史的一部分 :) - jayc

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接