在向量中寻找第一个TRUE值的更快方法

27

在一个函数中,我经常需要使用类似以下的代码:

which(x==1)[1]
which(x>1)[1]
x[x>10][1]

其中x是一个数字向量。 summaryRprof()显示我在关系运算符上花费了>80%的时间。我想知道是否有一个函数仅在达到第一个TRUE值时进行比较,以加快我的代码速度。for循环比上述选项要慢。


which.minwhich.max考虑了类似的问题。 - James
1
@James:不完全正确,因为它们需要一个逻辑向量,而创建逻辑向量是耗时的。 - Joshua Ulrich
3个回答

20

我不知道有一种纯R的方法来实现这个功能,因此我编写了一个C函数来为quantstrat包执行此操作。该函数是为特定目的而编写的,所以它并不像我希望的那样通用。例如,您可能注意到它仅适用于实数/双精度/数字数据,因此在调用.firstCross函数之前,请确保将Data强制转换为这种类型。

#include <R.h>
#include <Rinternals.h>

SEXP firstCross(SEXP x, SEXP th, SEXP rel, SEXP start)
{
    int i, int_rel, int_start;
    double *real_x=NULL, real_th;

    if(ncols(x) > 1)
        error("only univariate data allowed");

    /* this currently only works for real x and th arguments
     * support for other types may be added later */
    real_th = asReal(th);
    int_rel = asInteger(rel);
    int_start = asInteger(start)-1;

    switch(int_rel) {
        case 1:  /* >  */
            real_x = REAL(x);
            for(i=int_start; i<nrows(x); i++)
                if(real_x[i] >  real_th)
                    return(ScalarInteger(i+1));
            break;
        case 2:  /* <  */
            real_x = REAL(x);
            for(i=int_start; i<nrows(x); i++)
                if(real_x[i] <  real_th)
                    return(ScalarInteger(i+1));
            break;
        case 3:  /* == */
            real_x = REAL(x);
            for(i=int_start; i<nrows(x); i++)
                if(real_x[i] == real_th)
                    return(ScalarInteger(i+1));
            break;
        case 4:  /* >= */
            real_x = REAL(x);
            for(i=int_start; i<nrows(x); i++)
                if(real_x[i] >= real_th)
                    return(ScalarInteger(i+1));
            break;
        case 5:  /* <= */
            real_x = REAL(x);
            for(i=int_start; i<nrows(x); i++)
                if(real_x[i] <= real_th)
                    return(ScalarInteger(i+1));
            break;
        default:
            error("unsupported relationship operator");
  }
  /* return number of observations if relationship is never TRUE */
  return(ScalarInteger(nrows(x)));
}

这里是调用它的R函数:

.firstCross <- function(Data, threshold=0, relationship, start=1) {
    rel <- switch(relationship[1],
            '>'    =  ,
            'gt'   = 1,
            '<'    =  ,
            'lt'   = 2,
            '=='   =  ,
            'eq'   = 3,
            '>='   =  ,
            'gte'  =  ,
            'gteq' =  ,
            'ge'   = 4,
            '<='   =  ,
            'lte'  =  ,
            'lteq' =  ,
            'le'   = 5)
    .Call('firstCross', Data, threshold, rel, start)
}

一些基准测试,只是为了好玩。

> library(quantstrat)
> library(microbenchmark)
> firstCross <- quantstrat:::.firstCross
> set.seed(21)
> x <- rnorm(1e6)
> microbenchmark(which(x > 3)[1], firstCross(x,3,">"), times=10)
Unit: microseconds
                  expr      min       lq    median       uq      max neval
       which(x > 3)[1] 9482.081 9578.072 9597.3870 9690.448 9820.176    10
 firstCross(x, 3, ">")   11.370   11.675   31.9135   34.443   38.614    10
> which(x>3)[1]
[1] 919
> firstCross(x,3,">")
[1] 919

请注意,firstCross 函数在 Data 越大时将产生更大的相对加速(因为 R 的关系运算符必须完成整个向量的比较)。

> x <- rnorm(1e7)
> microbenchmark(which(x > 3)[1], firstCross(x,3,">"), times=10)
Unit: microseconds
                  expr      min        lq    median        uq        max neval
       which(x > 3)[1] 94536.21 94851.944 95799.857 96154.756 113962.794    10
 firstCross(x, 3, ">")     5.08     5.507    25.845    32.164     34.183    10
> which(x>3)[1]
[1] 97
> firstCross(x,3,">")
[1] 97

如果第一个 TRUE 值在向量的末尾附近,它也不会显著加快速度。

> microbenchmark(which(x==last(x))[1], firstCross(x,last(x),"eq"),times=10)
Unit: milliseconds
                         expr      min       lq   median       uq       max neval
       which(x == last(x))[1] 92.56311 93.85415 94.38338 98.18422 106.35253    10
 firstCross(x, last(x), "eq") 86.55415 86.70980 86.98269 88.32168  92.97403    10
> which(x==last(x))[1]
[1] 10000000
> firstCross(x,last(x),"eq")
[1] 10000000

请检查以下代码:firstCross(as.numeric(1:100),as.numeric(200),"gte")。我期望会出现错误或NA,而不仅仅是最后一行。 - user1603038
@user1603038:你为什么会有这样的期望呢?代码很明确地写着,“如果关系永远不为真,则返回观察数量”。;) 如果你希望如此,只需将C函数更改为返回NA_INTEGER即可。正如我所说,该函数是针对特定用例编写的,并不特别通用。 - Joshua Ulrich
没有注意到那个评论,抱歉。 - user1603038
@user1603038:别担心,我只是在讽刺而已。你可能会遇到其他类似的问题,但你可以随时修改代码并使用“R CMD SHLIB”重新编译。 - Joshua Ulrich
@user1603038:NA_REAL是特殊的NaN,IEEE标准基本上规定与NaN进行比较会得到false - Joshua Ulrich
谢谢!我已经在http://adv-r.had.co.nz/C-interface.html中找到了它,因此已删除我的评论。 - user1603038

11

在基本的R语言中,提供了PositionFind函数,用于查找满足条件的第一个索引和对应的值。这些高阶函数会在找到第一个符合条件的结果后立即返回。

f<-function(x) {
  r<-vector("list",3)
  r[[1]]<-which(x==1)[1]
  r[[2]]<-which(x>1)[1]
  r[[3]]<-x[x>10][1]
  return(r)
}

p<-function(f,b) function(a) f(a,b)
g<-function(x) {
  r<-vector("list",3)
  r[[1]]<-Position(p(`==`,1),x)
  r[[2]]<-Position(p(`>`,1),x)
  r[[3]]<-Find(p(`>`,10),x)
  return(r)
}

相对性能很大程度上取决于尽早发现命中的概率相对于谓词成本与Position/Find开销的比较。

library(microbenchmark)
set.seed(1)
x<-sample(1:100,1e5,replace=TRUE)
microbenchmark(f(x),g(x))

Unit: microseconds
 expr      min        lq     mean    median        uq      max neval cld
 f(x) 5034.283 5410.1205 6313.861 5798.4780 6948.5675 26735.52   100   b
 g(x)  587.463  650.4795 1013.183  734.6375  950.9845 20285.33   100  a

y<-rep(0,1e5)
microbenchmark(f(y),g(y))

Unit: milliseconds
 expr        min         lq       mean     median         uq        max neval cld
 f(y)   3.470179   3.604831   3.791592   3.718752   3.866952   4.831073   100  a 
 g(y) 131.250981 133.687454 137.199230 134.846369 136.193307 177.082128   100   b

1
“Position”和“Find”只是“for”循环的语法糖。并没有什么不妥之处,只是OP提到“for”循环速度较慢。 - Hugh

4

这是一个很好的问题和答案......只是要补充说明any()并不比which()match()更快,但两者都比[]快,我猜可能会创建一个无用的大向量T,F。所以我猜答案是否定的......就像上面的答案一样。

    v=rep('A', 10e6)
    v[5e6]='B'
    v[10e6]='B'

    microbenchmark(which(v=='B')[1])
    Unit: milliseconds
                   expr      min       lq   median       uq      max neval
     which(v == "B")[1] 332.3788 337.6718 344.4076 347.1194 503.4022   100

    microbenchmark(any(v=='B'))
    Unit: milliseconds
              expr      min      lq   median       uq      max neval
     any(v == "B") 334.4466 335.114 335.6714 347.5474 356.0261   100

    microbenchmark(v[v=='B'][1])
    Unit: milliseconds
               expr      min       lq  median       uq      max neval
     v[v == "B"][1] 601.5923 605.3331 609.191 612.0689 707.1409   100

    microbenchmark(match("B", v))
    Unit: milliseconds
               expr      min       lq   median       uq      max neval
    match("B", v) 339.2872 344.7648 350.5444 359.6746 915.6446   100

有其他想法吗?

正如原帖所述,分析表明超过80%的时间都花在关系运算符上,因此只要您仍然使用==,您不应该期望除了边际速度改进之外的任何东西。 - Joshua Ulrich
事实上,我只是好奇是否可以编写any或其他函数以在找到某些内容时停止,从而减少==。 显然不行。 即使这样,它也不会是一个通用的解决方案。 - Stephen Henderson
创建逻辑向量需要耗费时间,因此任何需要逻辑向量(例如 any)的 R 函数都无法帮助。此外,任何对整个对象无条件操作的函数也无法帮助。正如我在答案中所说,我无法找到一种纯 R 的方法来解决这个问题... 这并不是说我没有尝试。在我的情况下,Data 已经排序,并且通常速度很快的 findInterval 并没有什么帮助。 - Joshua Ulrich

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接