如何使这个 R data.table 的 join+group+summarise 操作更快?

3

现实问题概述

本质上,这是一个线性方程组的情景评估。 我有两个数据表。

  • s_dt包含每个观察到的情景(o)的情景、驱动程序(d)和值(v)。
  • c_dt包含一系列适用于多个拟合模型基础(b)的项(n)。
    驱动程序的各个幂以及相关系数被编码为名称-值对(dt)。
    每个基础(b)本质上都是一个具有n个术语的多项式。

问题

下面的重现案例提供了所需的输出格式。 但对于所需的用例来说太慢了,即使在简化的问题上也是如此。 数字是垃圾,但我不能分享实际数据。 在真实数据上运行需要类似的时间。

在我的系统上(12个线程),“lil”问题大约需要3秒钟。 但“big”问题要大4000倍。 因此预计需要大约3小时。 痛苦!
目标是使“big”问题在子5分钟内运行(或者理想情况下更快!)

那么,厉害的聪明人们,如何让这个过程变得更快
(而减速的根本原因是什么?)

如果基于base/tidyverse的解决方案能够满足性能需求,我也很乐意接受。我只是认为对于这个问题的规模,data.table是最好的选择。

当前解决方案

s_dt上运行fun,按o进行分组。
fun:将c_dt与每个组数据连接,以填充v,从而使计算每个多项式方程的结果r成为可能。

data.table的说法:

s_dt[, fun(.SD), keyby = .(o)]

复现案例

  • 创建两个数据表,其组合和字段类型与实际问题相匹配。
    但为了说明目的而缩小了规模。
  • 定义fun,然后运行以填充所有场景的r
library(data.table)

# problem sizing ----
dims <- list(o = 50000, d = 50, b = 250, n = 200) # "big" problem - real-life size
dims <- list(o =   100, d = 50, b =  25, n = 200) # "lil" problem (make runtime shorter as example)

# build some test data tables ----
build_s <- function() {
  o <- seq_len(dims$o)
  d <- paste0("d",seq_len(dims$d))
  v <- as.double(seq_len(dims$o * dims$d))/10000
  CJ(o, d)[, `:=`(v = v)]
}
s_dt <- build_s()

build_c <- function() {
  b <- paste0("c", seq_len(dims$b))
  n <- seq_len(dims$n)
  d <- c("c", paste0("d", seq_len(dims$d)))
  t <- as.double(rep_len(0:6, dims$b * dims$n * (dims$d+1)))
  dt <- CJ(d, b, n)[, `:=`(t = t)]
  dt <- dt[t != 0]
}
c_dt <- build_c()

# define fun and evaluate ---- 
# (this is what needs optimising)
profvis::profvis({
  fun <- function(dt) {
    # don't use chaining here, for more useful profvis output
    dt <- dt[c_dt, on = .(d)]
    dt <- dt[, r := fcase(d == "c", t,
                          is.na(v), 0,
                          rep(TRUE, .N), v^t)]
    dt <- dt[, .(r = prod(r)), keyby = .(b, n)]
    dt <- dt[, .(r = sum(r)),  keyby = .(b)]
  }
  res <- s_dt[, fun(.SD), keyby = .(o)]
})

示例输入和输出

> res
        o   b            r
   1:   1  c1 0.000000e+00
   2:   1 c10 0.000000e+00
   3:   1 c11 0.000000e+00
   4:   1 c12 0.000000e+00
   5:   1 c13 0.000000e+00
  ---                     
2496: 100  c5 6.836792e-43
2497: 100  c6 6.629646e-43
2498: 100  c7 6.840915e-43
2499: 100  c8 6.624668e-43
2500: 100  c9 6.842608e-43

> s_dt
        o   d      v
   1:   1  d1 0.0001
   2:   1 d10 0.0002
   3:   1 d11 0.0003
   4:   1 d12 0.0004
   5:   1 d13 0.0005
  ---               
4996: 100 d50 0.4996
4997: 100  d6 0.4997
4998: 100  d7 0.4998
4999: 100  d8 0.4999
5000: 100  d9 0.5000

> c_dt
         d  b   n t
     1:  c c1   2 1
     2:  c c1   3 2
     3:  c c1   4 3
     4:  c c1   5 4
     5:  c c1   6 5
    ---            
218567: d9 c9 195 5
218568: d9 c9 196 6
218569: d9 c9 198 1
218570: d9 c9 199 2
218571: d9 c9 200 3

你可能需要检查 collapse,因为它可以提高性能。另一个选择是使用基于 Rust 的 Python Polars 进行分组,速度很快。 - akrun
请澄清一下,只需要优化 profvis 调用中的部分,对吗?前两部分只是数据创建? - Cole
是的@Cole,只需要 profvis 部分。其余部分确实只是构建一些测试数据,以便它成为一个可运行的示例。 - mb147
1
在这个函数中,我们过滤t!=0。我们能事先做到这一点吗?由于它是在c_dt中定义的,这意味着我们可以提前过滤,这样我们就不必每次都进行过滤了。在较小的示例中,该过滤占用了总时间的相当大一部分,因此这种优化非常重要。 - Cole
@Cole:是的,你说得对。过滤器可以提前进行处理。 但我在性能分析中没有看到大幅度的速度提升。 我会更新复现操作,这样你就可以看到跟我一样的性能分析输出。 - mb147
现实世界中的数据或模型是否完全使用CJ制作?还是仅用于说明? - Cole
1个回答

2

这将很难完全向量化。 "大"问题需要执行的操作非常多,因此并行处理可能是达到约5分钟的最直接方式。

但首先,我们可以通过使用RcppArmadillo来进行乘积和求和计算,而不是使用data.table的分组操作,从而获得约3倍的速度提升。

library(data.table)
library(parallel)

Rcpp::cppFunction(
  "std::vector<double> sumprod(arma::cube& a) {
  for(unsigned int i = 1; i < a.n_slices; i++) a.slice(0) %= a.slice(i);
  return(as<std::vector<double>>(wrap(sum(a.slice(0), 0))));
}",
  depends = "RcppArmadillo",
  plugins = "cpp11"
)

cl <- makeForkCluster(detectCores() - 1L)

以下方法需要进行广泛的预处理。其好处在于可以轻松并行化。但是,它只适用于每个os_dt$d的值与MRE中的值相同的情况。
identical(s_dt$d, rep(s_dt[o == 1]$d, length.out = nrow(s_dt)))
#> [1] TRUE

现在让我们构建接受 s_dtc_dt 的函数:
# slightly modified original function for comparison
fun1 <- function(dt, c_dt) {
  # don't use chaining here, for more useful profvis output
  dt <- dt[c_dt, on = .(d)]
  dt <- dt[, r := fcase(d == "c", t,
                        is.na(v), 0,
                        rep(TRUE, .N), v^t)]
  dt <- dt[, .(r = prod(r)), keyby = .(b, n)]
  dt <- dt[, .(r = sum(r)),  keyby = .(b)]
}

fun2 <- function(s_dt, c_dt, cl = NULL) {
  s_dt <- copy(s_dt)
  c_dt <- copy(c_dt)
  # preprocess to get "a", "tt", "i", and "idxs"
  i_dt <- s_dt[o == 1][, idxs := .I][c_dt, on = .(d)][, ic := .I][!is.na(v)]
  ub <- unique(c_dt$b)
  un <- unique(c_dt$n)
  nb <- length(ub)
  nn <- length(un)
  c_dt[, `:=`(i = match(n, un) + nn*(match(b, ub) - 1L), r = 0)]
  c_dt[, `:=`(i = i + (0:(.N - 1L))*nn*nb, ni = .N), i]
  c_dt[d == "c", r := t]
  a <- array(1, c(nn, nb, max(c_dt$ni)))
  a[c_dt$i] <- c_dt$r # 3-d array to store v^t (updated for each unique "o")
  i <- c_dt$i[i_dt$ic] # the indices of "a" to update (same for each unique "o")
  tt <- c_dt$t[i_dt$ic] # c_dt$t ordered for "a" (same for each unique "o")
  idxs <- i_dt$idxs # the indices to order s_dt$v (same for each unique "o")
  uo <- unique(s_dt$o)
  v <- collapse::gsplit(s_dt$v, s_dt$o)
  
  if (is.null(cl)) {
    # non-parallel solution
    data.table(
      o = rep(uo, each = length(ub)),
      b = rep(ub, length(v)),
      r = unlist(
        lapply(
          v,
          function(x) {
            a[i] <- x[idxs]^tt
            sumprod(a)
          }
        )
      ),
      key = "o"
    )
  } else {
    # parallel solution
    clusterExport(cl, c("a", "tt", "i", "idxs"), environment())
    
    data.table(
      o = rep(uo, each = length(ub)),
      b = rep(ub, length(v)),
      r = unlist(
        parLapply(
          cl,
          v,
          function(x) {
            a[i] <- x[idxs]^tt
            sumprod(a)
          }
        )
      ),
      key = "o"
    )
  }
}

现在来看数据:
# problem sizing ----
bigdims <- list(o = 50000, d = 50, b = 250, n = 200) # "big" problem - real-life size
lildims <- list(o =   100, d = 50, b =  25, n = 200) # "lil" problem (make runtime shorter as example)

# build some test data tables ----
build_s <- function(dims) {
  o <- seq_len(dims$o)
  d <- paste0("d",seq_len(dims$d))
  v <- as.double(seq_len(dims$o * dims$d))/10000
  CJ(o, d)[, `:=`(v = v)]
}

build_c <- function(dims) {
  b <- paste0("c", seq_len(dims$b))
  n <- seq_len(dims$n)
  d <- c("c", paste0("d", seq_len(dims$d)))
  t <- as.double(rep_len(0:6, dims$b * dims$n * (dims$d+1)))
  dt <- CJ(d, b, n)[, `:=`(t = t)]
  dt <- dt[t != 0]
}

计时一个很小的问题,这个问题太小了以至于并行化没有帮助:

s_dt <- build_s(lildims)
c_dt <- build_c(lildims)

microbenchmark::microbenchmark(fun1 = s_dt[, fun1(.SD, c_dt), o],
                               fun2 = fun2(s_dt, c_dt),
                               times = 10,
                               check = "equal")
#> Unit: seconds
#>  expr      min       lq     mean   median       uq      max neval
#>  fun1 3.204402 3.237741 3.383257 3.315450 3.404692 3.888289    10
#>  fun2 1.134680 1.138761 1.179907 1.179872 1.210293 1.259249    10

现在的大问题是:
s_dt <- build_s(bigdims)
c_dt <- build_c(bigdims)

system.time(dt2p <- fun2(s_dt, c_dt, cl))
#>    user  system elapsed 
#>  24.937   9.386 330.600

stopCluster(cl)

比5分钟略长,使用了31个核心。

非常棒的答案。非常感谢您花费时间将其分解开来。我根本没有考虑过这个犰狳,所以这特别有帮助。我还可以在我写的其他代码中使用它。 - mb147

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接