使用R data.table包能否解决复杂的汇总函数问题?

20

我正在重新编写一些R脚本,用于分析大量数据(大约1700万行),并尝试使用data.table包来提高其内存效率(我刚开始学习这个包!)。

其中一部分代码让我感到困惑。我无法发布我的原始解决方案,因为(1)它很糟糕(速度慢!)(2)关于数据的细微差别需要用到很多细节,只会使问题更加复杂。

相反,我制作了这个玩具示例(而且它确实是一个玩具示例):

ds <- data.table(ID=c(1,1,1,1,2,2,2,3,3,3),
Obs=c(1.5,2.5,0.0,1.25,1.45,1.5,2.5,0.0,1.25,1.45), 
Pos=c(1,3,5,6,2,3,5,2,3,4))

它长这样:

    ID  Obs Pos
 1:  1 1.50   1
 2:  1 2.50   3
 3:  1 0.00   5
 4:  1 1.25   6
 5:  2 1.45   2
 6:  2 1.50   3
 7:  2 2.50   5
 8:  3 0.00   2
 9:  3 1.25   3
10:  3 1.45   4

为了更容易解释,我会假设我们观察火车(每辆火车都有自己的ID),在一条单向轨道上移动,并在轨道上的固定位置(pos,这里是1-6)进行有关于该火车的观测(某个值,对于问题来说不重要)。预计火车不会一直到达轨道的尽头(也许它在到达位置6之前就爆炸了),有时观察者会错过一个观测... 这些位置是连续的(因此,如果我们在位置4错过了观察火车的机会,但在位置5观察到了它,那么我们知道它必定已经通过了位置4)。

从上述数据表中,我需要生成像这样的表格:

   Pos Count
1:   1     3
2:   2     3
3:   3     3
4:   4     3
5:   5     2
6:   6     1

对于我数据表ds中每个唯一的Pos,我都有一份火车到达该位置(或更远位置)的数量计数,无论观察是否在轨道上该位置。

如果有人有任何想法或建议如何解决这个问题,将不胜感激。不幸的是,我对data.table不太熟悉,不知道是否可以完成!或者这可能是一个非常简单的问题要解决而我只是比较慢 :)

4个回答

15

好问题!!这个例子数据特别清晰易懂,解释得也很好。

首先我将展示答案,然后逐步解释。

> ids = 1:3   # or from the data: unique(ds$ID)
> pos = 1:6   # or from the data: unique(ds$Pos)
> setkey(ds,ID,Pos)

> ds[CJ(ids,pos), roll=-Inf, nomatch=0][, .N, by=Pos]
   Pos N
1:   1 3
2:   2 3
3:   3 3
4:   4 3
5:   5 2
6:   6 1
> 

这对于您的大数据也应该非常高效。

逐步执行

首先我尝试了一个交叉连接(CJ); 即,对于每个训练数据的每个位置。

> ds[CJ(ids,pos)]
    ID Pos  Obs
 1:  1   1 1.50
 2:  1   2   NA
 3:  1   3 2.50
 4:  1   4   NA
 5:  1   5 0.00
 6:  1   6 1.25
 7:  2   1   NA
 8:  2   2 1.45
 9:  2   3 1.50
10:  2   4   NA
11:  2   5 2.50
12:  2   6   NA
13:  3   1   NA
14:  3   2 0.00
15:  3   3 1.25
16:  3   4 1.45
17:  3   5   NA
18:  3   6   NA

我看到每列火车有6排。我看到3列火车。我得到了我预期的18排。在那列火车没有观察到时,我看到NA。很好。检查一下。笛卡尔积似乎正常工作。现在让我们构建查询。

你写道,如果在位置n观察到一列火车,则必须经过之前的位置。我立即想到了roll。让我们试试。

ds[CJ(ids,pos), roll=TRUE]
    ID Pos  Obs
 1:  1   1 1.50
 2:  1   2 1.50
 3:  1   3 2.50
 4:  1   4 2.50
 5:  1   5 0.00
 6:  1   6 1.25
 7:  2   1   NA
 8:  2   2 1.45
 9:  2   3 1.50
10:  2   4 1.50
11:  2   5 2.50
12:  2   6 2.50
13:  3   1   NA
14:  3   2 0.00
15:  3   3 1.25
16:  3   4 1.45
17:  3   5 1.45
18:  3   6 1.45

哦,这使得每个列车的观测结果向前滚动了。 对于第2和第3列车,在位置1留下了一些NA,但是您说如果在位置2观察到火车,则必须经过位置1。 它还将第2和第3辆火车的最后一次观测向前滚动到位置6,但是您说火车可能会爆炸。 因此,我们想要向后滚动! 这就是roll=-Inf。 这是一个复杂的-Inf,因为您还可以控制向后滚动多远,但对于这个问题我们不需要; 我们只想无限向后滚动。 让我们尝试roll=-Inf,看看会发生什么。

> ds[CJ(ids,pos), roll=-Inf]
    ID Pos  Obs
 1:  1   1 1.50
 2:  1   2 2.50
 3:  1   3 2.50
 4:  1   4 0.00
 5:  1   5 0.00
 6:  1   6 1.25
 7:  2   1 1.45
 8:  2   2 1.45
 9:  2   3 1.50
10:  2   4 2.50
11:  2   5 2.50
12:  2   6   NA
13:  3   1 0.00
14:  3   2 0.00
15:  3   3 1.25
16:  3   4 1.45
17:  3   5   NA
18:  3   6   NA

好了,差不多了。现在我们只需要计数。但是,在第2和第3次火车爆炸之后,那些讨厌的NA还在那里。让我们把它们移除。

> ds[CJ(ids,pos), roll=-Inf, nomatch=0]
    ID Pos  Obs
 1:  1   1 1.50
 2:  1   2 2.50
 3:  1   3 2.50
 4:  1   4 0.00
 5:  1   5 0.00
 6:  1   6 1.25
 7:  2   1 1.45
 8:  2   2 1.45
 9:  2   3 1.50
10:  2   4 2.50
11:  2   5 2.50
12:  3   1 0.00
13:  3   2 0.00
14:  3   3 1.25
15:  3   4 1.45

顺便提一下,data.table 喜欢尽可能多地放在一个单独的 DT[...] 中,因为这样它可以优化查询。内部并不会创建 NA 然后把它们删除;它从一开始就不会创建 NA。这个概念对于效率很重要。

最后,我们只需要计数即可。我们可以将此作为复合查询附加在末尾。

> ds[CJ(ids,pos), roll=-Inf, nomatch=0][, .N, by=Pos]
   Pos N
1:   1 3
2:   2 3
3:   3 3
4:   4 3
5:   5 2
6:   6 1

1
+1 真的是非常好的解决方案,而且解释得更好。您能否谈谈当数据变得更大时,您希望这种方法与 ds[ , list( Pos = 1:Pos[.N] ) , by = ID ][ , .N , by = Pos ] 相比如何? - Simon O'Hanlon
1
@SimonO'Hanlon 不错的替代方案。Pos[.N]将是一个新的长度为1的向量,传递给:函数以创建一个新的1:Pos[.N]向量。我预计所有这些小向量会使内存卡顿并导致更多的垃圾回收。随着列车数量的增加,这将比位置数量的增加更加严重(更多的组)。如果你测试它,我对结果很感兴趣! - Matt Dowle
我并不真正理解data.table的语法,但CJ看起来很昂贵(在概念上,如果不是实际上的话?);是否有像我的解决方案一样的解决方案,其中data.table通过ID识别最大Pos,即迅速将我们带到nAtMax?也许这就是@SimonO'Hanlon正在做的事情? - Martin Morgan
2
@MartinMorgan 是的,说得好。也许可以这样:ds[,max(Pos),by=ID][,rev(cumsum(rev(tabulate(V1))))]。我还在你的答案上添加了更长的评论。 - Matt Dowle
1
@MattDowle 很棒的例子和很好的解释!帮助我更好地理解 data.table 的工作原理。希望能在文档中看到这个例子。谢谢。 - Uwe
显示剩余3条评论

9

data.table听起来是一个很好的解决方案。从数据排序的方式中,我们可以找到每个列车的最大值。

maxPos = ds$Pos[!duplicated(ds$ID, fromLast=TRUE)]

然后将到达该位置的火车制成表格。
nAtMax = tabulate(maxPos)

并计算每个位置从末尾开始的火车累计总和

rev(cumsum(rev(nAtMax)))
## [1] 3 3 3 3 2 1

我认为这对于大数据来说会很快,但并非完全节省内存。

我从问题和标题中的印象是要求一个data.table演示,因为Meep解释说他的示例数据和任务非常简化。rev(cumsum(rev(tabulate())))执行了与所要求的完全相同的任务,但是如果火车从不同的点开始,观察值的价值变得重要,火车不再爆炸,或者还有卡车(2列ID)呢?这些都是对data.table查询的简单更改(开关),而在基础部分可能需要一些思考? - Matt Dowle
谢谢你提供的解决方案,它比我想到的要好得多! :) Matt 提出的建议是正确的,数据可能更加复杂,这就是为什么我接受了他的答案。如果你感兴趣,我正在处理的实际上是 DNA 序列跟踪数据,与火车无关 :) - Meep

4
您可以尝试以下步骤。为了更好地理解,我特意将其分成多个步骤的解决方案。您也可以通过链接[]将它们全部合并为一个步骤。
这里的逻辑是先找到每个ID的最终位置。然后我们聚合数据以查找每个最终位置的ID计数。由于最终位置6的所有ID也应计入最终位置5,因此我们使用cumsum将所有较高ID计数添加到它们各自的较低ID中。
ds2 <- ds[, list(FinalPos=max(Pos)), by=ID]

ds2 
##    ID FinalPos
## 1:  1        6
## 2:  2        5
## 3:  3        4

ds3 <- ds2[ , list(Count = length(ID)), by = FinalPos][order(FinalPos, decreasing=TRUE), list(FinalPos, Count = cumsum(Count))]

ds3
##    FinalPos Count
## 1:        4     3
## 2:        5     2
## 3:        6     1

setkey(ds3, FinalPos)

ds3[J(c(1:6)), roll = 'nearest']

##    FinalPos Count
## 1:        1     3
## 2:        2     3
## 3:        3     3
## 4:        4     3
## 5:        5     2
## 6:        6     1

非常好的使用了 roll="nearest"。我认为 ds3 是不必要的吗?- setkey(ds[, list(N=max(Pos)), keyby=ID], N)[J(1:6), roll="nearest"] - Arun
1
仔细思考一下,“roll =“nearest””会给出错误的结果。例如,如果数据中根本没有“6”,并且您从1:6进行连接,那么它不是(将给出“2”而不是NA或0)? - Arun

1
一些参考时间:

timing code:

library(data.table)
set.seed(0L)
nr <- 2e7
nid <- 1e6
npos <- 20
ds <- unique(data.table(ID=sample(nid, nr, TRUE), Pos=sample(npos, nr, TRUE)))
# ds <- data.table(ID=c(1,1,1,1,2,2,2,3,3,3),
#     Obs=c(1.5,2.5,0.0,1.25,1.45,1.5,2.5,0.0,1.25,1.45),
#     Pos=c(1,3,5,6,2,3,5,2,3,4))
setkey(ds, ID, Pos)

ids = ds[, sort(unique(ID))]   # or from the data: unique(ds$ID)
pos = ds[, sort(unique(Pos))]   # or from the data: unique(ds$Pos)

mtd0 <- function() ds[CJ(ids, pos), roll=-Inf, nomatch=0][, .N, by=Pos]
mtd1 <- function() ds[,max(Pos),by=ID][,rev(cumsum(rev(tabulate(V1))))]
mtd2 <- function() ds[, .(Pos=1:Pos[.N]), ID][, .N, by=Pos]
bench::mark(mtd0(), mtd1(), mtd2(), check=FALSE)

identical(mtd0()$N, mtd2()$N)
#[1] TRUE

identical(mtd1(), mtd2()$N)
#[1] TRUE

时间:

# A tibble: 3 x 13
  expression      min   median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time result            memory               time     gc              
  <bch:expr> <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl> <int> <dbl>   <bch:tm> <list>            <list>               <list>   <list>          
1 mtd0()        2.14s    2.14s     0.468    1.26GB     1.40     1     3      2.14s <df[,2] [20 x 2]> <df[,3] [41 x 3]>    <bch:tm> <tibble [1 x 3]>
2 mtd1()     281.54ms 284.89ms     3.51   209.24MB     1.76     2     1   569.78ms <int [20]>        <df[,3] [24 x 3]>    <bch:tm> <tibble [2 x 3]>
3 mtd2()        1.63s    1.63s     0.613  785.65MB     7.35     1    12      1.63s <df[,2] [20 x 2]> <df[,3] [9,111 x 3]> <bch:tm> <tibble [1 x 3]>

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接