为什么标准R语言中的中位数函数比简单的C++替代方法慢得多?

5

我在C++中实现了以下的中位数算法,并通过RcppR中使用它:

// [[Rcpp::export]]
double median2(std::vector<double> x){
  double median;
  size_t size = x.size();
  sort(x.begin(), x.end());
  if (size  % 2 == 0){
      median = (x[size / 2 - 1] + x[size / 2]) / 2.0;
  }
  else {
      median = x[size / 2];
  }
  return median;
}

如果我随后将性能与标准内置的R中位数函数进行比较,通过microbenchmark我得到以下结果。
> x = rnorm(100)
> microbenchmark(median(x),median2(x))
Unit: microseconds
       expr    min     lq     mean median     uq     max neval
  median(x) 25.469 26.990 34.96888 28.130 29.081 518.126   100
 median2(x)  1.140  1.521  2.47486  1.901  2.281  47.897   100

为什么标准中位数函数如此缓慢?这不是我所期望的...

3
首先,看一下 median.default 实际做了什么,然后尝试使用更公平的东西进行测试。 - joran
好的,我猜这是因为周围的所有事情,但实际上计算中位数根本不需要时间。 - Ruben
3
顺便提一下,对向量进行排序有点过头了。您不需要关心前n/2个元素的顺序——您只需要关心第n/2个元素是什么。算法std::nth_element可以比排序更快地完成这个任务。如果想要在R中实现它,可以使用递归中位数-中位数-5和分区来高效地实现,并使用一个短长度的备用算法。其次,在std::vector迭代器上明确地使用std::sort(没有保证它们被定义在namespace std中,而您的代码依赖于此)。 - Yakk - Adam Nevraumont
1
@Yakk 的确,需要注意的是 median.default 函数使用了 sortpartial 参数,这个参数执行的操作与你所描述的类似。 - joran
@joran,它实际上做了更多的事情——它对向量的前一半进行排序。nth_element仅将其分区,以便第n个元素位于位置n,并且所有在其之前的元素都比它小,而所有在其之后的元素都比它大。您可以使用中位数的中位数方法比半排序更快地找到一个几乎是中位数的元素,然后进行分区以找出它的位置。重复此过程,直到找到第n个元素。 - Yakk - Adam Nevraumont
4个回答

16

正如 @joran 所指出的那样,你的代码非常专门化,一般来说,不太通用的函数、算法等往往更具性能。看看 median.default

median.default
# function (x, na.rm = FALSE) 
# {
#   if (is.factor(x) || is.data.frame(x)) 
#     stop("need numeric data")
#   if (length(names(x))) 
#     names(x) <- NULL
#   if (na.rm) 
#     x <- x[!is.na(x)]
#   else if (any(is.na(x))) 
#     return(x[FALSE][NA])
#   n <- length(x)
#   if (n == 0L) 
#     return(x[FALSE][NA])
#   half <- (n + 1L)%/%2L
#   if (n%%2L == 1L) 
#     sort(x, partial = half)[half]
#   else mean(sort(x, partial = half + 0L:1L)[half + 0L:1L])
# }

为了适应缺失值的可能性,有几个操作已经放置。这些操作肯定会影响函数的整体执行时间。由于您的函数没有复制此行为,因此可以消除许多计算,但是对于带有缺失值的向量将不会提供相同的结果:

median(c(1, 2, NA))
#[1] NA

median2(c(1, 2, NA))
#[1] 2

还有一些因素可能没有处理NA那么重要,但是值得指出:

  • median和它使用的几个函数都是S3通用函数,所以会花费一点时间在方法派发上
  • median不仅适用于整数和数字向量;它还可以处理DatePOSIXt和可能还有其他类,并正确地保留属性:

median(Sys.Date() + 0:4)
#[1] "2016-01-15"

median(Sys.time() + (0:4) * 3600 * 24)
#[1] "2016-01-15 11:14:31 EST"

编辑: 需要注意的是下面的函数将会导致原始向量被排序,因为NumericVector是代理对象。如果你想避免这种情况,你可以使用Rcpp::clone来克隆输入向量并对克隆进行操作,或者使用原始签名(使用std::vector<double>),它在从SEXPstd::vector的转换中隐式地需要一次副本。

还要注意,你可以通过使用NumericVector而不是std::vector<double>来节省更多的时间:

#include <Rcpp.h>

// [[Rcpp::export]]
double cpp_med(Rcpp::NumericVector x){
  std::size_t size = x.size();
  std::sort(x.begin(), x.end());
  if (size  % 2 == 0) return (x[size / 2 - 1] + x[size / 2]) / 2.0;
  return x[size / 2];
}

microbenchmark::microbenchmark(
  median(x),
  median2(x),
  cpp_med(x),
  times = 200L
)
# Unit: microseconds
#       expr    min      lq      mean  median      uq     max neval
#  median(x) 74.787 81.6485 110.09870 92.5665 129.757 293.810   200
# median2(x)  6.474  7.9665  13.90126 11.0570  14.844 151.817   200
# cpp_med(x)  5.737  7.4285  11.25318  9.0270  13.405  52.184   200

在上面的评论中,Yakk提出了一个很好的观点 - 也由Jerry Coffin详细阐述 - 关于进行完整排序的低效性。这里是使用std :: nth_element重写的代码,在一个更大的向量上进行基准测试:

#include <Rcpp.h>

// [[Rcpp::export]]
double cpp_med2(Rcpp::NumericVector xx) {
  Rcpp::NumericVector x = Rcpp::clone(xx);
  std::size_t n = x.size() / 2;
  std::nth_element(x.begin(), x.begin() + n, x.end());

  if (x.size() % 2) return x[n]; 
  return (x[n] + *std::max_element(x.begin(), x.begin() + n)) / 2.;
}

set.seed(123)
xx <- rnorm(10e5)

all.equal(cpp_med2(xx), median(xx))
all.equal(median2(xx), median(xx))

microbenchmark::microbenchmark(
  cpp_med2(xx), median2(xx), 
  median(xx), times = 200L
)
# Unit: milliseconds
#         expr      min       lq     mean   median       uq       max neval
# cpp_med2(xx) 10.89060 11.34894 13.15313 12.72861 13.56161  33.92103   200
#  median2(xx) 84.29518 85.47184 88.57361 86.05363 87.70065 228.07301   200
#   median(xx) 46.18976 48.36627 58.77436 49.31659 53.46830 250.66939   200

4
如果只使用median.default函数的最后四行,并将其中的mean()替换为.Internal(mean()),我很好奇会发生什么。我猜这样做应该会得到与median2相当接近甚至更快的速度。 - joran
2
经过测试,它绝对不如 median2 快,但它更接近。 - joran
1
@joran 也许值得在更大的向量上进行测试;当我将 median.default 与两个 C++ 版本在一个 rnorm(1e5) 向量上进行比较时,计时结果更接近。 - nrussell
我也在想median.default的作者为什么选择使用any(is.na(x))而不是anyNA(x),因为后者更快... - nrussell
一次调用nth_element对于偶数长度的列表是不够的。先调用一次,然后在右侧区间上调用min_element(或在左侧上调用max_element,具体取决于您的数学计算方式)。此外,您还必须使用特殊情况代码处理大小为0的列表。 - Yakk - Adam Nevraumont
显示剩余7条评论

2

这更像是对你实际提出的问题的一个扩展评论,而不是答案。

甚至您的代码可能有很大的改进空间。特别是,您正在对整个输入进行排序,尽管您只关心一个或两个元素。

您可以通过使用std::nth_element而不是std::sort将其从O(n log n)更改为O(n)。在元素数量为偶数的情况下,通常要使用std::nth_element找到中间位置之前的元素,然后使用std::min_element找到紧接着的元素 - 但是std::nth_element还会对输入项进行分区,因此std::min_element只需要在nth_element之后运行在中间位置之上的项目上,而不是整个输入数组。也就是说,在nth_element之后,您会得到这样的情况:

enter image description here

std::nth_element的复杂度“平均线性”,(当然)std::min_element也是线性的,因此总体复杂度是线性的。

因此,对于简单情况(元素数量为奇数),您会得到以下结果:

auto pos = x.begin() + x.size()/2;

std::nth_element(x.begin(), pos, x.end());
return *pos;

对于更复杂的情况(元素数量为偶数):

std::nth_element(x.begin(), pos, x.end());
auto pos2 = std::min_element(pos+1, x.end());
return (*pos + *pos2) / 2.0;

-1

这里可以期望max_element( ForwardIt first, ForwardIt last )提供了从first到last的最大值,但是通过执行:return (x[n] + *std::max_element(x.begin(), x.begin() + n)) / 2.x.begin() + n元素似乎被排除在计算之外。为什么会有这种差异?

例如:cpp_med2({6, 2, 1, 5, 3, 4})产生x={2, 1, 3, 4, 5, 6},其中:

n = 3
*x[n] = 4
*x.begin() = 2
*(x.begin() + n) = 4
*std::max_element(x.begin(), x.begin() + n) = 3

这样,cpp_med2({6, 2, 1, 5, 3, 4}) 返回 (4+3)/2=3.5,这是正确的中位数。 但是为什么 *std::max_element(x.begin(), x.begin() + n) 等于3而不是4?该函数似乎在最大值计算中排除了最后一个元素(4)。

已解决(我想):在:

在范围 [first, last) 中查找最大元素

中,右括号)表示排除最后一个元素。这是正确的吗?

此致敬礼


-1

我不确定您所指的“标准”实现是什么。

无论如何:如果有一个,它肯定不允许更改向量中元素的顺序(就像您的实现一样),因为它是标准库的一部分。因此,它肯定必须在副本上工作。

创建这个副本需要时间和CPU(以及大量内存),这将影响运行时间。


2
C++代码也会进行复制,因此复制时间应该大致相同。 - NathanOliver
1
他通过传递向量,而不是通过const引用。 - Martin Bonner supports Monica
我的意思是统计包中的中位数函数(一个标准包)。谢谢你注意到我改变了x变量,我没有注意到。编辑:这是按值传递,因此会创建一个副本。 - Ruben

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接