快速矢量化函数以检查值是否在区间内

11

在R中,有没有一种函数可以高效地检查一个值是否大于一个数并且小于另一个数?它应该也能处理向量。

基本上,我正在寻找以下函数的更快版本:

> in.interval <- function(x, lo, hi) (x > lo & x < hi)
> in.interval(c(2,4,6), 3, 5)
[1] FALSE  TRUE FALSE
问题在于x必须被访问两次,而且与更高效的方法相比,计算会消耗两倍的内存。在内部,我会假设它的工作方式如下:
  1. 计算tmp1 <- (x > lo)
  2. 计算tmp2 <- (x < hi)
  3. 计算retval <- tmp1 & tmp2
现在,在步骤2之后,有两个布尔向量在内存中,并且必须两次查看x。我的问题是:是否有一个(内置的?)函数可以一次完成所有这些操作,而不需要分配额外的内存?
跟进此问题:R: Select values from data table in range 编辑: 我已经根据CauchyDistributedRV在https://gist.github.com/4344844的回答设置了一个Gist。

对于1e8个数值,该函数在我的电脑上需要大约12秒的时间。你希望它运行得更快多少呢?而且你打算如何通过仅访问x一次来检查2个条件?你能指出你心中所想的“更有效的方法”吗? - Joris Meys
@JorisMeys:大约6秒钟就可以了 :-) 稍后会编辑问题。 - krlmlr
也许可以使用findInterval来向量化你的问题? - Martin Morgan
@James:谢谢,这个实际上比&表现得更好(这让我很惊讶),但其他解决方案更快。请参见Gist。 - krlmlr
@user946850 你曾经尝试过使用Rcpp来实现这个算法吗?James的评论可能会让你惊讶... - Joris Meys
显示剩余2条评论
4个回答

6

正如评论中@James所说,关键是将low和high之间的中间值从x中减去,然后检查该差是否小于low和high之间距离的一半。或者,在代码中:

in.interval2 <- function(x, lo, hi) {
    abs(x-(hi+lo)/2) < (hi-lo)/2 
}

这几乎和 .bincode hack 一样快,是你要寻找的算法实现。你可以将其翻译成 C 或 C++,并尝试加速。

与其他解决方案的比较:

x <- runif(1e6,1,10)
require(rbenchmark)
benchmark(
  in.interval(x, 3, 5),
  in.interval2(x, 3, 5),
  findInterval(x, c(3, 5)) == 1,
  !is.na(.bincode(x, c(3, 5))),
  order='relative',
  columns=c("test", "replications", "elapsed", "relative")
) 

提供

                           test replications elapsed relative
4  !is.na(.bincode(x, c(3, 5)))          100    1.88    1.000
2         in.interval2(x, 3, 5)          100    1.95    1.037
3 findInterval(x, c(3, 5)) == 1          100    3.42    1.819
1          in.interval(x, 3, 5)          100    3.54    1.883

这个想法很好,但在我的机器上比.bincode慢得多,而且使用 & 内部的 Rcpp 版本表现就像使用最佳的其他 Rcpp 版本。请参见 Gist 的结果(这是测试 7 和 8)。 - krlmlr
(x-lo)*(hi-x) > 0怎么样? - Roland
@Roland: (x-lo)*(x-hi) <= 0 可能更好。但我认为我会选择 .bincode,除非有人提出更快的选项。 - krlmlr
顺便说一下:你和罗兰的方法似乎都不允许测试左包含右排除(或者反过来)。.bincode 可以做到除了左排除右排除之外的所有操作(通过 include.lowest 参数的帮助)。 - krlmlr
@JorisMeys:现在我想挑战你的“一次无法测试”的评论 :-) 请注意,我并不急于避免两次访问x,而是对(潜在大)x向量上的两次运行感到困扰。在一个紧密的for循环中两次访问x的元素是便宜的;但是两次迭代x可能需要两次通过内存层次结构拖动它。也许我没有表达清楚。 - krlmlr
显示剩余3条评论

5

findInterval 对于较长的 x 值比 in.interval 更快。

library(microbenchmark)

set.seed(123L)
x <- runif(1e6, 1, 10)
in.interval <- function(x, lo, hi) (x > lo & x < hi)

microbenchmark(
    findInterval(x, c(3, 5)) == 1L,
    in.interval(x, 3, 5),
    times=100)

使用

Unit: milliseconds
                            expr      min       lq   median       uq      max
1 findInterval(x, c(3, 5)) == 1L 23.40665 25.13308 25.17272 25.25361 27.04032
2           in.interval(x, 3, 5) 42.91647 45.51040 45.60424 45.75144 46.38389

如果不需要== 1L,则速度更快,如果要查找的“间隔”大于1,则很有用。

> system.time(findInterval(x, 0:10))
   user  system elapsed 
  3.644   0.112   3.763 

如果速度很重要,这个C语言实现非常快,但对于整数而不是数字参数则无法容忍。
library(inline)
in.interval_c <- cfunction(c(x="numeric", lo="numeric", hi="numeric"),
'    int len = Rf_length(x);
     double lower = REAL(lo)[0], upper = REAL(hi)[0],
            *xp = REAL(x);
     SEXP out = PROTECT(NEW_LOGICAL(len));
     int *outp = LOGICAL(out);

     for (int i = 0; i < len; ++i)
         outp[i] = (xp[i] - lower) * (xp[i] - upper) <= 0;

     UNPROTECT(1);
     return out;')

其他答案中提供的一些解决方案的时间如下:

microbenchmark(
    findInterval(x, c(3, 5)) == 1L,
    in.interval.abs(x, 3, 5),
    in.interval(x, 3, 5),
    in.interval_c(x, 3, 5),
    !is.na(.bincode(x, c(3, 5))),
    times=100)

使用

Unit: milliseconds
                            expr       min        lq    median        uq
1 findInterval(x, c(3, 5)) == 1L 23.419117 23.495943 23.556524 23.670907
2       in.interval.abs(x, 3, 5) 12.018486 12.056290 12.093279 12.161213
3         in.interval_c(x, 3, 5)  1.619649  1.641119  1.651007  1.679531
4           in.interval(x, 3, 5) 42.946318 43.050058 43.171480 43.407930
5   !is.na(.bincode(x, c(3, 5))) 15.421340 15.468946 15.520298 15.600758
        max
1 26.360845
2 13.178126
3  2.785939
4 46.187129
5 18.558425

重新审视速度问题,在一个名为bin.cpp的文件中。

#include <Rcpp.h>

using namespace Rcpp;

// [[Rcpp::export]]
SEXP bin1(SEXP x, SEXP lo, SEXP hi)
{
    const int len = Rf_length(x);
    const double lower = REAL(lo)[0], upper = REAL(hi)[0];
    SEXP out = PROTECT(Rf_allocVector(LGLSXP, len));

    double *xp = REAL(x);
    int *outp = LOGICAL(out);
    for (int i = 0; i < len; ++i)
    outp[i] = (xp[i] - lower) * (xp[i] - upper) <= 0;

    UNPROTECT(1);
    return out;
}

// [[Rcpp::export]]
LogicalVector bin2(NumericVector x, NumericVector lo, NumericVector hi)
{
    NumericVector xx(x);
    double lower = as<double>(lo);
    double upper = as<double>(hi); 

    LogicalVector out(x);
    for( int i=0; i < out.size(); i++ )
        out[i] = ( (xx[i]-lower) * (xx[i]-upper) ) <= 0;

    return out;
}

// [[Rcpp::export]]
LogicalVector bin3(NumericVector x, const double lower, const double upper)
{
    const int len = x.size();
    LogicalVector out(len);

    for (int i=0; i < len; i++)
        out[i] = ( (x[i]-lower) * (x[i]-upper) ) <= 0;

    return out;
}

带有时间

> library(Rcpp)
> sourceCpp("bin.cpp")
> microbenchmark(bin1(x, 3, 5), bin2(x, 3, 5), bin3(x, 3, 5),                   
+                in.interval_c(x, 3, 5), times=1000)                            
Unit: milliseconds                                                              
                    expr       min        lq    median        uq      max       
1          bin1(x, 3, 5)  1.546703  2.668171  2.785255  2.839225 144.9574       
2          bin2(x, 3, 5) 12.547456 13.583808 13.674477 13.792773 155.6594       
3          bin3(x, 3, 5)  2.238139  3.318293  3.357271  3.540876 144.1249       
4 in.interval_c(x, 3, 5)  1.545139  2.654809  2.767784  2.822722 143.7500       

通过使用常量 len 作为循环边界而不是 out.size(),以及在分配逻辑向量时不进行初始化(LogicalVector(len),因为它将在循环中初始化),约有相等部分的加速。


我已将您的解决方案嵌入到 https://gist.github.com/4344844 中。对于1e6个元素,它比&方法更快,但是使用C++仍然比它快两倍。 - krlmlr
Rcpp中大对象的复制 - Martin Morgan
现在有一件让我感到困惑的事情。你提供的 C 语言解决方案实际上比我的系统上简单的 x < hi 运行得更快(大约快了 2 倍)(尝试将 x > lo & x < hix < hi 添加到基准测试中查看)。这是怎么回事 - 我认为 R 中运算符的底层 C 实现已经非常优化了?还是与我编译该 C 函数时发生的任何事情相比,R 的二进制版本编译得更加“安全”? - Kevin Ushey
1
@CauchyDistributedRV x < hi的条件需要分配与我的代码同样数量的内存(用于返回逻辑值),这两个函数都需要迭代所有值,并且很可能C编译器已经优化了我的for循环体,使其操作数比高级语法所示的要少得多,因此两个循环的基本成本可能是可比的。R还会做很多我们认为理所当然的事情,例如处理NAs,回收hi(不仅仅是长度为1的特殊情况),检查是否需要强制转换数据类型等。 - Martin Morgan
@user946850 我稍微研究了一下速度差异,并在我的回答中添加了一节。 - Martin Morgan
有趣的是,我发现一个天真的C实现“x>lo && x<hi”几乎和使用乘法的解决方案一样快。 - Kevin Ushey

4

我能找到的主要加速方法是通过对函数进行字节编译。即使使用Rcpp解决方案(虽然使用Rcpp sugar,而不是更深入的C解决方案),也比编译后的解决方案慢。

library( compiler )
library( microbenchmark )
library( inline )

in.interval <- function(x, lo, hi) (x > lo & x < hi)
in.interval2 <- cmpfun( in.interval )
in.interval3 <- function(x, lo, hi) {
  sapply( x, function(xx) { 
    xx > lo && xx < hi }
          )
}
in.interval4 <- cmpfun( in.interval3 )
in.interval5 <- rcpp( signature(x="numeric", lo="numeric", hi="numeric"), '
NumericVector xx(x);
double lower = Rcpp::as<double>(lo);
double upper = Rcpp::as<double>(hi);

return Rcpp::wrap( xx > lower & xx < upper );
')

x <- c(2, 4, 6)
lo <- 3
hi <- 5

microbenchmark(
  in.interval(x, lo, hi),
  in.interval2(x, lo, hi),
  in.interval3(x, lo, hi),
  in.interval4(x, lo, hi),
  in.interval5(x, lo, hi)
)

让我拥有

Unit: microseconds
                     expr    min      lq  median      uq    max
1  in.interval(x, lo, hi)  1.575  2.0785  2.5025  2.6560  7.490
2 in.interval2(x, lo, hi)  1.035  1.4230  1.6800  2.0705 11.246
3 in.interval3(x, lo, hi) 25.439 26.2320 26.7350 27.2250 77.541
4 in.interval4(x, lo, hi) 22.479 23.3920 23.8395 24.3725 33.770
5 in.interval5(x, lo, hi)  1.425  1.8740  2.2980  2.5565 21.598

编辑:根据其他评论,这里有一个更快的Rcpp解决方案,使用给定的绝对值技巧:
library( compiler )
library( inline )
library( microbenchmark )

in.interval.oldRcpp <- rcpp( 
  signature(x="numeric", lo="numeric", hi="numeric"), '
    NumericVector xx(x);
    double lower = Rcpp::as<double>(lo);
    double upper = Rcpp::as<double>(hi);

    return Rcpp::wrap( (xx > lower) & (xx < upper) );
    ')

in.interval.abs <- rcpp( 
  signature(x="numeric", lo="numeric", hi="numeric"), '
    NumericVector xx(x);
    double lower = as<double>(lo);
    double upper = as<double>(hi); 

    LogicalVector out(x);
    for( int i=0; i < out.size(); i++ ) {
      out[i] = ( (xx[i]-lower) * (xx[i]-upper) ) <= 0;
    }
    return wrap(out);
    ')

in.interval.abs.sugar <- rcpp( 
  signature( x="numeric", lo="numeric", hi="numeric"), '
    NumericVector xx(x);
    double lower = as<double>(lo);
    double upper = as<double>(hi); 

    return wrap( ((xx-lower) * (xx-upper)) <= 0 );
    ')

x <- runif(1E5)
lo <- 0.5
hi <- 1

microbenchmark(
  in.interval.oldRcpp(x, lo, hi),
  in.interval.abs(x, lo, hi),
  in.interval.abs.sugar(x, lo, hi)
)

all.equal( in.interval.oldRcpp(x, lo, hi), in.interval.abs(x, lo, hi) )
all.equal( in.interval.oldRcpp(x, lo, hi), in.interval.abs.sugar(x, lo, hi) )

给我。
1       in.interval.abs(x, lo, hi)  662.732  666.4855  669.939  690.6585 1580.707
2 in.interval.abs.sugar(x, lo, hi)  722.789  726.0920  728.795  742.6085 1671.093
3   in.interval.oldRcpp(x, lo, hi) 1870.784 1876.4890 1892.854 1935.0445 2859.025

> all.equal( in.interval.oldRcpp(x, lo, hi), in.interval.abs(x, lo, hi) )
[1] TRUE

> all.equal( in.interval.oldRcpp(x, lo, hi), in.interval.abs.sugar(x, lo, hi) )
[1] TRUE

4
你检查了你的函数返回什么了吗?它们不一样;&& 运算符只评估其操作数的第一个元素。 - Hong Ooi
哎呀 - 你说得完全正确。可以想象将调用包装在sapplymap中,但仍然比其他解决方案慢。 - Kevin Ushey
我已将您的代码放入gist中:https://gist.github.com/4344844。然而,在我的系统上(Ubuntu 12.10,最新的CRAN R),它无法编译:Error in compileCode(f, code, language = language, verbose = verbose) : ...error: no match for ‘operator&’ in ... - krlmlr
1
@user946850,我在gist中添加了一个可能的解决方案。它应该可以与0.10版本之前的Rcpp编译,但可能会稍微慢一点。或者,您可以在R会话中使用install.packages("Rcpp", type="source")从CRAN获取最新版本。 - Kevin Ushey
@CauchyDistributedRV:非常感谢您提供测试程序的核心代码。到目前为止,结果表明.bincode与最快的Rcpp解决方案相当。让我们看看还有什么其他的! - krlmlr
显示剩余3条评论

4
如果您可以处理“NA”,您可以使用“.bincode”:
.bincode(c(2,4,6), c(3, 5))
[1] NA  1 NA

library(microbenchmark)
set.seed(42)
x = runif(1e8, 1, 10)
microbenchmark(in.interval(x, 3, 5),
               findInterval(x,  c(3, 5)),
               .bincode(x, c(3, 5)),
               times=5)

Unit: milliseconds
                      expr       min        lq    median       uq      max
1     .bincode(x, c(3, 5))  930.4842  934.3594  955.9276 1002.857 1047.348
2 findInterval(x, c(3, 5)) 1438.4620 1445.7131 1472.4287 1481.380 1551.419
3     in.interval(x, 3, 5) 2977.8460 3046.7720 3075.8381 3182.013 3288.020

内部函数的巧妙使用。通过 !is.na(.bincode(...)) 获得答案。 - Joris Meys
2
.bincode 将其参数转换为整数,因此在当前上下文中具有令人惊讶的结果 -- .bincode(3.1, 3, 5) 是 'NA';测试每种方法的结果是否相同。 - Martin Morgan
糟糕,我的错,对不起。 - Martin Morgan
适用于平面,比Rcpp-ed解决方案略慢但效果良好。结果在Gist中。 - krlmlr

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接