如何在两个已排序数组的并集中找到第K小的元素?

116

这是一个作业问题,已经介绍了二分查找:

给定两个数组,分别包含NM个升序排列的元素,不一定唯一:
在两个数组的并集中查找第k小的元素的时间有效算法是什么?

据说需要 O(logN + logM) 的时间,其中 NM 是数组长度。

让我们把这两个数组命名为 ab。显然,我们可以忽略所有的 a[i]b[i],其中 i > k。
首先比较 a[k/2]b[k/2]。假设 b[k/2] > a[k/2],因此我们也可以丢弃所有 b[i],其中 i >k/2。

现在,我们可以在所有的 a[i](其中i < k)和所有的 b[i](其中i < k/2)中寻找答案。

接下来是什么步骤?


O(logN + logM) 只是指查找第k个元素所需的时间吗?联合操作之前可以进行预处理吗? - David Weiser
1
@David。不需要预处理。 - Michael
3
数组中允许重复吗? - David Weiser
@David 是的,允许重复。 - Michael
如果N和/或M小于k/2怎么办? - kentor
17个回答

86

希望我没有在解答你的作业问题,因为这个问题已经超过一年了。这是一个尾递归解决方案,其时间复杂度为log(len(a)+len(b))

假设:输入正确,即k[0,len(a)+len(b)]范围内。

基本情况:

  • 如果其中一个数组的长度为0,则答案是第二个数组的第k个元素。

规约步骤:

  • 如果a的中间索引加上b的中间索引小于k
    • 如果a的中间元素大于b的中间元素,则可以忽略b的前半部分,调整k
    • 否则,忽略a的前半部分,调整k
  • 如果k小于ab的中间索引之和:
    • 如果a的中间元素大于b的中间元素,则可以安全地忽略a的后半部分。
    • 否则,我们可以忽略b的后半部分。

代码:

def kthlargest(arr1, arr2, k):
    if len(arr1) == 0:
        return arr2[k]
    elif len(arr2) == 0:
        return arr1[k]

    mida1 = len(arr1) // 2  # integer division
    mida2 = len(arr2) // 2
    if mida1 + mida2 < k:
        if arr1[mida1] > arr2[mida2]:
            return kthlargest(arr1, arr2[mida2+1:], k - mida2 - 1)
        else:
            return kthlargest(arr1[mida1+1:], arr2, k - mida1 - 1)
    else:
        if arr1[mida1] > arr2[mida2]:
            return kthlargest(arr1[:mida1], arr2, k)
        else:
            return kthlargest(arr1, arr2[:mida2], k)
请注意,我的解决方案在每次调用时都会创建较小数组的新副本,这可以通过仅在原始数组上传递起始和结束索引来轻松消除。

5
为什么称其为 kthlargest(),它返回第 (k+1) 小的元素,例如在 0,1,2,3 中,1 是第二小的元素,即您的函数返回 sorted(a+b)[k] - jfs
2
我已将您的代码转换为C ++。它似乎可以工作。 - jfs
1
请问为什么将数组a和b的中间索引之和与k进行比较很重要? - Maggie
3
在缩减步骤中,重要的是要摆脱其中一个数组中与其长度成比例的若干元素,以使运行时间为对数级别。(这里我们要摆脱一半)。为了做到这一点,我们需要选择一个数组,其中一个半部分可以安全地忽略。如何做到这一点?通过自信地消除那半个我们确定不会有第k个元素的数组。 - lambdapilgrim
1
将k与数组半长度之和进行比较,可以得知哪个数组的哪一半可以被消除。如果k大于半长度之和,则可以消除其中一个数组的前半部分。反之亦然。请注意,我们不能同时从每个数组中消除一半。为了决定消除哪个数组的哪一半,我们利用了两个数组都已排序的事实,因此如果k大于半长度之和,则可以消除中间元素较小的那个数组的前半部分。反之亦然。 - lambdapilgrim
显示剩余7条评论

55

你已经明白了,继续前进!但要注意索引...

为了简化问题,我假设N和M均大于k,因此这里的复杂度为O(log k),即O(log N + log M)。

伪代码:

i = k/2
j = k - i
step = k/4
while step > 0
    if a[i-1] > b[j-1]
        i -= step
        j += step
    else
        i += step
        j -= step
    step /= 2

if a[i-1] > b[j-1]
    return a[i-1]
else
    return b[j-1]

你可以用循环不变式 i + j = k 进行演示,但我不会帮你做所有的作业 :)


15
这不是真正的证明,但算法背后的思想是保持 i + j = k,并找到这样的 i 和 j,使得 a[i-1] < b[j-1] < a[i](或者反过来)。现在由于在 b[j-1] 之前有 i 个元素在数组 'a' 中比它小,在 b[j-1] 之前有 j-1 个元素在数组 'b' 中比它小,因此 b[j-1] 是第 k 小的元素,即 i + j-1 + 1 = k。为了找到这样的 i 和 j,该算法对数组进行二分搜索。听起来有道理吗? - Jules Olléon
9
为什么 O(log k) 是 O(log n + log m)?(翻译注释:O是计算复杂度的符号,log是对数函数,k、n、m是变量。) - Rajendra Uppal
7
如果数组1中的所有值都在数组2中的值之前,那么这种方法就行不通了。 - John Kurlak
3
你最初为什么使用k/4作为步长? - Maggie
2
正如@JohnKurlak所提到的,它不适用于整个a小于b的值,请参见https://repl.it/HMYf/0 - Jeremy S.
显示剩余8条评论

40
许多人回答了“从两个已排序数组中找到第k小的元素”的问题,但通常只是提供了一般性思路,没有清晰的可工作代码或边界条件分析。
在这里,我想仔细阐述我的方法,并带上我的正确有效的Java代码,以帮助一些初学者理解。假设A1A2是两个已排序的升序数组,它们的长度分别为size1size2。我们需要从这两个数组的并集中找到第k小的元素。在此,我们合理地假设(k > 0 && k <= size1 + size2),这意味着A1A2不能都为空。
首先,让我们用一个慢速的O(k)算法来解决这个问题。该方法是比较两个数组的第一个元素A1 [0]A2 [0]。取较小的一个,比如将A1 [0]放入我们的口袋里。然后将A1 [1]A2 [0]进行比较,以此类推。重复此操作,直到我们的口袋里有了k个元素。非常重要的一点:在第一步中,我们只能将A1 [0]放入我们的口袋里。我们不能包括或排除A2 [0]
下面的O(k)代码给出了正确答案之前的一个元素。我使用它来展示我的想法和分析边界条件。在这之后,我有正确的代码:
private E kthSmallestSlowWithFault(int k) {
    int size1 = A1.length, size2 = A2.length;

    int index1 = 0, index2 = 0;
    // base case, k == 1
    if (k == 1) {
        if (size1 == 0) {
            return A2[index2];
        } else if (size2 == 0) {
            return A1[index1];
        } else if (A1[index1].compareTo(A2[index2]) < 0) {
            return A1[index1];
        } else {
            return A2[index2];
        }
    }

    /* in the next loop, we always assume there is one next element to compare with, so we can
     * commit to the smaller one. What if the last element is the kth one?
     */
    if (k == size1 + size2) {
        if (size1 == 0) {
            return A2[size2 - 1];
        } else if (size2 == 0) {
            return A1[size1 - 1];
        } else if (A1[size1 - 1].compareTo(A2[size2 - 1]) < 0) {
            return A1[size1 - 1];
        } else {
            return A2[size2 - 1];
        }
    }

    /*
     * only when k > 1, below loop will execute. In each loop, we commit to one element, till we
     * reach (index1 + index2 == k - 1) case. But the answer is not correct, always one element
     * ahead, because we didn't merge base case function into this loop yet.
     */
    int lastElementFromArray = 0;
    while (index1 + index2 < k - 1) {
        if (A1[index1].compareTo(A2[index2]) < 0) {
            index1++;
            lastElementFromArray = 1;
            // commit to one element from array A1, but that element is at (index1 - 1)!!!
        } else {
            index2++;
            lastElementFromArray = 2;
        }
    }
    if (lastElementFromArray == 1) {
        return A1[index1 - 1];
    } else {
        return A2[index2 - 1];
    }
}

在每次循环中,最强大的想法是始终使用基本情况方法。在已经确定了当前最小元素之后,我们离目标——第k个最小元素——就更近了一步。永远不要跳到中间并使自己感到困惑和迷失!
通过观察上述代码基本情况 k == 1, k == size1+size2,并结合 A1A2 不能同时为空的事实,我们可以将逻辑转化为以下更简洁的样式。
以下是一个缓慢但正确的工作代码:
private E kthSmallestSlow(int k) {
    // System.out.println("this is an O(k) speed algorithm, very concise");
    int size1 = A1.length, size2 = A2.length;

    int index1 = 0, index2 = 0;
    while (index1 + index2 < k - 1) {
        if (size1 > index1 && (size2 <= index2 || A1[index1].compareTo(A2[index2]) < 0)) {
            index1++; // here we commit to original index1 element, not the increment one!!!
        } else {
            index2++;
        }
    }
    // below is the (index1 + index2 == k - 1) base case
    // also eliminate the risk of referring to an element outside of index boundary
    if (size1 > index1 && (size2 <= index2 || A1[index1].compareTo(A2[index2]) < 0)) {
        return A1[index1];
    } else {
        return A2[index2];
    }
}

现在我们可以尝试一个更快的算法,其时间复杂度为O(log k)。同样地,比较A1[k/2]A2[k/2];如果A1[k/2]小于A2[k/2],则A1[0]A1[k/2]中的所有元素都应该被包含在我们的“口袋”里。这个想法是在每次循环中不仅仅承诺一个元素;第一步包含了k/2个元素。同样地,我们不能选择或排除A2[0]A2[k/2]中的任何元素。因此,在第一步中,我们不能超过k/2个元素。对于第二步,我们不能超过k/4个元素...。
每一步后,我们离第k个元素就更近了。同时,每一步都会变得越来越小,直到我们达到了(step == 1),也就是(k-1 == index1+index2)。然后,我们可以再次参考简单而强大的基本情况。
以下是正确的工作代码:
private E kthSmallestFast(int k) {
    // System.out.println("this is an O(log k) speed algorithm with meaningful variables name");
    int size1 = A1.length, size2 = A2.length;

    int index1 = 0, index2 = 0, step = 0;
    while (index1 + index2 < k - 1) {
        step = (k - index1 - index2) / 2;
        int step1 = index1 + step;
        int step2 = index2 + step;
        if (size1 > step1 - 1
                && (size2 <= step2 - 1 || A1[step1 - 1].compareTo(A2[step2 - 1]) < 0)) {
            index1 = step1; // commit to element at index = step1 - 1
        } else {
            index2 = step2;
        }
    }
    // the base case of (index1 + index2 == k - 1)
    if (size1 > index1 && (size2 <= index2 || A1[index1].compareTo(A2[index2]) < 0)) {
        return A1[index1];
    } else {
        return A2[index2];
    }
}

有些人可能会担心如果 (index1+index2) 跳过了 k-1,我们会错过基本情况 (k-1 == index1+index2) 吗?那是不可能的。你可以把 0.5+0.25+0.125... 相加,你永远不会超过 1。

当然,将上述代码转换为递归算法非常容易:

private E kthSmallestFastRecur(int k, int index1, int index2, int size1, int size2) {
    // System.out.println("this is an O(log k) speed algorithm with meaningful variables name");

    // the base case of (index1 + index2 == k - 1)
    if (index1 + index2 == k - 1) {
        if (size1 > index1 && (size2 <= index2 || A1[index1].compareTo(A2[index2]) < 0)) {
            return A1[index1];
        } else {
            return A2[index2];
        }
    }

    int step = (k - index1 - index2) / 2;
    int step1 = index1 + step;
    int step2 = index2 + step;
    if (size1 > step1 - 1 && (size2 <= step2 - 1 || A1[step1 - 1].compareTo(A2[step2 - 1]) < 0)) {
        index1 = step1;
    } else {
        index2 = step2;
    }
    return kthSmallestFastRecur(k, index1, index2, size1, size2);
}

希望上述分析和Java代码能帮助您理解。但是不要将我的代码抄作您的家庭作业!干杯 ;)

1
非常感谢您的出色解释和回答,+1 :) - Hengameh
在第一段代码中,应该是 else if (A1[size1 - 1].compareTo(A2[size2 - 1]) < 0) 而不是 else if (A1[size1 - 1].compareTo(A2[size2 - 1]) > 0) 吗?(在 kthSmallestSlowWithFault 代码中) - Hengameh
可能在一些步骤后(比如说15步),可以将复杂度降至O(k),因为步骤范围会非常快地缩小。 - Tianwei Chen
1
在递归调用中,A1或A2的大小都没有被减小。 - Aditya Joshee
1
我不理解步骤-1的必要性。 - Abhijit Sarkar
显示剩余4条评论

5
这里是C++迭代版本的@lambdapilgrim解决方案(请参见算法的详细解释:@lambdapilgrim's solution):
#include <cassert>
#include <iterator>

template<class RandomAccessIterator, class Compare>
typename std::iterator_traits<RandomAccessIterator>::value_type
nsmallest_iter(RandomAccessIterator firsta, RandomAccessIterator lasta,
               RandomAccessIterator firstb, RandomAccessIterator lastb,
               size_t n,
               Compare less) {
  assert(issorted(firsta, lasta, less) && issorted(firstb, lastb, less));
  for ( ; ; ) {
    assert(n < static_cast<size_t>((lasta - firsta) + (lastb - firstb)));
    if (firsta == lasta) return *(firstb + n);
    if (firstb == lastb) return *(firsta + n);

    size_t mida = (lasta - firsta) / 2;
    size_t midb = (lastb - firstb) / 2;
    if ((mida + midb) < n) {
      if (less(*(firstb + midb), *(firsta + mida))) {
        firstb += (midb + 1);
        n -= (midb + 1);
      }
      else {
        firsta += (mida + 1);
        n -= (mida + 1);
      }
    }
    else {
      if (less(*(firstb + midb), *(firsta + mida)))
        lasta = (firsta + mida);
      else
        lastb = (firstb + midb);
    }
  }
}

该算法适用于所有 0 <= n < (size(a) + size(b)) 的索引,并具有 O(log(size(a)) + log(size(b))) 的复杂度。

示例

#include <functional> // greater<>
#include <iostream>

#define SIZE(a) (sizeof(a) / sizeof(*a))

int main() {
  int a[] = {5,4,3};
  int b[] = {2,1,0};
  int k = 1; // find minimum value, the 1st smallest value in a,b

  int i = k - 1; // convert to zero-based indexing
  int v = nsmallest_iter(a, a + SIZE(a), b, b + SIZE(b),
                         SIZE(a)+SIZE(b)-1-i, std::greater<int>());
  std::cout << v << std::endl; // -> 0
  return v;
}

4

我尝试着解决以下问题:前k个数字、两个已排序数组中的第k个数字以及n个已排序数组中的第k个数字:

// require() is recognizable by node.js but not by browser;
// for running/debugging in browser, put utils.js and this file in <script> elements,
if (typeof require === "function") require("./utils.js");

// Find K largest numbers in two sorted arrays.
function k_largest(a, b, c, k) {
    var sa = a.length;
    var sb = b.length;
    if (sa + sb < k) return -1;
    var i = 0;
    var j = sa - 1;
    var m = sb - 1;
    while (i < k && j >= 0 && m >= 0) {
        if (a[j] > b[m]) {
            c[i] = a[j];
            i++;
            j--;
        } else {
            c[i] = b[m];
            i++;
            m--;
        }
    }
    debug.log(2, "i: "+ i + ", j: " + j + ", m: " + m);
    if (i === k) {
        return 0;
    } else if (j < 0) {
        while (i < k) {
            c[i++] = b[m--];
        }
    } else {
        while (i < k) c[i++] = a[j--];
    }
    return 0;
}

// find k-th largest or smallest number in 2 sorted arrays.
function kth(a, b, kd, dir){
    sa = a.length; sb = b.length;
    if (kd<1 || sa+sb < kd){
        throw "Mission Impossible! I quit!";
    }

    var k;
    //finding the kd_th largest == finding the smallest k_th;
    if (dir === 1){ k = kd;
    } else if (dir === -1){ k = sa + sb - kd + 1;}
    else throw "Direction has to be 1 (smallest) or -1 (largest).";

    return find_kth(a, b, k, sa-1, 0, sb-1, 0);
}

// find k-th smallest number in 2 sorted arrays;
function find_kth(c, d, k, cmax, cmin, dmax, dmin){

    sc = cmax-cmin+1; sd = dmax-dmin+1; k0 = k; cmin0 = cmin; dmin0 = dmin;
    debug.log(2, "=k: " + k +", sc: " + sc + ", cmax: " + cmax +", cmin: " + cmin + ", sd: " + sd +", dmax: " + dmax + ", dmin: " + dmin);

    c_comp = k0-sc;
    if (c_comp <= 0){
        cmax = cmin0 + k0-1;
    } else {
        dmin = dmin0 + c_comp-1;
        k -= c_comp-1;
    }

    d_comp = k0-sd;
    if (d_comp <= 0){
        dmax = dmin0 + k0-1;
    } else {
        cmin = cmin0 + d_comp-1;
        k -= d_comp-1;
    }
    sc = cmax-cmin+1; sd = dmax-dmin+1;

    debug.log(2, "#k: " + k +", sc: " + sc + ", cmax: " + cmax +", cmin: " + cmin + ", sd: " + sd +", dmax: " + dmax + ", dmin: " + dmin + ", c_comp: " + c_comp + ", d_comp: " + d_comp);

    if (k===1) return (c[cmin]<d[dmin] ? c[cmin] : d[dmin]);
    if (k === sc+sd) return (c[cmax]>d[dmax] ? c[cmax] : d[dmax]);

    m = Math.floor((cmax+cmin)/2);
    n = Math.floor((dmax+dmin)/2);

    debug.log(2, "m: " + m + ", n: "+n+", c[m]: "+c[m]+", d[n]: "+d[n]);

    if (c[m]<d[n]){
        if (m === cmax){ // only 1 element in c;
            return d[dmin+k-1];
        }

        k_next = k-(m-cmin+1);
        return find_kth(c, d, k_next, cmax, m+1, dmax, dmin);
    } else {
        if (n === dmax){
            return c[cmin+k-1];
        }

        k_next = k-(n-dmin+1);
        return find_kth(c, d, k_next, cmax, cmin, dmax, n+1);
    }
}

function traverse_at(a, ae, h, l, k, at, worker, wp){
    var n = ae ? ae.length : 0;
    var get_node;
    switch (at){
        case "k": get_node = function(idx){
                var node = {};
                var pos = l[idx] + Math.floor(k/n) - 1;
                if (pos<l[idx]){ node.pos = l[idx]; }
                else if (pos > h[idx]){ node.pos = h[idx];}
                else{ node.pos = pos; }

                node.idx = idx;
                node.val = a[idx][node.pos];
                debug.log(6, "pos: "+pos+"\nnode =");
                debug.log(6, node);
                return node;
            };
            break;
        case "l": get_node = function(idx){
                debug.log(6, "a["+idx+"][l["+idx+"]]: "+a[idx][l[idx]]);
                return a[idx][l[idx]];
            };
            break;
        case "h": get_node = function(idx){
                debug.log(6, "a["+idx+"][h["+idx+"]]: "+a[idx][h[idx]]);
                return a[idx][h[idx]];
            };
            break;
        case "s": get_node = function(idx){
                debug.log(6, "h["+idx+"]-l["+idx+"]+1: "+(h[idx] - l[idx] + 1));
                return h[idx] - l[idx] + 1;
            };
            break;
        default: get_node = function(){
                debug.log(1, "!!! Exception: get_node() returns null.");
                return null;
            };
            break;
    }

    worker.init();

    debug.log(6, "--* traverse_at() *--");

    var i;
    if (!wp){
        for (i=0; i<n; i++){
            worker.work(get_node(ae[i]));
        }    
    } else {
        for (i=0; i<n; i++){
            worker.work(get_node(ae[i]), wp);
        }
    }

    return worker.getResult();
}

sumKeeper = function(){
    var res = 0;
    return {
        init     : function(){ res = 0;},
        getResult: function(){
                debug.log(5, "@@ sumKeeper.getResult: returning: "+res);
                return res;
            },
        work     : function(node){ if (node!==null) res += node;}
    };
}();

maxPicker = function(){
    var res = null;
    return {
        init     : function(){ res = null;},
        getResult: function(){
                debug.log(5, "@@ maxPicker.getResult: returning: "+res);
                return res;
            },
        work     : function(node){
            if (res === null){ res = node;}
            else if (node!==null && node > res){ res = node;}
        }
    };    
}();

minPicker = function(){
    var res = null;
    return {
        init     : function(){ res = null;},
        getResult: function(){
                debug.log(5, "@@ minPicker.getResult: returning: ");
                debug.log(5, res);
                return res;
            },
        work     : function(node){
            if (res === null && node !== null){ res = node;}
            else if (node!==null &&
                node.val !==undefined &&
                node.val < res.val){ res = node; }
            else if (node!==null && node < res){ res = node;}
        }
    };  
}();

// find k-th smallest number in n sorted arrays;
// need to consider the case where some of the subarrays are taken out of the selection;
function kth_n(a, ae, k, h, l){
    var n = ae.length;
    debug.log(2, "------**  kth_n()  **-------");
    debug.log(2, "n: " +n+", k: " + k);
    debug.log(2, "ae: ["+ae+"],  len: "+ae.length);
    debug.log(2, "h: [" + h + "]");
    debug.log(2, "l: [" + l + "]");

    for (var i=0; i<n; i++){
        if (h[ae[i]]-l[ae[i]]+1>k) h[ae[i]]=l[ae[i]]+k-1;
    }
    debug.log(3, "--after reduction --");
    debug.log(3, "h: [" + h + "]");
    debug.log(3, "l: [" + l + "]");

    if (n === 1)
        return a[ae[0]][k-1]; 
    if (k === 1)
        return traverse_at(a, ae, h, l, k, "l", minPicker);
    if (k === traverse_at(a, ae, h, l, k, "s", sumKeeper))
        return traverse_at(a, ae, h, l, k, "h", maxPicker);

    var kn = traverse_at(a, ae, h, l, k, "k", minPicker);
    debug.log(3, "kn: ");
    debug.log(3, kn);

    var idx = kn.idx;
    debug.log(3, "last: k: "+k+", l["+kn.idx+"]: "+l[idx]);
    k -= kn.pos - l[idx] + 1;
    l[idx] = kn.pos + 1;
    debug.log(3, "next: "+"k: "+k+", l["+kn.idx+"]: "+l[idx]);
    if (h[idx]<l[idx]){ // all elements in a[idx] selected;
        //remove a[idx] from the arrays.
        debug.log(4, "All elements selected in a["+idx+"].");
        debug.log(5, "last ae: ["+ae+"]");
        ae.splice(ae.indexOf(idx), 1);
        h[idx] = l[idx] = "_"; // For display purpose only.
        debug.log(5, "next ae: ["+ae+"]");
    }

    return kth_n(a, ae, k, h, l);
}

function find_kth_in_arrays(a, k){

    if (!a || a.length<1 || k<1) throw "Mission Impossible!";

    var ae=[], h=[], l=[], n=0, s, ts=0;
    for (var i=0; i<a.length; i++){
        s = a[i] && a[i].length;
        if (s>0){
            ae.push(i); h.push(s-1); l.push(0);
            ts+=s;
        }
    }

    if (k>ts) throw "Too few elements to choose from!";

    return kth_n(a, ae, k, h, l);
}

/////////////////////////////////////////////////////
// tests
// To show everything: use 6.
debug.setLevel(1);

var a = [2, 3, 5, 7, 89, 223, 225, 667];
var b = [323, 555, 655, 673];
//var b = [99];
var c = [];

debug.log(1, "a = (len: " + a.length + ")");
debug.log(1, a);
debug.log(1, "b = (len: " + b.length + ")");
debug.log(1, b);

for (var k=1; k<a.length+b.length+1; k++){
    debug.log(1, "================== k: " + k + "=====================");

    if (k_largest(a, b, c, k) === 0 ){
      debug.log(1, "c = (len: "+c.length+")");
      debug.log(1, c);
    }

    try{
        result = kth(a, b, k, -1);
        debug.log(1, "===== The " + k + "-th largest number: " + result);
    } catch (e) {
        debug.log(0, "Error message from kth(): " + e);
    }
    debug.log("==================================================");
}

debug.log(1, "################# Now for the n sorted arrays ######################");
debug.log(1, "####################################################################");

x = [[1, 3, 5, 7, 9],
     [-2, 4, 6, 8, 10, 12],
     [8, 20, 33, 212, 310, 311, 623],
     [8],
     [0, 100, 700],
     [300],
     [],
     null];

debug.log(1, "x = (len: "+x.length+")");
debug.log(1, x);

for (var i=0, num=0; i<x.length; i++){
    if (x[i]!== null) num += x[i].length;
}
debug.log(1, "totoal number of elements: "+num);

// to test k in specific ranges:
var start = 0, end = 25;
for (k=start; k<end; k++){
    debug.log(1, "=========================== k: " + k + "===========================");

    try{
        result = find_kth_in_arrays(x, k);
        debug.log(1, "====== The " + k + "-th smallest number: " + result);
    } catch (e) {
        debug.log(1, "Error message from find_kth_in_arrays: " + e);
    }
    debug.log(1, "=================================================================");
}
debug.log(1, "x = (len: "+x.length+")");
debug.log(1, x);
debug.log(1, "totoal number of elements: "+num);

完整的带有调试工具的代码可以在以下链接中找到:https://github.com/brainclone/teasers/tree/master/kth


当问题要求在“查找第k小元素”中进行“下一步”时,这是一个有趣的方法。(不过,如果值是唯一的,你可以使用k2 = N+M-k。) - greybeard

3

我在这里找到的大多数答案都专注于两个数组。虽然这样做很好,但实现起来更加困难,因为我们需要处理很多边缘情况。此外,大多数实现都是递归的,这增加了递归栈的空间复杂度。因此,我决定只关注较小的数组,并在较小的数组上执行二进制搜索,在第一个数组中的指针值基础上调整第二个数组的指针。通过以下实现,我们具有O(log(min(n,m)))的时间复杂度和O(1)的空间复杂度。

    public static int kth_two_sorted(int []a, int b[],int k){
    if(a.length > b.length){
        return kth_two_sorted(b,a,k);
    }
    if(a.length + a.length < k){
        throw new RuntimeException("wrong argument");
    }
    int low = 0;
    int high = k;
    if(a.length <= k){
        high = a.length-1;
    }
    while(low <= high){
        int sizeA = low+(high - low)/2;
        int sizeB = k - sizeA;
        boolean shrinkLeft = false;
        boolean extendRight = false;
        if(sizeA != 0){
            if(sizeB !=b.length){
                if(a[sizeA-1] > b[sizeB]){
                    shrinkLeft = true;
                    high = sizeA-1;
                }
            }
        }
        if(sizeA!=a.length){
            if(sizeB!=0){
                if(a[sizeA] < b[sizeB-1]){
                    extendRight = true;
                    low = sizeA;
                }
            }
        }
        if(!shrinkLeft && !extendRight){
            return Math.max(a[sizeA-1],b[sizeB-1]) ;
        }
    }
    throw  new IllegalArgumentException("we can't be here");
}

我们有一个数组a的区间范围为[low, high],随着算法的进行,我们会逐渐缩小这个范围。sizeA表示k项中有多少项来自于a,它是由lowhigh计算得出的。同样地,sizeB也是这样定义的,但是我们需要按照sizeA+sizeB=k的方式计算该值。基于这两个边界的值,我们得出结论:我们必须向数组a的右侧扩展或向左侧收缩。如果我们停留在同一位置,则意味着我们已经找到了解决方案,我们将返回a中位置sizeA-1b中位置sizeB-1的最大值作为答案。

如果(a.length + a.length < k),那么一个地方应该是b吗? - greybeard

2

这是我基于Jules Olleon的解决方案编写的代码:

int getNth(vector<int>& v1, vector<int>& v2, int n)
{
    int step = n / 4;

    int i1 = n / 2;
    int i2 = n - i1;

    while(!(v2[i2] >= v1[i1 - 1] && v1[i1] > v2[i2 - 1]))
    {                   
        if (v1[i1 - 1] >= v2[i2 - 1])
        {
            i1 -= step;
            i2 += step;
        }
        else
        {
            i1 += step;
            i2 -= step;
        }

        step /= 2;
        if (!step) step = 1;
    }

    if (v1[i1 - 1] >= v2[i2 - 1])
        return v1[i1 - 1];
    else
        return v2[i2 - 1];
}

int main()  
{  
    int a1[] = {1,2,3,4,5,6,7,8,9};
    int a2[] = {4,6,8,10,12};

    //int a1[] = {1,2,3,4,5,6,7,8,9};
    //int a2[] = {4,6,8,10,12};

    //int a1[] = {1,7,9,10,30};
    //int a2[] = {3,5,8,11};
    vector<int> v1(a1, a1+9);
    vector<int> v2(a2, a2+5);


    cout << getNth(v1, v2, 5);
    return 0;  
}  

1
这种方法并非适用于所有情况。例如,int a2 [] = {1,2,3,4,5}; int a1 [] = {5,6,8,10,12}; getNth(a1,a2,7)。数组的索引将超出边界。 - Jay
@Jay: 数组的索引将超出[边界] … 在 vector<int> v1(a1, a1+9) 中?糟糕 :-/(如果其中一个数组短于一半 k (n),它确实会失败:downvoting.) - greybeard

2
基本上,通过这种方法,每一步可以丢弃k/2个元素。K将递归地从k => k/2 => k/4 => ...改变,直到达到1。 因此,时间复杂度为O(logk)。
在k=1时,我们得到两个数组中的最低值。
以下代码是JAVA代码。请注意,在代码中,我们从索引中减去1(-1),因为Java数组的索引从0开始而不是1,例如:k=3由数组中第2个索引表示。
private int kthElement(int[] arr1, int[] arr2, int k) {
        if (k < 1 || k > (arr1.length + arr2.length))
            return -1;
        return helper(arr1, 0, arr1.length - 1, arr2, 0, arr2.length - 1, k);
    }


private int helper(int[] arr1, int low1, int high1, int[] arr2, int low2, int high2, int k) {
    if (low1 > high1) {
        return arr2[low2 + k - 1];
    } else if (low2 > high2) {
        return arr1[low1 + k - 1];
    }
    if (k == 1) {
        return Math.min(arr1[low1], arr2[low2]);
    }
    int i = Math.min(low1 + k / 2, high1 + 1);
    int j = Math.min(low2 + k / 2, high2 + 1);
    if (arr1[i - 1] > arr2[j - 1]) {
        return helper(arr1, low1, high1, arr2, j, high2, k - (j - low2));
    } else {
        return helper(arr1, i, high1, arr2, low2, high2, k - (i - low1));
    }
}

2

这是我的解决方案。C++代码使用循环打印第k小的值以及获取第k小的值所需的迭代次数,我认为迭代次数大约是log(k)的数量级。但是,该代码要求k必须小于第一个数组的长度,这是一个限制条件。

#include <iostream>
#include <vector>
#include<math.h>
using namespace std;

template<typename comparable>
comparable kthSmallest(vector<comparable> & a, vector<comparable> & b, int k){

int idx1; // Index in the first array a
int idx2; // Index in the second array b
comparable maxVal, minValPlus;
float iter = k;
int numIterations = 0;

if(k > a.size()){ // Checks if k is larger than the size of first array
    cout << " k is larger than the first array" << endl;
    return -1;
}
else{ // If all conditions are satisfied, initialize the indexes
    idx1 = k - 1;
    idx2 = -1;
}

for ( ; ; ){
    numIterations ++;
    if(idx2 == -1 || b[idx2] <= a[idx1] ){
        maxVal = a[idx1];
        minValPlus = b[idx2 + 1];
        idx1 = idx1 - ceil(iter/2); // Binary search
        idx2 = k - idx1 - 2; // Ensures sum of indices  = k - 2
    }
    else{
        maxVal = b[idx2];
        minValPlus = a[idx1 + 1];
        idx2 = idx2 - ceil(iter/2); // Binary search
        idx1 = k - idx2 - 2; // Ensures sum of indices  = k - 2
    }
    if(minValPlus >= maxVal){ // Check if kth smallest value has been found
        cout << "The number of iterations to find the " << k << "(th) smallest value is    " << numIterations << endl;
        return maxVal;

    }
    else
        iter/=2; // Reduce search space of binary search
   }
}

int main(){
//Test Cases
    vector<int> a = {2, 4, 9, 15, 22, 34, 45, 55, 62, 67, 78, 85};
    vector<int> b = {1, 3, 6, 8, 11, 13, 15, 20, 56, 67, 89};
    // Input k < a.size()
    int kthSmallestVal;
    for (int k = 1; k <= a.size() ; k++){
        kthSmallestVal = kthSmallest<int>( a ,b ,k );
        cout << k <<" (th) smallest Value is " << kthSmallestVal << endl << endl << endl;
    }
}

2

这是我在C语言中的实现,你可以参考@Jules Olléon对算法的解释:算法的思想是我们维护i + j = k,并找到这样的i和j,使得a [i-1] < b [j-1] < a [i](或者反过来)。现在由于'a'中有i个元素比b[j-1]小,在'b'中有j-1个元素比b[j-1]小,因此b[j-1]是第i + j-1 + 1 = k个最小的元素。为了找到这样的i、j,该算法在数组上进行二分搜索。

int find_k(int A[], int m, int B[], int n, int k) {
   if (m <= 0 )return B[k-1];
   else if (n <= 0) return A[k-1];
   int i =  ( m/double (m + n))  * (k-1);
   if (i < m-1 && i<k-1) ++i;
   int j = k - 1 - i;

   int Ai_1 = (i > 0) ? A[i-1] : INT_MIN, Ai = (i<m)?A[i]:INT_MAX;
   int Bj_1 = (j > 0) ? B[j-1] : INT_MIN, Bj = (j<n)?B[j]:INT_MAX;
   if (Ai >= Bj_1 && Ai <= Bj) {
       return Ai;
   } else if (Bj >= Ai_1 && Bj <= Ai) {
       return Bj;
   }
   if (Ai < Bj_1) { // the answer can't be within A[0,...,i]
       return find_k(A+i+1, m-i-1, B, n, j);
   } else { // the answer can't be within A[0,...,i]
       return find_k(A, m, B+j+1, n-j-1, i);
   }
 }

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接