在二分查找中计算中间值

93
我正在阅读一本关于算法的书,其中提到了下面这个二分查找的算法:
public class BinSearch {
  static int search ( int [ ] A, int K ) {
    int l = 0 ;
    int u = A. length −1;
    int m;
    while (l <= u ) {
      m = (l+u) /2;
      if (A[m] < K) {
        l = m + 1 ;
      } else if (A[m] == K) {
        return m;
        } else {
          u = m−1;
        }
       }
       return1;
      }
 }

作者说:"该错误在赋值语句m = (l+u)/2;中,它可能会导致溢出,应该替换为m = l + (u-l)/2。"

我看不出来它会导致溢出。当我在脑海中运行算法以获取几个不同的输入时,我没有看到中间值超出数组索引。

那么,在哪些情况下会发生溢出?


4
对两个数进行加、减、乘操作都会产生更多的比特位,因此很明显存在溢出的风险。 - phuclv
1
可能是binary search middle value calculation的重复问题。 - Jingguo Yao
14个回答

132

这篇文章详细介绍了这个著名的bug。正如其他人所说,这是一个溢出问题。链接中推荐的修复方法如下:

int mid = low + ((high - low) / 2);

// Alternatively
int mid = (low + high) >>> 1;

值得一提的是,如果允许使用负索引,或者可能不是在搜索数组(例如,在满足某些条件的整数范围内搜索值),上面的代码也可能不正确。在这种情况下,可能需要使用以下代码。

(low < 0 && high > 0) ? (low + high) / 2 : low + (high - low) / 2

可能需要这样做。一个很好的例子是 在线性时间和常数空间内搜索未排序数组的中位数,只需在整个Integer.MIN_VALUE - Integer.MAX_VALUE 范围上执行二分查找。


1
您提供的链接清晰地解释了这个问题。谢谢! - Bharat
1
使用(high / 2 + low / 2)可以吗? - Fakru
7
为什么上述的可替代方法中的 (low + high),即 int mid = (low + high) >>> 1 不会导致溢出? - Vineet Kapoor
提议的解决方案 (low + high) >>> 1 在 C 中使用时可能不安全,其中索引可以定义为无符号类型。具体来说,当 lowhigh 都大于 SIZE_MAX / 2 时,(low + high) 会溢出。 - Explorer09
3
@Fakrudeen说(high / 2 + low / 2)会舍去最低有效位并产生错误的结果。例如,当low=3high=5时,mid变成了3,而它应该是4。 - Explorer09
显示剩余3条评论

51
下面的C++程序可以向您展示32位无符号整数如何发生溢出:
#include <iostream>
using namespace std;

int main ()
{
  unsigned int  low = 33,  
                high = 4294967290, 
                mid;

  cout << "The value of low is " << low << endl;
  cout << "The value of high is " << high << endl;

  mid = (low + high) / 2;

  cout << "The value of mid is " << mid << endl;
  
  return 0;
}

如果你在Mac上运行它:
$ g++ try.cpp && ./a.out
The value of low is 33
The value of high is 4294967290
The value of mid is 13

mid的值可能预期为2147483661,但是low + high溢出了,因为32位无符号整数无法包含正确的值,并返回27,所以mid变成了13

mid的计算方式改变为

mid = low + (high - low) / 2;

然后它会显示
The value of mid is 2147483661

简单来说,加法操作 `l + u` 可能会溢出,并且在某些语言中具有未定义的行为,正如Joshua Bloch 的博客文章所描述的关于 Java 二分查找实现中的一个 bug
有些读者可能不明白这是什么意思:
l + (u - l) / 2

请注意,在某些代码中,变量名可能不同,并且这是正常的。
low + (high - low) / 2

答案是:假设你有两个数字:200和210,现在你想要得到"中间的数字"。假设如果你将任意两个数字相加的结果大于255,那么它可能会溢出并且行为是未定义的,那么你该怎么办呢?一个简单的方法就是只添加它们之间的差值的一半到较小的值上:看一下200和210之间的差值是多少。它是10。(你可以将其视为它们之间的"差值"或"长度")。所以你只需要将10除以2得到5,然后加到200上,得到205。你不需要先将200和210相加 -- 这就是我们如何得到计算式:(u - l)是差值,(u - l) / 2是它的一半。将它加到l上,我们就得到了l + (u - l) / 2

就好像我们在看两棵树,一棵高200英尺,另一棵高210英尺,那么"中点"或者"平均值"是多少呢?我们不需要先把它们加在一起。我们只需要知道它们的差距是10英尺,然后再加上一半的差距,也就是5英尺,于是我们知道结果是205英尺。

从历史的角度来看,Robert Sedgewick提到第一个二分查找是在1946年提出的,直到1964年才被证明正确。Jon Bentley在他的书《Programming Pearls》中描述了在1988年,超过90%的专业程序员在几个小时内都无法正确编写二分查找算法。但即使是Jon Bentley自己也有20年时间存在溢出错误。一项在1988年发表的研究显示,在20本教科书中,只有5本包含了正确的二分查找代码。在2006年,Joshua Bloch写了一篇关于计算"mid"值的bug的博文。所以这段代码花了60年时间才变得正确。但是现在,在下一次求职面试时,请记住要在5分钟内正确编写它。


我认为你的意思是使用 std::int32_t,而不是 int(后者可能具有比你预期更大的范围)。 - Toby Speight
是这样的...在我的 Mac 上,它是 32 位的。在某些平台上,它是 64 位的,这是真的吗? - nonopolarity
我可能有点过于强硬了 - 或者忽略了您指定的平台。如果您使用固定宽度类型进行演示,则可以在提供该类型的_任何平台_上重现该问题。 - Toby Speight
1
顺便提一下,C++20引入了std::midpoint()来解决这个问题,而不是让每个程序员都重新发明它 - 阅读源代码的GNU实现是有益的,以了解它实际上是多么不直观。 - Toby Speight
2
我在StackOverflow上最喜欢的答案之一 :) - Arka Mukherjee
2
@nonopolarity,请为您如此清晰的解释自我表扬。我很幸运能够遇到这个解释。这就是你需要理解整数溢出修复的全部内容。 - Aseem Sharma

11

问题在于先计算(l+u),可能会导致int溢出,所以(l+u)/2会返回错误的值。


5

Jeff建议阅读这篇非常好的文章来了解这个bug,如果你想快速了解,下面是摘要。

在编程珠玑中,Bentley说:“将m设置为l和u的平均值,向下取整到最近的整数。”表面上看,这个断言可能是正确的,但对于int变量low和high的大值,它会失败。具体来说,如果low和high的和大于最大正int值(2^31-1),则总和会溢出为负值,并且在除以二时保持为负值。在C语言中,这会导致数组索引越界并产生不可预测的结果。在Java中,它会抛出ArrayIndexOutOfBoundsException异常。


5

这里举个例子,假设你有一个非常大的数组,大小为2,000,000,00010 (10^9 + 10),左边的index2,000,000,000处,右边的index2,000,000,000 + 1处。

使用lo + hi将相加为2,000,000,000 + 2,000,000,001 = 4,000,000,001。由于integer的最大值为2,147,483,647,因此你将得到integer overflow而不是4,000,000,000 + 1

但是low + ((high - low) / 2)可以解决这个问题。2,000,000,000 + ((2,000,000,001 - 2,000,000,000) / 2) = 2,000,000,000


4

潜在的溢出问题在于 l+u 这个加法本身。

这实际上是 JDK 中二分查找早期版本中的 一个 bug


3

这篇答案提供了一个实际的例子来说明为什么需要进行l + (r-l)/2计算。

如果你想知道这两个数学上等价的证明,那么这里是证明过程。关键在于加上0,然后将其分割成l/2 - l/2

(l+r)/2 =
l/2 + r/2 =
l/2 + r/2 + 0 =
l/2 + r/2 + (l/2 - l/2) =
(l/2 + l/2) + (r/2 - l/2) =
l + (r-l)/2

3

实际上,在计算mid时,以下语句可能导致INT范围溢出。

mid = (start + end) /2

假设给定有序输入列表非常大且超过了INT范围(-2 ^ 31到2 ^ 31-1)start + end可能会引发异常。为了解决这个问题,编写了以下语句:

mid = start + (end-start)/2

最终结果是相同的表达式。但是通过这种技巧避免了异常。


1

这是因为如果我们添加: [mid = low + high],并且mid和high都很大,它们的相加可能超出整数范围。

另外为什么不是[mid = low/2 + high/2],因为这是一个整数除法,所以如果[low = 5 and high = 11],那么[mid = low/2 + high/2]将是mid = 5/2 + 11/2 => 2+ 5 => 9,这将导致错误的答案。这就是为什么mid被视为low + (high-low)/2的原因。


1
为避免溢出,您也可以这样做: int midIndex = (int)(startIndex / 2.0 + endIndex / 2.0); 您将两个索引都除以2.0 -> 您得到的是两个小于或等于Integer.MAX_VALUE / 2的double,它们的和也小于或等于Integer.MAXVALUE并且是double。Integer.MIN_VALUE同理。最后,您将总和转换为int并防止了溢出 ;)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接