在最坏情况下时间复杂度为O(n)的2D峰值查找算法?

19

我正在学习 MIT 的算法课程(链接)。在第一节课中,教授提出了以下问题:

在一个二维数组中,如果一个值的四个相邻元素都小于或等于它,那么它就是一个“峰”,即对于 a[i][j] 来说,它是一个局部最大值。

a[i+1][j] <= a[i][j] 
&& a[i-1][j] <= a[i][j]
&& a[i][j+1] <= a[i][j]
&& a[i+1][j-1] <= a[i][j]

现在,给定一个NxN的二维数组,在数组中找到一个峰值

可以通过遍历所有元素并返回一个峰值来很容易地在O(N^2)时间内解决此问题。

然而,可以通过使用分治解决方案(如此处所述)来将其优化为在O(NlogN)时间内解决。

但他们说存在一种O(N)时间算法来解决此问题。请建议如何在O(N)时间内解决此问题。

PS(对于那些了解Python的人)课程工作人员已经在他们的问题集中此处(问题1-5. Peak-Finding Proof)解释了一种方法,并且还提供了一些Python代码。但所解释的方法完全不明显,非常难以理解。Python代码同样令人困惑。因此,我已经复制了以下代码的主要部分,供那些知道Python并且可以从代码中了解使用的算法的人使用。

def algorithm4(problem, bestSeen = None, rowSplit = True, trace = None):
    # if it's empty, we're done 
    if problem.numRow <= 0 or problem.numCol <= 0:
        return None

    subproblems = []
    divider = []

    if rowSplit:
        # the recursive subproblem will involve half the number of rows
        mid = problem.numRow // 2

        # information about the two subproblems
        (subStartR1, subNumR1) = (0, mid)
        (subStartR2, subNumR2) = (mid + 1, problem.numRow - (mid + 1))
        (subStartC, subNumC) = (0, problem.numCol)

        subproblems.append((subStartR1, subStartC, subNumR1, subNumC))
        subproblems.append((subStartR2, subStartC, subNumR2, subNumC))

        # get a list of all locations in the dividing column
        divider = crossProduct([mid], range(problem.numCol))
    else:
        # the recursive subproblem will involve half the number of columns
        mid = problem.numCol // 2

        # information about the two subproblems
        (subStartR, subNumR) = (0, problem.numRow)
        (subStartC1, subNumC1) = (0, mid)
        (subStartC2, subNumC2) = (mid + 1, problem.numCol - (mid + 1))

        subproblems.append((subStartR, subStartC1, subNumR, subNumC1))
        subproblems.append((subStartR, subStartC2, subNumR, subNumC2))

        # get a list of all locations in the dividing column
        divider = crossProduct(range(problem.numRow), [mid])

    # find the maximum in the dividing row or column
    bestLoc = problem.getMaximum(divider, trace)
    neighbor = problem.getBetterNeighbor(bestLoc, trace)

    # update the best we've seen so far based on this new maximum
    if bestSeen is None or problem.get(neighbor) > problem.get(bestSeen):
        bestSeen = neighbor
        if not trace is None: trace.setBestSeen(bestSeen)

    # return when we know we've found a peak
    if neighbor == bestLoc and problem.get(bestLoc) >= problem.get(bestSeen):
        if not trace is None: trace.foundPeak(bestLoc)
        return bestLoc

    # figure out which subproblem contains the largest number we've seen so
    # far, and recurse, alternating between splitting on rows and splitting
    # on columns
    sub = problem.getSubproblemContaining(subproblems, bestSeen)
    newBest = sub.getLocationInSelf(problem, bestSeen)
    if not trace is None: trace.setProblemDimensions(sub)
    result = algorithm4(sub, newBest, not rowSplit, trace)
    return problem.getLocationInSelf(sub, result)

#Helper Method
def crossProduct(list1, list2):
    """
    Returns all pairs with one item from the first list and one item from 
    the second list.  (Cartesian product of the two lists.)

    The code is equivalent to the following list comprehension:
        return [(a, b) for a in list1 for b in list2]
    but for easier reading and analysis, we have included more explicit code.
    """

    answer = []
    for a in list1:
        for b in list2:
            answer.append ((a, b))
    return answer

2
只需要一个随机峰值还是所有峰值? - Andrey
2
一个随机峰很简单,但似乎没什么用处。 - bosnjak
4
可能这个问题看起来没什么用,但它是一个非常好的算法问题。 - Nikunj Banka
3个回答

15
  1. 假设数组的宽度大于高度,否则我们将在另一个方向上进行划分。
  2. 将数组分成三部分:中心列、左侧和右侧。
  3. 遍历中心列和两个相邻的列,并查找最大值。
    • 如果它在中心列中 - 这是我们的峰值
    • 如果它在左边,运行该算法在子数组left_side+central_column
    • 如果它在右边,运行该算法在子数组right_side+central_column

为什么这有效:

对于最大元素位于中心列的情况 - 很明显。如果不是这样,我们可以从那个最大值到递增的元素步进,肯定不会越过中心行,因此峰值肯定存在于相应的一半中。

为什么这是O(n):

步骤#3需要少于或等于max_dimension次迭代,并且每两个算法步骤中的max_dimension至少减半。这给出了n+n/2+n/4+...,即O(n)。重要细节:我们按最大方向分割。对于正方形数组,这意味着分割方向将交替出现。这是与您链接的PDF中的最后一次尝试的不同之处。

注意:我不确定它是否完全匹配您提供的代码中的算法,可能是不同的方法。


4
每次数组都会变小,所以它不是 n+n+n+...log(n) 次方。而是 n+n/2+n/4+...<2n - maxim1000
1
我似乎现在理解了这个算法,并且它将是O(n)。但是,您能否在您的答案中添加一个递归关系。这将进一步说明运行时间是O(n)。 - Nikunj Banka
2
我已经添加了一个关于交替方向的注释,这应该澄清了情况。对于正方形数组,递推关系将是 T(n)=T.horiz.(n)+T.vert.(n), T.horiz(n)=T.horiz.(n/2)+O(n) and T.vert.(n)=T.vert.(n/2)+O(n) - maxim1000
1
感谢您出色的回答。您能否也告诉我一些关于您如何解决这个问题并得出这个算法的方法? - Nikunj Banka
2
这个答案是错误的,因为您没有正确考虑列数和行数。设n为列数,m为行数,您正在将“m+m+m+...”添加“logn”次,但不是“n+n+n+...”。因此,该算法实际上是O(mlogn)。当矩阵为n x n时,该算法为O(nlogn)。在Youtube上MIT 6.006讲座视频的第1讲中提供了分析。 - yyFred
显示剩余6条评论

3

查看tha(n):

计算步骤在图片中

查看算法实现:

1)从1a)或1b)开始

1a)设置左半部分、分隔符、右半部分。

1b)设置上半部分、分隔符、下半部分。

2) 在分隔符上找到全局最大值。 [theta n]

3) 找到它的邻居的值,并将最大节点记录为bestSeen节点。[theta 1]

# update the best we've seen so far based on this new maximum
if bestSeen is None or problem.get(neighbor) > problem.get(bestSeen):
    bestSeen = neighbor
    if not trace is None: trace.setBestSeen(bestSeen)

4) 检查全局最大值是否大于bestSeen及其邻居。[theta1]

//第4步是此算法有效的主要关键。

# return when we know we've found a peak
if neighbor == bestLoc and problem.get(bestLoc) >= problem.get(bestSeen):
    if not trace is None: trace.foundPeak(bestLoc)
    return bestLoc

5) 如果 4) 成立,则返回全局最大值作为 2-D 峰值。

否则,如果这次是 1a),选择 BestSeen 的一半,回到步骤 1b)。

否则,选择 BestSeen 的一半,回到步骤 1a)。


为了直观地理解算法的原理,就好像抓住最高值的那一侧,不断缩小边界并最终得到 BestSeen 值。

# 可视化模拟

第一轮

第二轮

第三轮

第四轮

第五轮

第六轮

最后一轮

对于这个 10*10 的矩阵,我们只用了 6 步来搜索 2-D 峰值,这相当令人信服,它确实是 theta n。


作者:Falcon


一个小修正,实现步骤2将沿着矩阵的维度逐渐减少。因此,仅在第一轮中为theta n,然后变为theta n/2,n/4等。 - Falcon Lin

1

这里是实现@maxim1000算法的工作Java代码。以下代码在线性时间内找到2D数组中的峰值。

import java.util.*;

class Ideone{
    public static void main (String[] args) throws java.lang.Exception{
        new Ideone().run();
    }
    int N , M ;

    void run(){
        N = 1000;
        M = 100;

        // arr is a random NxM array
        int[][] arr = randomArray();
        long start = System.currentTimeMillis();
//      for(int i=0; i<N; i++){   // TO print the array. 
//          System. out.println(Arrays.toString(arr[i]));
//      }
        System.out.println(findPeakLinearTime(arr));
        long end = System.currentTimeMillis();
        System.out.println("time taken : " + (end-start));
    }

    int findPeakLinearTime(int[][] arr){
        int rows = arr.length;
        int cols = arr[0].length;
        return kthLinearColumn(arr, 0, cols-1, 0, rows-1);
    }

    // helper function that splits on the middle Column
    int kthLinearColumn(int[][] arr, int loCol, int hiCol, int loRow, int hiRow){
        if(loCol==hiCol){
            int max = arr[loRow][loCol];
            int foundRow = loRow;
            for(int row = loRow; row<=hiRow; row++){
                if(max < arr[row][loCol]){
                    max = arr[row][loCol];
                    foundRow = row;
                }
            }
            if(!correctPeak(arr, foundRow, loCol)){
                System.out.println("THIS PEAK IS WRONG");
            }
            return max;
        }
        int midCol = (loCol+hiCol)/2;
        int max = arr[loRow][loCol];
        for(int row=loRow; row<=hiRow; row++){
            max = Math.max(max, arr[row][midCol]);
        }
        boolean centralMax = true;
        boolean rightMax = false;
        boolean leftMax  = false;

        if(midCol-1 >= 0){
            for(int row = loRow; row<=hiRow; row++){
                if(arr[row][midCol-1] > max){
                    max = arr[row][midCol-1];
                    centralMax = false;
                    leftMax = true;
                }
            }
        }

        if(midCol+1 < M){
            for(int row=loRow; row<=hiRow; row++){
                if(arr[row][midCol+1] > max){
                    max = arr[row][midCol+1];
                    centralMax = false;
                    leftMax = false;
                    rightMax = true;
                }
            }
        }

        if(centralMax) return max;
        if(rightMax)  return kthLinearRow(arr, midCol+1, hiCol, loRow, hiRow);
        if(leftMax)   return kthLinearRow(arr, loCol, midCol-1, loRow, hiRow);
        throw new RuntimeException("INCORRECT CODE");
    }

    // helper function that splits on the middle 
    int kthLinearRow(int[][] arr, int loCol, int hiCol, int loRow, int hiRow){
        if(loRow==hiRow){
            int ans = arr[loCol][loRow];
            int foundCol = loCol;
            for(int col=loCol; col<=hiCol; col++){
                if(arr[loRow][col] > ans){
                    ans = arr[loRow][col];
                    foundCol = col;
                }
            }
            if(!correctPeak(arr, loRow, foundCol)){
                System.out.println("THIS PEAK IS WRONG");
            }
            return ans;
        }
        boolean centralMax = true;
        boolean upperMax = false;
        boolean lowerMax = false;

        int midRow = (loRow+hiRow)/2;
        int max = arr[midRow][loCol];

        for(int col=loCol; col<=hiCol; col++){
            max = Math.max(max, arr[midRow][col]);
        }

        if(midRow-1>=0){
            for(int col=loCol; col<=hiCol; col++){
                if(arr[midRow-1][col] > max){
                    max = arr[midRow-1][col];
                    upperMax = true;
                    centralMax = false;
                }
            }
        }

        if(midRow+1<N){
            for(int col=loCol; col<=hiCol; col++){
                if(arr[midRow+1][col] > max){
                    max = arr[midRow+1][col];
                    lowerMax = true;
                    centralMax = false;
                    upperMax   = false;
                }
            }
        }

        if(centralMax) return max;
        if(lowerMax)   return kthLinearColumn(arr, loCol, hiCol, midRow+1, hiRow);
        if(upperMax)   return kthLinearColumn(arr, loCol, hiCol, loRow, midRow-1);
        throw new RuntimeException("Incorrect code");
    }

    int[][] randomArray(){
        int[][] arr = new int[N][M];
        for(int i=0; i<N; i++)
            for(int j=0; j<M; j++)
                arr[i][j] = (int)(Math.random()*1000000000);
        return arr;
    }

    boolean correctPeak(int[][] arr, int row, int col){//Function that checks if arr[row][col] is a peak or not
        if(row-1>=0 && arr[row-1][col]>arr[row][col])  return false;
        if(row+1<N && arr[row+1][col]>arr[row][col])   return false;
        if(col-1>=0 && arr[row][col-1]>arr[row][col])  return false;
        if(col+1<M && arr[row][col+1]>arr[row][col])   return false;
        return true;
    }
}

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接