使用data.table按条件分组

6
在R中,我有一个大的data.table。对于每一行,我想要计算具有类似x1值(+/- 一些容差tol)的行数。我可以通过adply来实现这个目标,但速度太慢了。看起来像是data.table非常擅长的事情 - 实际上,我已经在部分计算中使用了data.table。
是否有一种完全使用data.table来完成这项任务的方法?下面是一个示例:
library(data.table)
library(plyr)
my.df = data.table(x1 = 1:1000,
                   x2 = 4:1003)
tol = 3
adply(my.df, 1, function(df) my.df[x1 > (df$x1 - tol) & x1 < (df$x1 + tol), .N])

结果:

        x1   x2 V1
   1:    1    4  3
   2:    2    5  4
   3:    3    6  5
   4:    4    7  5
   5:    5    8  5
  ---             
 996:  996  999  5
 997:  997 1000  5
 998:  998 1001  5
 999:  999 1002  4
1000: 1000 1003  3

更新:

这里是一个样本数据集,更接近于我的真实数据:

set.seed(10)
x = seq(1,100000000,100000)
x = x + sample(1:50000, length(x), replace=T)
x2 = x + sample(1:50000, length(x), replace=T)
my.df = data.table(x1 = x,
                   x2 = x2)
setkey(my.df,x1)
tol = 100000

og = function(my.df) {
  adply(my.df, 1, function(df) my.df[x1 > (df$x1 - tol) & x1 < (df$x1 + tol), .N])
}

microbenchmark(r_ed <- ed(copy(my.df)),
               r_ar <- ar(copy(my.df)),
               r_og <- og(copy(my.df)),
               times = 1)

Unit: milliseconds
                    expr         min          lq      median          uq         max neval
 r_ed <- ed(copy(my.df))    8.553137    8.553137    8.553137    8.553137    8.553137     1
 r_ar <- ar(copy(my.df))   10.229438   10.229438   10.229438   10.229438   10.229438     1
 r_og <- og(copy(my.df)) 1424.472844 1424.472844 1424.472844 1424.472844 1424.472844     1

显然,@eddi和@Arun的解决方案比我的快得多。现在我只需要尝试理解“rolls”的含义。
4个回答

9

查看@eddi的答案,可以得到更快的解决方案(针对这个特定问题)。当x1不是整数时,它也适用。

你正在寻找的算法是区间树。还有一个叫做IRanges的生物信息学包可以完成这个任务。很难超越它。

require(IRanges)
require(data.table)
my.df[, res := countOverlaps(IRanges(my.df$x1, width=1), 
           IRanges(my.df$x1-tol+1, my.df$x1+tol-1))]

一些解释:

如果你把代码分解开来,你可以用三行来编写它:

ir1 <- IRanges(my.df$x1, width=1)
ir2 <- IRanges(my.df$x1-tol+1, my.df$x1+tol-1)
cnt <- countOverlaps(ir1, ir2)

我们所做的实质上是创建两个“范围”(只需键入ir1ir2即可查看它们)。然后,对于ir1中的每个条目,我们询问它们在ir2中重叠了多少次(这是“区间树”的部分)。这非常高效。隐式地,countOverlaps函数的参数type默认为“type = any”。如果您想要探索其他类型,可以尝试一下。这非常有用。还与findOverlaps函数相关。
注意:对于这种特定情况,可能存在更快的解决方案(事实上,@eddi提供了一种),其中ir1的宽度= 1。但对于宽度可变和/或> 1的问题,这应该是最快的解决方案。

基准测试:

ag <- function(my.df) my.df[, res := sum(abs(my.df$x1-x1) < tol), by=x1]
ro <- function(my.df) {
            my.df[,res:= { y = my.df$x1
            sum(y > (x1 - tol) & y < (x1 + tol))
            }, by=x1]
      }
ar <- function(my.df) {
           my.df[, res := countOverlaps(IRanges(my.df$x1, width=1), 
            IRanges(my.df$x1-tol+1, my.df$x1+tol-1))]
      }


require(microbenchmark)
microbenchmark(r1 <- ag(copy(my.df)), r2 <- ro(copy(my.df)), 
               r3 <- ar(copy(my.df)), times=100)

Unit: milliseconds
                  expr      min       lq   median       uq       max neval
 r1 <- ag(copy(my.df)) 33.15940 39.63531 41.61555 44.56616 208.99067   100
 r2 <- ro(copy(my.df)) 69.35311 76.66642 80.23917 84.67419 344.82031   100
 r3 <- ar(copy(my.df)) 11.22027 12.14113 13.21196 14.72830  48.61417   100 <~~~

identical(r1, r2) # TRUE
identical(r1, r3) # TRUE

Arun,我用一个更类似于我的数据集的数据集尝试了你的解决方案,但出现了错误: set.seed(10) x = seq(1,100000000,100000) x = x + sample(1:50000, length(x), replace=T) x2 = x + sample(1:50000, length(x), replace=T) my.df = data.table(x1 = x, x2 = x2) setkey(my.df,x1) tol = 100000 - benjamin
抱歉 - 我在解析注释标记时遇到了问题。错误是:Error in countOverlaps(IRanges(my.df$x1, width = 1), IRanges(pmax(1, my.df$x1 - : 在选择函数'countOverlaps'的方法时评估参数'subject'时出错:Error in .Call2("solve_user_SEW0", start, end, width, PACKAGE = "IRanges") : 解决第二行时:不允许负宽度。 - benjamin
找到原因了。进行了修改。现在试试看。问题出在 pmax(nrow(my.df)...)... 它并不是必要的。修改应该可以正常工作。 - Arun
好的,现在可以了。你经常使用Bioconductor吗?我一直想试试它。几周前我无意中参加了西雅图的一个Bioconductor会议。 - benjamin
是的,我在我的工作中广泛使用了相当多的软件包。它们大多与基因组学相关,尽管像IRanges这样的软件包可以用于任何间隔数据。据我所知,一些软件包也由于许可限制/CRAN中的差异而提交到那里。 - Arun

4

以下是更快的data.table解决方案。思路是利用data.table的滚动合并功能,但在此之前,我们需要稍微修改数据,并将列x1从整数转换为数字。这是因为原作者使用了严格不等式,为了使用滚动连接,我们必须稍微降低公差,使其成为浮点数。

my.df[, x1 := as.numeric(x1)]

# set the key to x1 for the merges and to sort
# (note, if data already sorted can make this step instantaneous using setattr)
setkey(my.df, x1)

# and now we're going to do two rolling merges, one with the upper bound
# and one with lower, then get the index of the match and subtract the ends
# (+1, to get the count)
my.df[, res := my.df[J(x1 + tol - 1e-6), list(ind = .I), roll = Inf]$ind -
               my.df[J(x1 - tol + 1e-6), list(ind = .I), roll = -Inf]$ind + 1]


# and here's the bench vs @Arun's solution
ed = function(my.df) {
  my.df[, x1 := as.numeric(x1)]
  setkey(my.df, x1)
  my.df[, res := my.df[J(x1 + tol - 1e-6), list(ind = .I), roll = Inf]$ind -
                 my.df[J(x1 - tol + 1e-6), list(ind = .I), roll = -Inf]$ind + 1]
}

microbenchmark(ed(copy(my.df)), ar(copy(my.df)))
#Unit: milliseconds
#            expr       min       lq   median       uq      max neval
# ed(copy(my.df))  7.297928 10.09947 10.87561 11.80083 23.05907   100
# ar(copy(my.df)) 10.825521 15.38151 16.36115 18.15350 21.98761   100

注意:正如Arun和Matthew指出的那样,如果x1是整数,就不必转换为数字并从tol中减去一个小量,可以使用tol - 1L而不是上面的tol - 1e-6


+1 x1 可以保持为整数,只有 +1L 和 -1L 处理严格不等式,即 J(x1 + tol - 1L)?保持为整数应该更快。 - Matt Dowle
@MatthewDowle 我们与 @Arun 大致有这样的评论,尽管我没有意识到 1L1 的区别,可能本应该把它们留在这里——1L 比较确实适用于整数 x1,但如果不同 x1 之间的最小距离小于 1,则会失败。 - eddi
@eddi 哦,我明白了。当 tol 小于 1(比如说 0.8 数字),但 x1 是整数时,可以使用 J(x1+as.integer(ceiling(tol)-1))。如果我理解正确的话,也许 rollequal TRUE/FALSE 可以添加到 data.table 中,以严格不等式形式构建 rollrollends - Matt Dowle
1
顺便提一下:“DT[,list(ind = .I),]$ind” 可以简写为 “DT[,.I,]$I”。当“.I”没有命名时,点号会自动删除,因此它可以在结果的任何复合查询中与“.I”区分开。 - Matt Dowle
@MatthewDowle 我认为 rollequal 将是一个有用的标志。 - eddi

2

这里是一个纯data.table的解决方案:

my.df[, res:=sum(my.df$x1 > (x1 - tol) & my.df$x1 < (x1 + tol)), by=x1]

my.df <- adply(my.df, 1, 
           function(df) my.df[x1 > (df$x1 - tol) & x1 < (df$x1 + tol), .N])

identical(my.df[,res],my.df[,V1])
#[1] TRUE

然而,如果你有很多唯一的x1,这仍然会比较慢。毕竟,你需要进行大量的比较,我目前想不到避免这种情况的方法。


2
利用这个事实,
 abs(x-y) < tol ~    y-tol <= x <= y+ tol 

您可以将性能提升2倍。

## wrap codes in 2 function for benchmarking
library(data.table)
library(plyr)
my.df = data.table(x1 = 1:1000,
                   x2 = 4:1003)
tol = 3
ag <- function()
my.df[, res := sum(abs(my.df$x1-x1) < tol), by=x1]
ro <- function()
  my.df[,res:= { y = my.df$x1
          sum(y > (x1 - tol) & y < (x1 + tol))
          }, by=x1]
## check equal results
identical(ag(),ro())
TRUE
library(microbenchmark)
## benchmarks 
microbenchmark(ag(),
               ro(),times=1)

Unit: milliseconds
 expr      min       lq   median       uq      max neval
 ag() 32.75638 32.75638 32.75638 32.75638 32.75638     1
 ro() 63.50043 63.50043 63.50043 63.50043 63.50043     1

1
其中一部分是子集。使用 $ 符号,您的函数速度大约快了两倍,这是可以预料的。我总是记不住哪个子集函数最快。在我的答案中进行了更改。 - Roland

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接