已排序矩阵中的第K小元素

26
这是一个面试问题。
找到一个有排序的行和列的矩阵中第K小的元素。 是否正确,第K小的元素是像a[i, j]这样的元素,满足i + j = K

1
矩阵是如何排序的?每行或每列中的数字是否都是递增的? - V-X
是的,每行和每列中的数字都按递增顺序排序。 - Michael
1
很容易举出反例来证明该陈述是错误的。 - NPE
这个解决方案显然是不正确的。例如,第一个元素可以在角落找到,但第二个数字可能是两个相邻元素中的一个。第三个元素可能在5个可能的索引之一。你必须采用某种修改的二分查找算法。 - V-X
7个回答

42

错误。

考虑一个简单的矩阵,例如这个:

1 3 5
2 4 6
7 8 9

9是最大(第9小)的元素。但是9在A [3,3]处,而3 + 3!= 9。(无论使用什么索引约定,都不可能成立)。


你可以通过逐步合并行并使用堆来高效地找到最小元素以O(k log n)的时间解决此问题。
基本上,你将第一列的元素放入堆中并跟踪它们来自哪一行。在每个步骤中,你从堆中删除最小元素并推送来自它所在行的下一个元素(如果到达该行的末尾,则不推送任何内容)。删除最小值和添加新元素都需要O(log n)的成本。在第j步中,你删除第j小的元素,因此经过k步后,你完成了总共O(k log n)的操作(其中n是矩阵中的行数)。
对于上面的矩阵,您最初在堆中使用1,2,7。您删除1并添加3(因为第一行是1 3 5)以得到2,3,7。您删除2并添加4以获得3,4,7。删除3并添加5以获得4,5,7。删除4并添加6以获得5,6,7。请注意,我们按全局排序顺序删除元素。您可以看到,继续这个过程将在k次迭代后产生第k个最小元素。
(如果矩阵的行数大于列数,则改为对列进行操作可减少运行时间。)

@grijeshchauhan:好吧,如果这样假设是正确的。但是这种假设太过严格了。 - nneonneo
不必使用大小为n的堆,我们可以使用大小为k的堆,并将第一列的前k个元素推入其中,然后继续上述步骤。这将给我们带来O(k log k)的复杂度。 - Sanjay Verma
2
如果只有行或列被排序(本质上,这是外部排序中的n路合并),则此解决方案效果最佳。@user1987143的更好,因为它利用了行和列都已排序的事实。 - sinoTrinity
3
如果你将行数定义为n,然后用第一列初始化你的最小堆,那么运行时间不是n+k log(n)吗?(你似乎没有考虑到初始化步骤在计算运行时间时)。 - ChaimKut
@ChaimKut 我很确定它是O(n + klog n)。 - user5965026
显示剩余8条评论

33

O(k log(k)) 解决方案。

  • 构建一个最小堆。

  • (0,0) 添加到堆中。当我们还没有找到第 k 小的元素时,从堆中删除顶部元素 (x,y),并添加下一组还未被访问过的两个元素 [(x+1,y)(x,y+1)]

我们在大小为 O(k) 的堆上执行了 O(k) 次操作,因此复杂度为 O(k log(k))


你能给这个加点格式吗?现在的样子有点难以阅读。 - StormeHawke
你确定这是正确的吗?我的意思是,即使我也认为一样,但是你的答案获得的投票数与其他答案相比令人惊讶,尽管你的解决方案的复杂性比其他答案更好。 - Akashdeep Saluja
我认为这是正确的,而且运行时间比被接受的答案更好。 - Meow
1
同意复杂度为O(k log (k))。简单解释:堆弹出的复杂度为O(log(heapsize))。这里的堆大小从1开始,每次迭代增加1直到k。堆大小在大多数迭代中增加一个单位,因为在每个阶段都会删除一个元素并添加两个即右侧和下方的单元格(除了矩阵的边缘)。因此,时间复杂度约为O(log(1))+O(log(2))+...+O(log(k))~=k log(k)。 - Pranjal Mittal
1
@user1987143,我们不需要维护已访问的节点以避免重复吗? - Govind Prabhu
显示剩余3条评论

7

这个问题可以使用二分查找和优化计数在已排序矩阵中解决。二分查找需要O(log(n))的时间,并且对于每个搜索值,它平均需要n次迭代才能找到小于所搜索单个数的数字。二分查找的搜索空间被限制在矩阵中最小值mat[0][0]和最大值mat[n-1][n-1]之间。

对于从二分查找中选择的每个数字,我们需要计算小于或等于该特定数字的数字。因此,第k^th个最小数字可以被找到。

为更好地理解,您可以参考此视频:

https://www.youtube.com/watch?v=G5wLN4UweAM&t=145s


1
从左上角(0,0)开始遍历矩阵,并使用二叉堆存储"frontier"(已访问部分与其余部分之间的边界)。
Java实现:
private static class Cell implements Comparable<Cell> {

    private final int x;
    private final int y;
    private final int value;

    public Cell(int x, int y, int value) {
        this.x = x;
        this.y = y;
        this.value = value;
    }

    @Override
    public int compareTo(Cell that) {
        return this.value - that.value;
    }

}

private static int findMin(int[][] matrix, int k) {

    int min = matrix[0][0];

    PriorityQueue<Cell> frontier = new PriorityQueue<>();
    frontier.add(new Cell(0, 0, min));

    while (k > 1) {

        Cell poll = frontier.remove();

        if (poll.y + 1 < matrix[poll.x].length) frontier.add(new Cell(poll.x, poll.y + 1, matrix[poll.x][poll.y + 1]));
        if (poll.x + 1 < matrix.length) frontier.add(new Cell(poll.x + 1, poll.y, matrix[poll.x + 1][poll.y]));

        if (poll.value > min) {
            min = poll.value;
            k--;
        }

    }

    return min;

}

0

正如之前提到的,最简单的方法是构建一个小根堆。这里是使用PriorityQueue实现的Java代码:

private int kthSmallestUsingHeap(int[][] matrix, int k) {

    int n = matrix.length;

    // This is not necessary since this is the default Int comparator behavior
    Comparator<Integer> comparator = new Comparator<Integer>() {
        @Override
        public int compare(Integer o1, Integer o2) {
            return o1 - o2;
        }
    };

    // building a minHeap
    PriorityQueue<Integer> pq = new PriorityQueue<>(n*n, comparator);
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < n; j++) {
            pq.add(matrix[i][j]);
        }
    }

    int ans = -1;
    // remove the min element k times
    for (int i = 0; i < k; i++) {
        ans = pq.poll();
    }

    return ans;
}

0

矩阵中第K小的元素:

该问题可以通过以下方式缩小范围。

如果k为20,则取k*k的矩阵(答案一定位于其中)。

现在,您可以反复将行成对合并以构建排序数组,然后找到第K个最小的数字。


-1
//int arr[][] = {{1, 5, 10, 14},
//        {2, 7, 12, 16},
//        {4, 10, 15, 20},
//        {6, 13, 19, 22}
//};
// O(k) Solution
public static int myKthElement(int arr[][], int k) {
    int lRow = 1;
    int lCol = 0;
    int rRow = 0;
    int rCol = 1;
    int count = 1;

    int row = 0;
    int col = 0;

    if (k == 1) {
        return arr[row][col];
    }

    int n = arr.length;
    if (k > n * n) {
        return -1;
    }

    while (count < k) {
        count++;

        if (arr[lRow][lCol] < arr[rRow][rCol]) {
            row = lRow;
            col = lCol;

            if (lRow < n - 1) {
                lRow++;
            } else {
                if (lCol < n - 1) {
                    lCol++;
                }

                if (rRow < n - 1) {
                    lRow = rRow + 1;
                }
            }
        } else {
            row = rRow;
            col = rCol;

            if (rCol < n - 1) {
                rCol++;
            } else {
                if (rRow < n - 1) {
                    rRow++;
                }
                if (lCol < n - 1) {
                    rCol = lCol + 1;
                }
            }
        }
    }

    return arr[row][col];
}

请在您的答案中添加一些内容,以阐述您的方法或解决方案,除了代码之外,这样对于任何正在查看答案的人来说都会更有意义。 - kabirbaidhya

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接