寻找排序数组中第一个大于目标值的元素

71

在一般的二分查找中,我们要寻找出现在数组中的值。然而有时候,我们需要找到第一个比目标值大或小的元素。

这里是我丑陋、不完整的解决方案:

// Assume all elements are positive, i.e., greater than zero
int bs (int[] a, int t) {
  int s = 0, e = a.length;
  int firstlarge = 1 << 30;
  int firstlargeindex = -1;
  while (s < e) {
    int m = (s + e) / 2;
    if (a[m] > t) {
      // how can I know a[m] is the first larger than
      if(a[m] < firstlarge) {
        firstlarge = a[m];
        firstlargeindex = m;
      }
      e = m - 1; 
    } else if (a[m] < /* something */) {
      // go to the right part
      // how can i know is the first less than  
    }
  }
}

有没有更优雅的解决方案来解决这种问题?

2
数组已排序?如果是,则使用二分查找,否则使用线性查找... - YXD
这是一个已排序的数组,因此我们可以进行二分查找。在上述代码中,我使用比较来查找数组中第一个大于或小于的元素。 - SecureFish
3
为什么不使用C++ STL中的upper_bound函数? - Shivendra
8个回答

105

解决这个问题的一种思路是将数组进行转换,然后在转换后的数组上执行二分查找。具体来说,就是对数组应用某个函数进行修改。

f(x) = 1 if x > target
       0 else

现在的目标是找到这个函数第一次取值为1的位置。我们可以通过二分查找来实现:

int low = 0, high = numElems; // numElems is the size of the array i.e arr.size() 
while (low != high) {
    int mid = (low + high) / 2; // Or a fancy way to avoid int overflow
    if (arr[mid] <= target) {
        /* This index, and everything below it, must not be the first element
         * greater than what we're looking for because this element is no greater
         * than the element.
         */
        low = mid + 1;
    }
    else {
        /* This element is at least as large as the element, so anything after it can't
         * be the first element that's at least as large.
         */
        high = mid;
    }
}
/* Now, low and high both point to the element in question. */

为了证明这个算法是正确的,考虑每次比较。如果我们找到一个不大于目标元素的元素,则它和它下面的所有元素都不可能匹配,因此没有必要搜索该区域。我们可以递归地搜索右半部分。如果我们找到一个大于所询问元素的元素,则其后的任何元素都必须更大,因此它们不能是第一个更大的元素,因此我们不需要搜索它们。因此,中间元素是最后可能的位置。
请注意,在每次迭代中,我们至少会丢弃剩余元素的一半。如果执行顶部分支,则范围内的元素[low,(low + high) / 2]都被丢弃,导致我们失去floor((low + high) / 2) - low + 1 >= (low + high) / 2 - low = (high - low) / 2个元素。
如果执行底部分支,则范围内的元素[(low + high) / 2 + 1,high]都被丢弃。这使我们失去high - floor(low + high) / 2 + 1 >= high - (low + high) / 2 = (high - low) / 2个元素。
因此,我们将在O(lg n)次这个过程中找到第一个大于目标的元素。
以下是算法在数组0 0 1 1 1 1上运行的跟踪。
最初,我们有
0 0 1 1 1 1
L = 0       H = 6

因此,我们计算出 mid = (0 + 6) / 2 = 3,然后检查位于位置 3 的元素,其值为 1。由于 1 > 0,所以我们将 high = mid = 3。现在我们有:

0 0 1
L     H

我们计算出 mid = (0 + 3) / 2 = 1,所以我们检查元素 1。由于它的值为 0 <= 0,因此我们设置 mid = low + 1 = 2。现在我们只剩下 L = 2H = 3
0 0 1
    L H

现在,我们计算 mid = (2 + 3) / 2 = 2。索引 2 处的元素为 1,因为 10,所以我们将 H = mid = 2,此时停止搜索,确实我们找到了第一个大于 0 的元素。

1
请运行一下您的解决方案示例,例如我们有输入0,0,1,1,1,1。让我们找到第一个大于0的元素。 - SecureFish
1
同样地,我们可以组合函数来查找第一个小于t的元素:如果(a[m]>=t),则high=m-1;否则,low=m。 - SecureFish
2
@SecureFish:只是补充一下这个小问题:对于这个相反的问题,也需要调整mid的计算。由于除法中的向下取整效应和减法的组合,可能会出现负高值而不进行修改。可以通过在此计算中改变为向上取整行为来解决这个问题,例如再添加模2项。 - bluenote10
如果我们进行普通的二分查找,当array[mid]==target时,同时low=mid+1,那么low将始终指向第一个大于目标元素的位置,对吗? - dsfdf
我不理解你的示例。你给了一个有六个元素的数组,所以numElems=6,但你正在初始化H=5。 - Ed Avis
显示剩余12条评论

13

如果数组已经排序好了(假设n是数组a[]的大小),你可以使用std::upper_bound

int* p = std::upper_bound( a, a + n, x );
if( p == a + n )
     std::cout << "No element greater";
else
     std::cout << "The first element greater is " << *p
               << " at position " << p - a;

13

经过多年教授算法后,我解决二分查找问题的方法是将起始点和结束点设置在元素上而不是在数组外部。这样我可以感受到发生了什么,一切都在掌控之中,而不会对解决方案感到神秘。

解决二分查找问题(以及许多其他基于循环的解决方案)的关键点是良好的不变量集合。选择正确的不变量使得问题解决不费吹灰之力。虽然我很早就在大学里学习了不变量的概念,但花了我很多年才理解它。

即使您想通过在数组外部选择起始或结束点来解决二分查找问题,只要有适当的不变量也可以实现。话虽如此,我的选择如上所述,总是将起始点放在数组的第一个元素上,将结束点放在最后一个元素上。

因此,总结一下,我们迄今为止所拥有的内容如下:

int start = 0; 
int end = a.length - 1; 

现在讲不变量。我们当前拥有的数组是[start,end]。我们还不知道元素的任何信息。它们中的所有元素都可能大于目标值,也可能都小于目标值,或者有些小于目标值而有些大于目标值。因此,到目前为止,我们不能对这些元素做出任何假设。我们的目标是找到第一个大于目标值的元素。因此,我们选择以下不变量

右侧所有元素均大于目标值。
左侧所有元素小于或等于目标值。

我们很容易看出,在开始时(即进入任何循环之前),我们的不变量是正确的。所有在开始左侧的元素(基本上没有元素)都小于或等于目标值,对于结束位置同理。

有了这个不变量,当循环完成时,结束位置后面的第一个元素将是答案(记住不变量,即结束位置右侧的所有元素都大于目标值?)。因此,answer = end + 1

此外,我们需要注意,当循环完成时,开始位置将比结束位置多一。也就是说,start = end + 1。因此,我们可以同样地说,开始位置也是答案(不变量是左侧所有元素小于或等于目标值,因此开始位置本身就是第一个大于目标值的元素)。

综上所述,以下是代码。

public static int find(int a[], int target) {
    int st = 0; 
    int end = a.length - 1; 
    while(st <= end) {
        int mid = (st + end) / 2;   // or elegant way of st + (end - st) / 2; 
        if (a[mid] <= target) {
            st = mid + 1; 
        } else { // mid > target
            end = mid - 1; 
        }
    }
    return st; // or return end + 1
}
关于解决二分查找问题的这种方式,有一些额外的注意事项:
这种解决方案总是至少将子数组的大小缩小 1。这在代码中很明显。新的开始或结束位置要么是 +1,要么是 -1 在中间位置。我喜欢这种方法比在两边或一边包括中间位置然后再推理为什么算法正确。这样更具体,更不容易出错。
while 循环的条件是 st <= end,而不是 st < end。这意味着进入 while 循环的最小尺寸是大小为 1 的数组。这完全符合我们的预期。在其他解决二分查找问题的方法中,有时最小尺寸是大小为 2 的数组(如果 st < end),老实说,我发现始终考虑包括大小为 1 的所有数组尺寸要容易得多。
因此,希望这可以澄清这个问题以及许多其他二分查找问题的解决方案。把这个解决方案当作一种专业的方式来理解和解决更多的二分查找问题,而不会摇摆不定算法是否适用于边缘情况。

谢谢您使用不变量进行解释!这让一切都更加清晰明了。 - simha
@apadana 如果你想找到第一个等于目标的元素,你会如何处理起始和结束位置?你需要一种方式来指示该元素不在数组中,通常使用-1来表示(但不一定)。 - Seth
1
@Seth 将不变量定义为:在起始位置左侧的任何元素都小于目标值(因此在结束位置右侧的任何元素都大于或等于目标值)。当循环结束时,end + 1 将指向第一个等于或大于目标值的元素。循环结束后,您可以检查 "end + 1" 是否等于目标值。如果不是,则返回 -1。 - apadana
1
@Seth,我在关于二分查找的回答中也有这个回答:https://dev59.com/wF0a5IYBdhLWcg3wAEo3#54374256,它对谓词作了更明确的解释。请记住,在二分查找问题中,目标是找到0和1之间的边界。如果你能想到分配0和1(换句话说,为每个元素提供0或1的好谓词),那么解决二分查找问题将会非常容易。在你提到的问题中,小于目标的任何数字都是0,大于或等于目标的任何数字都是1。所以目标是找到第一个1。[0 0 0 0 1 1 1]。希望这澄清了这一点。 - apadana

2
这是一段经过修改的JAVA二分搜索代码,时间复杂度为O(logn),它能够实现以下功能:
  • 如果要查找的元素存在,则返回该元素的索引
  • 如果要查找的元素不存在,则返回下一个更大元素的索引
  • 如果查找的元素大于数组中最大的元素,则返回-1
public static int search(int arr[],int key) {
    int low=0,high=arr.length,mid=-1;
    boolean flag=false;
    
    while(low<high) {
        mid=(low+high)/2;
        if(arr[mid]==key) {
            flag=true;
            break;
        } else if(arr[mid]<key) {
            low=mid+1;
        } else {
            high=mid;
        }
    }
    if(flag) {
        return mid;
    }
    else {
        if(low>=arr.length)
            return -1;
        else
        return low;
        //high will give next smaller
    }
}

public static void main(String args[]) throws IOException {
    BufferedReader br=new BufferedReader(new InputStreamReader(System.in));
    //int n=Integer.parseInt(br.readLine());
    int arr[]={12,15,54,221,712};
    int key=71;
    System.out.println(search(arr,key));
    br.close();
}

2
以下是一个递归方法,您觉得怎么样:
public static int minElementGreaterThanOrEqualToKey(int A[], int key,
    int imin, int imax) {

    // Return -1 if the maximum value is less than the minimum or if the key
    // is great than the maximum
    if (imax < imin || key > A[imax])
        return -1;

    // Return the first element of the array if that element is greater than
    // or equal to the key.
    if (key < A[imin])
        return imin;

    // When the minimum and maximum values become equal, we have located the element. 
    if (imax == imin)
        return imax;

    else {
        // calculate midpoint to cut set in half, avoiding integer overflow
        int imid = imin + ((imax - imin) / 2);

        // if key is in upper subset, then recursively search in that subset
        if (A[imid] < key)
            return minElementGreaterThanOrEqualToKey(A, key, imid + 1, imax);

        // if key is in lower subset, then recursively search in that subset
        else
            return minElementGreaterThanOrEqualToKey(A, key, imin, imid);
    }
}

1

我的实现使用条件 bottom <= top,这与 templatetypedefanswer 不同。

int FirstElementGreaterThan(int n, const vector<int>& values) {
  int B = 0, T = values.size() - 1, M = 0;
  while (B <= T) { // B strictly increases, T strictly decreases
    M = B + (T - B) / 2;
    if (values[M] <= n) { // all values at or before M are not the target
      B = M + 1;
    } else {
      T = M - 1;// search for other elements before M
    }
  }
  return T + 1;
}

1
public static int search(int target, int[] arr) {
        if (arr == null || arr.length == 0)
            return -1;
        int lower = 0, higher = arr.length - 1, last = -1;
        while (lower <= higher) {
            int mid = lower + (higher - lower) / 2;
            if (target == arr[mid]) {
                last = mid;
                lower = mid + 1;
            } else if (target < arr[mid]) {
                higher = mid - 1;
            } else {
                lower = mid + 1;
       }
    }
    return (last > -1 && last < arr.length - 1) ? last + 1 : -1;
}

如果我们找到target == arr[mid],那么任何之前的元素都会小于或等于目标。因此,将下限设置为lower = mid + 1。此外,last是'target'的最后一个索引。最后,我们返回last+1,注意边界条件。

0
  • kind=0:精确匹配
  • kind=1:大于x的值
  • kind=-1:小于x的值
  • 如果没有找到匹配项,则返回-1
#include <iostream>
#include <algorithm>

using namespace std;


int g(int arr[], int l , int r, int x, int kind){
    switch(kind){
    case 0: // for exact match
        if(arr[l] == x) return l;
        else if(arr[r] == x) return r;
        else return -1;
        break;
    case 1: // for just greater than x
        if(arr[l]>=x) return l;
        else if(arr[r]>=x) return r;
        else return -1;
        break;
    case -1: // for just smaller than x
        if(arr[r]<=x) return r;
        else if(arr[l] <= x) return l;
        else return -1;
        break;
    default:
        cout <<"please give "kind" as 0, -1, 1 only" << ednl;
    }
}

int f(int arr[], int n, int l, int r, int x, int kind){
    if(l==r) return l;
    if(l>r) return -1;
    int m = l+(r-l)/2;
    while(m>l){
        if(arr[m] == x) return m;
        if(arr[m] > x) r = m;
        if(arr[m] < x) l = m;
        m = l+(r-l)/2;
    }
    int pos = g(arr, l, r, x, kind);
    return pos;
}

int main()
{
    int arr[] = {1,2,3,5,8,14, 22, 44, 55};
    int n = sizeof(arr)/sizeof(arr[0]);
    sort(arr, arr+n);
    int tcs;
    cin >> tcs;
    while(tcs--){
        int l = 0, r = n-1, x = 88, kind = -1; // you can modify these values
        cin >> x;
        int pos = f(arr, n, l, r, x, kind);
        // kind =0: exact match, kind=1: just grater than x, kind=-1: just smaller than x;
        cout <<"position"<< pos << " Value ";
        if(pos >= 0) cout << arr[pos];
        cout << endl;
    }
    return 0;
}

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接