类似于dplyr,如何对data.table进行按行求和、平均值、最小值和最大值的计算?

38

有关datatable的逐行运算符的其他帖子。它们要么过于简单,要么解决了特定场景

我这里的问题更通用。有一个使用dplyr的解决方案。我尝试了一下,但未能找到一个等效的data.table语法解决方案。您能否提供一个优雅的data.table解决方案,以重现与dplyr版本相同的结果?

编辑 1:在真实数据集(10MB,73000行,统计信息基于24个数字列)上对建议解决方案的基准测试结果进行总结。基准测试结果是主观的。但是,经过时间已经得到了一致的再现。

| Solution By | Speed compared to dplyr     |
|-------------|-----------------------------|
| Metrics v1  |  4.3 times SLOWER (use .SD) |
| Metrics v2  |  5.6 times FASTER           |
| ExperimenteR| 15   times FASTER           |
| Arun v1     |  3   times FASTER (Map func)|
| Arun v2     |  3   times FASTER (foo func)|
| Ista        |  4.5 times FASTER           |

编辑2:一天后我添加了NACount列。这就是为什么各位贡献者提出的解决方案中没有找到该列的原因。

数据设置

library(data.table)
dt <- data.table(ProductName = c("Lettuce", "Beetroot", "Spinach", "Kale", "Carrot"),
    Country = c("CA", "FR", "FR", "CA", "CA"),
    Q1 = c(NA, 61, 40, 54, NA), Q2 = c(22,  8, NA,  5, NA),
    Q3 = c(51, NA, NA, 16, NA), Q4 = c(79, 10, 49, NA, NA))

#    ProductName Country Q1 Q2 Q3 Q4
# 1:     Lettuce      CA NA 22 51 79
# 2:    Beetroot      FR 61  8 NA 10
# 3:     Spinach      FR 40 NA NA 49
# 4:        Kale      CA 54  5 16 NA
# 5:      Carrot      CA NA NA NA NA

使用dplyr + rowwise()的解决方案

library(dplyr) ; library(magrittr)
dt %>% rowwise() %>% 
    transmute(ProductName, Country, Q1, Q2, Q3, Q4,
     AVG = mean(c(Q1, Q2, Q3, Q4), na.rm=TRUE),
     MIN = min (c(Q1, Q2, Q3, Q4), na.rm=TRUE),
     MAX = max (c(Q1, Q2, Q3, Q4), na.rm=TRUE),
     SUM = sum (c(Q1, Q2, Q3, Q4), na.rm=TRUE),
     NAcnt= sum(is.na(c(Q1, Q2, Q3, Q4))))

#   ProductName Country Q1 Q2 Q3 Q4      AVG MIN  MAX SUM NAcnt
# 1     Lettuce      CA NA 22 51 79 50.66667  22   79 152     1
# 2    Beetroot      FR 61  8 NA 10 26.33333   8   61  79     1
# 3     Spinach      FR 40 NA NA 49 44.50000  40   49  89     2
# 4        Kale      CA 54  5 16 NA 25.00000   5   54  75     1
# 5      Carrot      CA NA NA NA NA      NaN Inf -Inf   0     4

数据表出现错误(计算整个列而不是每行)

dt[, .(ProductName, Country, Q1, Q2, Q3, Q4,
    AVG = mean(c(Q1, Q2, Q3, Q4), na.rm=TRUE),
    MIN = min (c(Q1, Q2, Q3, Q4), na.rm=TRUE),
    MAX = max (c(Q1, Q2, Q3, Q4), na.rm=TRUE),
    SUM = sum (c(Q1, Q2, Q3, Q4), na.rm=TRUE),
    NAcnt= sum(is.na(c(Q1, Q2, Q3, Q4))))]

#    ProductName Country Q1 Q2 Q3 Q4      AVG MIN MAX SUM NAcnt
# 1:     Lettuce      CA NA 22 51 79 35.90909   5  79 395     9
# 2:    Beetroot      FR 61  8 NA 10 35.90909   5  79 395     9
# 3:     Spinach      FR 40 NA NA 49 35.90909   5  79 395     9
# 4:        Kale      CA 54  5 16 NA 35.90909   5  79 395     9
# 5:      Carrot      CA NA NA NA NA 35.90909   5  79 395     9

几乎是解决方案,但更复杂且缺少Q1、Q2、Q3、Q4输出列。
dtmelt <- reshape2::melt(dt, id=c("ProductName", "Country"),
            variable.name="Quarter", value.name="Qty")

dtmelt[, .(AVG = mean(Qty, na.rm=TRUE),
    MIN = min (Qty, na.rm=TRUE),
    MAX = max (Qty, na.rm=TRUE),
    SUM = sum (Qty, na.rm=TRUE),
    NAcnt= sum(is.na(Qty))), by = list(ProductName, Country)]

#    ProductName Country      AVG MIN  MAX SUM NAcnt
# 1:     Lettuce      CA 50.66667  22   79 152     1
# 2:    Beetroot      FR 26.33333   8   61  79     1
# 3:     Spinach      FR 44.50000  40   49  89     2
# 4:        Kale      CA 25.00000   5   54  75     1
# 5:      Carrot      CA      NaN Inf -Inf   0     4

1
dt[, AVG := rowMeans(.SD, na.rm=T),.SDcols=c(Q1, Q2,Q3,Q4)] - ExperimenteR
@ExperimenteR 谢谢(SDcols应该是一个字符向量吗?)我尝试了这个 dt[, .(Q1, Q2, Q3, Q4, AVG = rowMeans(.SD, na.rm=T), MIN = pmin(Q1,Q2,Q3,Q4, na.rm=T), MAX = pmax(Q1,Q2,Q3,Q4, na.rm=T) ), .SDcols=c("Q1","Q2","Q3","Q4")] 但仍然缺少SUM并且没有ProductName和Country列。 - Polymerase
@Metrics 由于评估错误,没有输出:dt[, \:=` (AVG = rowMeans(.SD, na.rm=TRUE), MIN = min(.SD, na.rm=TRUE), MAX = max(.SD, na.rm=TRUE), SUM = sum(.SD, na.rm=TRUE)), .SDcols = c("Q1","Q2","Q3","Q4"), by=1:nrow(dt)] 警告信息: 1: 在 min(c(NA_real_, NA_real_, NA_real_, NA_real_), na.rm = TRUE) 中: 没有非缺失参数可用于 min;返回 Inf 2: 在 max(c(NA_real_, NA_real_, NA_real_, NA_real_), na.rm = TRUE) 中: 没有非缺失参数可用于 max;返回 -Inf` - Polymerase
看我的回答。我已经更新了代码并删除了注释。Dplyr和data.table都会对NaN和-Inf发出警告。 - Metrics
3
data.table 尽可能使用基本的 R 函数,以避免实行“封闭园区”的方式。但是,基本的 R 没有一个很好的函数来执行这个操作 :-(。因此,我们将把 colwise()rowwise() 函数作为 #1063 文件中的内容进行实现...... 我已经将其标记为下一个版本发布的内容。 - Arun
5个回答

48

您可以使用来自matrixStats包的高效逐行函数。

library(matrixStats)
dt[, `:=`(MIN = rowMins(as.matrix(.SD), na.rm=T),
          MAX = rowMaxs(as.matrix(.SD), na.rm=T),
          AVG = rowMeans(.SD, na.rm=T),
          SUM = rowSums(.SD, na.rm=T)), .SDcols=c(Q1, Q2,Q3,Q4)]

dt
#    ProductName Country Q1 Q2 Q3 Q4 MIN  MAX      AVG SUM
# 1:     Lettuce      CA NA 22 51 79  22   79 50.66667 152
# 2:    Beetroot      FR 61  8 NA 10   8   61 26.33333  79
# 3:     Spinach      FR 40 NA 79 49  40   79 56.00000 168
# 4:        Kale      CA 54  5 16 NA   5   54 25.00000  75
# 5:      Carrot      CA NA NA NA NA Inf -Inf      NaN   0

对于具有500000行的数据集(使用CRAN中的data.table)

dt <- rbindlist(lapply(1:100000, function(i)dt))
system.time(dt[, `:=`(MIN = rowMins(as.matrix(.SD), na.rm=T),
                      MAX = rowMaxs(as.matrix(.SD), na.rm=T),
                      AVG = rowMeans(.SD, na.rm=T),
                      SUM = rowSums(.SD, na.rm=T)), .SDcols=c("Q1", "Q2","Q3","Q4")])
#  user  system elapsed 
# 0.089   0.004   0.093

rowwise(或by=1:nrow(dt))是“婉辞”,表示使用for loop,例如:

library(dplyr) ; library(magrittr)
system.time(dt %>% rowwise() %>% 
  transmute(ProductName, Country, Q1, Q2, Q3, Q4,
            MIN = min (c(Q1, Q2, Q3, Q4), na.rm=TRUE),
            MAX = max (c(Q1, Q2, Q3, Q4), na.rm=TRUE),
            AVG = mean(c(Q1, Q2, Q3, Q4), na.rm=TRUE),
            SUM = sum (c(Q1, Q2, Q3, Q4), na.rm=TRUE)))
#   user  system elapsed 
# 80.832   0.111  80.974 

system.time(dt[, `:=`(AVG= mean(as.numeric(.SD),na.rm=TRUE),MIN = min(.SD, na.rm=TRUE),MAX = max(.SD, na.rm=TRUE),SUM = sum(.SD, na.rm=TRUE)),.SDcols=c("Q1", "Q2","Q3","Q4"),by=1:nrow(dt)] )
#    user  system elapsed 
# 141.492   0.196 141.757

你的解决方案是最快的!(请参见原问题中的基准测试)感谢介绍matrixStats包。我想知道与Arun和Metrics第二个解决方案相比,你的解决方案对内存资源的影响如何。 - Polymerase
@ExperimenteR 这个能够工作吗?dt <- rbindlist(lapply(1:100000, function(i)dt))。我尝试分解它,但是返回了错误dt(list(1))。不过这个解决方案很优雅。 - sahuno
哦,我明白了!你把原始数据表复制了多次,并将它们的所有行合并在一起。 - sahuno

19

对于 data.table,使用 by=1:nrow(dt) 实现逐行操作。

 library(data.table)
dt[, `:=`(AVG= mean(as.numeric(.SD),na.rm=TRUE),MIN = min(.SD, na.rm=TRUE),MAX = max(.SD, na.rm=TRUE),SUM = sum(.SD, na.rm=TRUE)),.SDcols=c(Q1, Q2,Q3,Q4),by=1:nrow(dt)] 
   ProductName Country Q1 Q2 Q3 Q4      AVG MIN  MAX SUM
1:     Lettuce      CA NA 22 51 79 50.66667  22   79 152
2:    Beetroot      FR 61  8 NA 10 26.33333   8   61  79
3:     Spinach      FR 40 NA 79 49 56.00000  40   79 168
4:        Kale      CA 54  5 16 NA 25.00000   5   54  75
5:      Carrot      CA NA NA NA NA      NaN Inf -Inf   0

Warning messages:
1: In min(c(NA_real_, NA_real_, NA_real_, NA_real_), na.rm = TRUE) :
  no non-missing arguments to min; returning Inf
2: In max(c(NA_real_, NA_real_, NA_real_, NA_real_), na.rm = TRUE) :
  no non-missing arguments to max; returning -Inf

你收到了警告消息,因为在第5行中,你正在计算空值的最大值、总和、最小值和最大值。例如,如下所示:

min(c(NA,NA,NA,NA),na.rm=TRUE)
[1] Inf
Warning message:
In min(c(NA, NA, NA, NA), na.rm = TRUE) :
  no non-missing arguments to min; returning Inf

哦,没错。我忘记打印(dt)了。抱歉!顺便问一下,如果在“.SDcols=c(Q1,Q2,Q3,Q4)”中不加引号,为什么会出现“找不到对象'Q1'”的情况呢?(data.table 1.9.4,R v3.2.0) - Polymerase
刚刚在一个10MB的数据集上尝试了你提供的解决方案,共有73000行。使用dplyr版本比你建议的实现快了4倍。这可能是因为在计算平均值时使用了as.numeric(.SD)吗? - Polymerase
3
在如此小的数据集上进行基准测试是毫无意义的。 - David Arenburg
@Polymerase:我认为这与.SD.有关。试试这个:在你需要输入所有列名的地方:dt[,:=(AVF = mean (c(Q1, Q2, Q3, Q4), na.rm=TRUE),MIN = min (c(Q1, Q2, Q3, Q4), na.rm=TRUE),MAX = max (c(Q1, Q2, Q3, Q4), na.rm=TRUE),AVG = mean(c(Q1, Q2, Q3, Q4), na.rm=TRUE),SUM = sum (c(Q1, Q2, Q3, Q4), na.rm=TRUE)),by=1:nrow(dt)]。对于你的小样本数据,这比你的dplyr更快。 - Metrics
@Metrics,你建议的第二个版本非常快。让我在这里测试所有解决方案,然后我会总结我所有的测试结果。 - Polymerase
显示剩余3条评论

8

还有另一种方式(虽然不太高效,因为每次都调用了 na.omit(),而且还会有很多内存分配):

require(data.table)
new_cols = c("MIN", "MAX", "SUM", "AVG")
dt[, (new_cols) := Map(function(x, f) f(x), 
                       list(na.omit(c(Q1,Q2,Q3,Q4))), 
                       list(min, max, sum, mean)),
   by = 1:nrow(dt)]

#    ProductName Country Q1 Q2 Q3 Q4 MIN  MAX SUM      AVG
# 1:     Lettuce      CA NA 22 51 79  22   79 152 50.66667
# 2:    Beetroot      FR 61  8 NA 10   8   61  79 26.33333
# 3:     Spinach      FR 40 NA 79 49  40   79 168 56.00000
# 4:        Kale      CA 54  5 16 NA   5   54  75 25.00000
# 5:      Carrot      CA NA NA NA NA Inf -Inf   0      NaN

但是就像我之前提到的一样,一旦 colwise()rowwise() 实现了,这将变得更加简单。在这种情况下,语法可能如下所示:
dt[, rowwise(.SD, list(MIN=min, MAX=max, SUM=sum, AVG=mean), na.rm=TRUE), by = 1:nrow(dt)]
# `by = ` is really not necessary in this case.

甚至针对这种情况更加简单明了的方法是:
rowwise(dt, list(...), na.rm=TRUE)

编辑:

另一种变体:

myNACount <- function(x, ...) length(attributes(x)$na.action)
foo <- function(x, ...) {
    funs = c(min, max, mean, sum, myNACount)
    lapply(funs, function(f) f(x, ...))
}

dt[, (new_cols) := foo(na.omit(c(Q1, Q2, Q3, Q4)), na.rm=TRUE), by=1:nrow(dt)]
#    ProductName Country Q1 Q2 Q3 Q4 MIN  MAX      SUM AVG NAs
# 1:     Lettuce      CA NA 22 51 79  22   79 50.66667 152   1
# 2:    Beetroot      FR 61  8 NA 10   8   61 26.33333  79   1
# 3:     Spinach      FR 40 NA NA 49  40   49 44.50000  89   2
# 4:        Kale      CA 54  5 16 NA   5   54 25.00000  75   1
# 5:      Carrot      CA NA NA NA NA Inf -Inf      NaN   0   4

是的,为什么您在“逐行”潜在解决方案中添加了“by”? - David Arenburg
可能会出现复杂的情况,例如 dt[, if (TRUE) do_bla else rowwise(...), by=some_cols](就像我说的,在这种情况下这并不必要)。 - Arun
1
@Arun 那个 myNACount <- function(x) length(attributes(x)$na.action) 函数很出色。谢谢。但愿我能理解这种优化机制。你提出的第二种变化速度非常快。 - Polymerase
1
@Arun Ahem...抱歉,我在基准测试中犯了一个错误。你所做的第二个变化比第一个版本稍微快一些。最快的执行时间来自ExperimenteR的解决方案。 - Polymerase
1
@Polymerase,别担心。我认为我们在这里都学到了很多 :-)。好问题。 - Arun
显示剩余2条评论

2
< p > apply 函数可用于执行逐行计算。将函数单独定义可以使代码更清晰:

dstats <- function(x){
    c(mean(x,na.rm=TRUE),
      min(x, na.rm=TRUE),
      max(x, na.rm=TRUE),
      sum(x, na.rm=TRUE))
}

该函数现在可应用于数据表的行。

(dt[,
   c("AVG", "MIN", "MAX", "SUM") := data.frame(t(apply(.SD, 1, dstats))),
   .SDcols=c("Q1", "Q2","Q3","Q4"),
])

请注意,使用[.data.table 的唯一优势是它允许使用:=进行快速引用添加。
这比matrixStats解决方案慢但更灵活,并且比@ExperimenteR的dplyr解决方案更快,计时为36秒(我对其他方法的计时与@ExperimenteR的答案中的计时类似)。

1
  1. apply()函数将.SD转换为矩阵,需要进行内存分配。
  2. t()函数对结果进行转置,会产生另一个副本。
  3. data.frame()函数会再次进行内存分配。不确定在这里是否需要使用with = FALSE参数。我们可以通过避免所有这些副本来改进代码。
- Arun
@Arun 或许是这样,但它已经相当快了,如果我们需要更快的速度,我们可以使用 matrixStats。我使用 with = FALSE 是因为 help(":=") 暗示当 RHS 返回一个列表时需要这样做。 - Ista
相当快并不够好,特别是当更高效的方法非常简单时。我已经在 Github 项目页面上回复了你的回复,并详细说明了原因。关于 with=FALSE,它并不意味着那个意思,但我理解造成的困惑。我会进行修复。 - Arun
@Ista,你的解决方案是第二快的,可以在原始问题中看到基准测试结果。 - Polymerase

0

我希望其他人在遇到相同问题时,能够找到有用的信息。

第一种方法:结合基本R语言

dt[,`:=`(MIN = apply(dt[, Q1:Q4], 1, FUN = min, na.rm=TRUE),
       MAX = apply(dt[, Q1:Q4], 1, FUN = max, na.rm = TRUE),
       AVG = rowMeans(dt[, Q1:Q4], na.rm = TRUE),
       SUM = rowSums(dt[, Q1:Q4], na.rm = TRUE))][]
# ProductName Country Q1 Q2 Q3 Q4 MIN  MAX      AVG SUM
# 1:     Lettuce      CA NA 22 51 79  22   79 50.66667 152
# 2:    Beetroot      FR 61  8 NA 10   8   61 26.33333  79
# 3:     Spinach      FR 40 NA NA 49  40   49 44.50000  89
# 4:        Kale      CA 54  5 16 NA   5   54 25.00000  75
# 5:      Carrot      CA NA NA NA NA Inf -Inf      NaN   0

第二种方法:基于@ExperimenteR的想法,使用matrixStats包

dt1 <- dt[,`:=`(MIN = rowMins(as.matrix(dt[, Q1:Q4]), na.rm=TRUE),
                MAX = rowMaxs(as.matrix(dt[, Q1:Q4]), na.rm = TRUE),
                AVG = rowMeans(dt[, Q1:Q4], na.rm = TRUE),
                SUM = rowSums(dt[, Q1:Q4], na.rm = TRUE))][]
# ProductName Country Q1 Q2 Q3 Q4 MIN  MAX      AVG SUM
# 1:     Lettuce      CA NA 22 51 79  22   79 50.66667 152
# 2:    Beetroot      FR 61  8 NA 10   8   61 26.33333  79
# 3:     Spinach      FR 40 NA NA 49  40   49 44.50000  89
# 4:        Kale      CA 54  5 16 NA   5   54 25.00000  75
# 5:      Carrot      CA NA NA NA NA Inf -Inf      NaN   0

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接