获取k个已排序数组的交集最有效的方法是什么?

15

给定 k 个已排序的数组,如何最有效地获取这些列表的交集

示例

输入:

[[1,3,5,7], [1,1,3,5,7], [1,4,7,9]] 

输出:

[1,7]
根据我在编程面试要点书中所读的内容,有一种方法可以在nlogk时间内获取k个已排序数组的并。我想知道是否有一种类似的方法可以用于交集。

有一种方法可以根据我在《编程面试要点》这本书中所读到的内容,在nlogk的时间内获取k个已排序数组的并集。我想知道是否有一种类似的方法可以用于交集。

## merge sorted arrays in nlogk time [ regular appending and merging is nlogn time ]
import heapq
def mergeArys(srtd_arys):
    heap = []
    srtd_iters = [iter(x) for x in srtd_arys]
    
    # put the first element from each srtd array onto the heap
    for idx, it in enumerate(srtd_iters):
        elem = next(it, None)
        if elem:
            heapq.heappush(heap, (elem, idx))
    
    res = []
 
    # collect results in nlogK time
    while heap:
        elem, ary = heapq.heappop(heap)
        it = srtd_iters[ary]
        res.append(elem)
        nxt = next(it, None)
        if nxt:
            heapq.heappush(heap, (nxt, ary))

编辑:显然,这是一道我正在尝试解决的算法问题,因此我不能使用任何内置函数,如集合交集等。


如果您观察到在所有k个数组的交集中出现的元素必须至少连续出现k次,那么仍然可以应用优先队列方法。如何有效地确定每个数组中已经看到了>=k个连续元素是留给读者自己思考的问题。 - wLui155
1
这些数字很小吗?在 [0, 127] 范围内吗? - Neil
似乎不需要优先队列-您可以为每个数组维护一个索引,并跟踪任何数组当前索引的最大值,再加上该值已出现在多少个数组中的计数器。这应该以O(n)时间呈现,但使用集合的简单一行代码也是O(n),因此唯一的渐进差异在于辅助空间。至少在Python中,使用集合的解决方案肯定会更快,因为工作由C编写的代码完成。 - kaya3
1
我可以读取k的值:那么n到底是什么? - greybeard
1
“列表的交集”到底是什么?在[[1, 3, 3, 5, 7, 7, 9], [2, 2, 3, 3, 5, 5, 7, 7]]中,交集是[3, 3, 5, 7, 7]还是[3, 5, 7] - greybeard
2
请看下面,我找到了一种不需要堆和索引追踪的改进方法。它的运行时间为O(n)。 - Raymond Hettinger
13个回答

16

利用排序顺序

这里有一种单次遍历的O(n)方法,不需要任何特殊的数据结构或辅助内存,只需一个输入迭代器的基本要求。

from itertools import cycle, islice

def intersection(inputs):
    "Yield the intersection of elements from multiple sorted inputs."
    # intersection(['ABBCD', 'BBDE', 'BBBDDE']) --> B B D
    n = len(inputs)
    iters = cycle(map(iter, inputs))
    try:
        candidate = next(next(iters))
        while True:
            for it in islice(iters, n-1):
                while (value := next(it)) < candidate:
                    pass
                if value != candidate:
                    candidate = value
                    break
            else:
                yield candidate
                candidate = next(next(iters))
    except StopIteration:
        return

这是一个示例会话:

>>> data = [[1,3,5,7], [1,1,3,5,7], [1,4,7,9]]
>>> list(intersection(data))
[1, 7]

>>> data = [[1,1,2,3], [1,1,4,4]]
>>> list(intersection(data))
[1, 1]

用文字描述的算法

该算法从下一个迭代器中选择下一个值作为候选项。

主循环假定已选择候选项,并循环遍历下一个n - 1个迭代器。对于这些迭代器中的每一个,它消耗值直到找到一个不小于候选项的值。如果该值大于候选项,则该值成为新的候选项,并且主循环重新开始。如果所有n - 1个值都等于候选项,则发出候选项并获取新的候选项。

任何输入迭代器耗尽后,算法就完成了。

不使用库(仅使用核心语言)实现算法

同样的算法可以在不使用itertools时正常工作(虽然不太美观)。只需将cycleislice替换为它们基于列表的等效项:

def intersection(inputs):
    "Yield the intersection of elements from multiple sorted inputs."
    # intersection(['ABBCD', 'BBDE', 'BBBDDE']) --> B B D
    n = len(inputs)
    iters = list(map(iter, inputs))
    curr_iter = 0
    try:
        it = iters[curr_iter]
        curr_iter = (curr_iter + 1) % n
        candidate = next(it)
        while True:
            for i in range(n - 1):
                it = iters[curr_iter]
                curr_iter = (curr_iter + 1) % n
                while (value := next(it)) < candidate:
                    pass
                if value != candidate:
                    candidate = value
                    break
            else:
                yield candidate
                it = iters[curr_iter]
                curr_iter = (curr_iter + 1) % n
                candidate = next(it)
    except StopIteration:
        return

谢谢,我会仔细看一下。这个方法仍然比nlogk的方法更差,对吧?我假设n是指总元素数而不是每个数组的平均元素数。 - identical123456
在对另一个答案的评论中,OP指出他们希望重复的输入值在输出中出现多次:[[1,1,2,3],[1,1,4,4]]应该给出输出[1,1],而不是[1] - Oli
VALUE是常量,不是吗? - grovkin
@identical123456 这比 n log k 方法更快。此外,它适用于迭代器输入,仅进行一次遍历,并且不需要辅助内存。我相信它是可以证明最优的。 - Raymond Hettinger

5

是的,这是可能的!我已经修改了您的示例代码以实现此功能。

我的答案假设您的问题是关于算法的——如果您想要使用set的最快运行代码,请参阅其他答案。

这将保持O(n log(k))时间复杂度:从if lowest != elem or ary != times_seen:unbench_all = False之间的所有代码都是O(log(k))。主循环内有一个嵌套循环(for unbenched in range(times_seen):),但这仅运行times_seen次,并且times_seen最初为0,在每次运行此内部循环后重置为0,并且每次主循环迭代只能增加一次,因此内部循环无法总共执行更多的迭代比主循环。因此,由于内部循环内的代码是O(log(k)),并且运行的次数最多与外部循环一样多,而外部循环是O(log(k)),并且运行n次,该算法是O(n log(k))

这个算法依赖于Python中元组的比较方式。它比较元组的第一项,如果它们相等,则比较第二项(即(x, a)<(x, b)当且仅当a<b时为true)。 在这个算法中,与问题中示例代码不同的是,当从堆中弹出一个项目时,它不一定会在同一次迭代中再次推回。由于我们需要检查所有子列表是否包含相同的数字,因此在从堆中弹出数字后,它的子列表就被我称为"加入队列",意味着它不会被再次添加到堆中了。这是因为我们需要检查其他子列表是否包含相同的项,所以现在不需要添加这个子列表的下一个项目。

如果一个数字确实存在于所有子列表中,那么堆将看起来像[(2,0),(2,1),(2,2),(2,3)],所有元组的第一个元素都相同,所以heappop将选择具有最低子列表索引的索引0,这意味着首先弹出索引0并增加times_seen到1,然后弹出索引1并增加times_seen到2——如果ary不等于times_seen,则该数字不在所有子列表的交集中。这导致条件if lowest != elem or ary != times_seen:,它决定了一个数字何时不应该在结果中。这个if语句的else分支是为仍然可能出现在结果中的情况而设计的。

unbench_all 布尔值用于从工作台中移除所有子列表 - 这可能是因为以下原因:

  1. 当前的数字已知不在子列表的交集中
  2. 已知它在子列表的交集中

unbench_allTrue 时,被移除的所有子列表将重新添加到堆中。已知这些子列表的索引范围为 range(times_seen),因为算法仅在它们拥有相同数字时才从堆中移除条目,所以它们必须按照索引顺序连续从索引 0 开始移除,并且其中必须有 times_seen 个。这意味着我们不需要存储被禁用的子列表的索引,只需存储被禁用的子列表数量即可。

import heapq


def mergeArys(srtd_arys):
    heap = []
    srtd_iters = [iter(x) for x in srtd_arys]

    # put the first element from each srtd array onto the heap
    for idx, it in enumerate(srtd_iters):
        elem = next(it, None)
        if elem:
            heapq.heappush(heap, (elem, idx))

    res = []

    # the number of tims that the current number has been seen
    times_seen = 0

    # the lowest number from the heap - currently checking if the first numbers in all sub-lists are equal to this
    lowest = heap[0][0] if heap else None

    # collect results in nlogK time
    while heap:
        elem, ary = heap[0]
        unbench_all = True

        if lowest != elem or ary != times_seen:
            if lowest == elem:
                heapq.heappop(heap)
                it = srtd_iters[ary]
                nxt = next(it, None)
                if nxt:
                    heapq.heappush(heap, (nxt, ary))
        else:
            heapq.heappop(heap)
            times_seen += 1

            if times_seen == len(srtd_arys):
                res.append(elem)
            else:
                unbench_all = False

        if unbench_all:
            for unbenched in range(times_seen):
                unbenched_it = srtd_iters[unbenched]
                nxt = next(unbenched_it, None)
                if nxt:
                    heapq.heappush(heap, (nxt, unbenched))
            times_seen = 0
            if heap:
                lowest = heap[0][0]

    return res


if __name__ == '__main__':
    a1 = [[1, 3, 5, 7], [1, 1, 3, 5, 7], [1, 4, 7, 9]]
    a2 = [[1, 1], [1, 1, 2, 2, 3]]
    for arys in [a1, a2]:
        print(mergeArys(arys))

如果你愿意,可以写出一个等效的算法,如下所示:

def mergeArys(srtd_arys):
    heap = []
    srtd_iters = [iter(x) for x in srtd_arys]

    # put the first element from each srtd array onto the heap
    for idx, it in enumerate(srtd_iters):
        elem = next(it, None)
        if elem:
            heapq.heappush(heap, (elem, idx))

    res = []

    # collect results in nlogK time
    while heap:
        elem, ary = heap[0]
        lowest = elem
        keep_elem = True
        for i in range(len(srtd_arys)):
            elem, ary = heap[0]
            if lowest != elem or ary != i:
                if ary != i:
                    heapq.heappop(heap)
                    it = srtd_iters[ary]
                    nxt = next(it, None)
                    if nxt:
                        heapq.heappush(heap, (nxt, ary))

                keep_elem = False
                i -= 1
                break
            heapq.heappop(heap)

        if keep_elem:
            res.append(elem)

        for unbenched in range(i+1):
            unbenched_it = srtd_iters[unbenched]
            nxt = next(unbenched_it, None)
            if nxt:
                heapq.heappush(heap, (nxt, unbenched))

        if len(heap) < len(srtd_arys):
            heap = []

    return res


这是正确的。但是,我复制粘贴了你的代码,对于你的输入,它只输出[1]而不是[1,1]。[[1,1],[1,1,2,2,3]] 我错过了什么吗? - identical123456
@identical123456,你没有错过任何东西,我还没有实现这个功能。正在努力中! - Oli
@identical123456 经过算法编辑,可以正确处理重复项。 - Oli
谢谢。我会尽快研究这个。 - identical123456
看下面,我找到了一种不需要堆和不需要跟踪索引的改进方法。它的运行时间是O(n)。 - Raymond Hettinger
显示剩余3条评论

3

您可以使用内置集合和集合交集:

d = [[1,3,5,7],[1,1,3,5,7],[1,4,7,9]] 
result = set(d[0]).intersection(*d[1:])
{1, 7}

这绝对应该是OP所要求的最有效的方法。值得一加。 - Onyambu
@Onyambu:这个问题涉及到多重集合,但在修订版2之前并没有说得太多。(然后,它并不一定比其他O(n)方法更快。) - greybeard

2
您可以使用reduce
from functools import reduce

a = [[1,3,5,7],[1,1,3,5,7],[1,4,7,9]] 
reduce(lambda x, y: x & set(y), a[1:], set(a[0]))
 {1, 7}

这也是我首先想到的解决方案,但是(1)OP正在从头实现算法,不能使用内置函数如intersection(它隐式地使用了x & set(y)),而且(2)这些列表被视为多重集合,因此,例如,如果所有列表中都出现了1, 1,那么结果中也应该出现1, 1,但是这种解决方案会失败。 - Richard Ambler
@RixhardAmbler 这个问题在提出后5分钟就得到了回答。问题已经因为清晰度的原因进行了多次编辑,并添加了一些信息。 - Onyambu

1
我想出了这个算法。它的时间复杂度不超过O(nk),但不确定是否足够好。该算法的重点是,你可以为每个数组设置k个索引,在每次迭代中找到交集中下一个元素的索引,并增加每个索引,直到超出数组的边界且交集中没有更多的项。技巧在于,由于数组已排序,你可以查看两个不同数组中的两个元素,如果其中一个大于另一个,则可以立即放弃另一个,因为你知道你不能拥有比正在查看的数字更小的数字。该算法的最坏情况是每个索引都将增加到边界,这需要k*n的时间,因为索引不能减少其值。
  inter = []

  for n in range(len(arrays[0])):
    if indexes[0] >= len(arrays[0]):
        return inter
    for i in range(1,k):
      if indexes[i] >= len(arrays[i]):
        return inter
      while indexes[i] < len(arrays[i]) and arrays[i][indexes[i]] < arrays[0][indexes[0]]:
        indexes[i] += 1
      while indexes[i] < len(arrays[i]) and indexes[0] < len(arrays[0]) and arrays[i][indexes[i]] > arrays[0][indexes[0]]:
        indexes[0] += 1
    if indexes[0] < len(arrays[0]):
      inter.append(arrays[0][indexes[0]])
    indexes = [idx+1 for idx in indexes]
  return inter

1
你说我们不能使用集合,但是字典/哈希表呢?(是的,我知道它们基本上是一样的):D
如果可以的话,这是一个相当简单的方法(请原谅py2语法):
arrays = [[1,3,5,7],[1,1,3,5,7],[1,4,7,9]]
counts = {}

for ar in arrays:
  last = None
  for i in ar:
    if (i != last):
      counts[i] = counts.get(i, 0) + 1
    last = i

N = len(arrays)
intersection = [i for i, n in counts.iteritems() if n == N]
print intersection

1

以上方法中,有些并没有涵盖到列表的每个子集中存在重复元素的情况。下面的代码实现了这种交集,如果列表的子集中有很多重复元素,它将更加高效 :) 如果不确定是否存在重复元素,建议使用来自 collections 的 Counter from collections import Counter。自定义计数器函数是为了提高处理大量重复元素的效率。但仍然无法击败 Raymond Hettinger 的实现。

def counter(my_list):
    my_list = sorted(my_list)
    first_val, *all_val = my_list
    p_index = my_list.index(first_val)
    my_counter = {}
    for item in all_val:
         c_index = my_list.index(item)
         diff = abs(c_index-p_index)
         p_index = c_index
         my_counter[first_val] = diff 
         first_val = item
    c_index = my_list.index(item)
    diff = len(my_list) - c_index
    my_counter[first_val] = diff 
    return my_counter

def my_func(data):
    if not data or not isinstance(data, list):
        return
    # get the first value
    first_val, *all_val = data
    if not isinstance(first_val, list):
        return
    # count items in first value
    p = counter(first_val) # counter({1: 2, 3: 1, 5: 1, 7: 1})
    # collect all common items and calculate the minimum occurance in intersection
    for val in all_val:
        # collecting common items
        c = counter(val)
        # calculate the minimum occurance in intersection
        inner_dict = {}
        for inner_val in set(c).intersection(set(p)):
            inner_dict[inner_val] = min(p[inner_val], c[inner_val])
        p = inner_dict
    # >>>p
    # {1: 2, 7: 1}
    # Sort by keys of counter
    sorted_items = sorted(p.items(), key=lambda x:x[0]) # [(1, 2), (7, 1)]
    result=[i[0] for i in sorted_items for _ in range(i[1])] # [1, 1, 7]
    return result

这里是示例示例

>>> data = [[1,3,5,7],[1,1,3,5,7],[1,4,7,9]]
>>> my_func(data=data)
[1, 7]
>>> data = [[1,1,3,5,7],[1,1,3,5,7],[1,1,4,7,9]]
>>> my_func(data=data)
[1, 1, 7]

1
你可以使用函数heapq.merge, chain.from_iterablegroupby来执行以下操作。
from heapq import merge
from itertools import groupby, chain

ls = [[1, 3, 5, 7], [1, 1, 3, 5, 7], [1, 4, 7, 9]]


def index_groups(lst):
    """[1, 1, 3, 5, 7] -> [(1, 0), (1, 1), (3, 0), (5, 0), (7, 0)]"""
    return chain.from_iterable(((e, i) for i, e in enumerate(group)) for k, group in groupby(lst))


iterables = (index_groups(li) for li in ls)
flat = merge(*iterables)
res = [k for (k, _), g in groupby(flat) if sum(1 for _ in g) == len(ls)]
print(res)

输出

[1, 7]

这个想法是使用枚举(enumerate)给相同列表中的相等值赋予额外的价值,以区分它们(请参见函数index_groups)。
该算法的复杂度为O(n),其中n是输入中每个列表长度的总和。
请注意,对于(每个列表额外增加1)的输出:
ls = [[1, 1, 3, 5, 7], [1, 1, 3, 5, 7], [1, 1, 4, 7, 9]]

is:

[1, 1, 7]

1
与Raymond Hettinger的解决方案相同,但使用更基本的Python代码实现:
def intersection(arrays, unique: bool=False):
    result = []
    if not len(arrays) or any(not len(array) for array in arrays):
        return result

    pointers = [0] * len(arrays)

    target = arrays[0][0]
    start_step = 0
    current_step = 1
    while True:
        idx = current_step % len(arrays)
        array = arrays[idx]

        while pointers[idx] < len(array) and array[pointers[idx]] < target:
            pointers[idx] += 1

        if pointers[idx] < len(array) and array[pointers[idx]] > target:
            target = array[pointers[idx]]
            start_step = current_step
            current_step += 1
            continue

        if unique:
            while (
                pointers[idx] + 1 < len(array)
                and array[pointers[idx]] == array[pointers[idx] + 1]
            ):
                pointers[idx] += 1

        if (current_step - start_step) == len(arrays):
            result.append(target)
            for other_idx, other_array in enumerate(arrays):
                pointers[other_idx] += 1
            if pointers[idx] < len(array):
                target = array[pointers[idx]]
                start_step = current_step

        if pointers[idx] == len(array):
            return result

        current_step += 1

1
这里提供一个O(n)的答案(其中n = sum(len(sublist) for sublist in data))。
from itertools import cycle

def intersection(data):
    result = []    
    maxval = float("-inf")
    consecutive = 0
    try:
        for sublist in cycle(iter(sublist) for sublist in data):

            value = next(sublist)
            while value < maxval:
                value = next(sublist)

            if value > maxval:
                maxval = value
                consecutive = 0
                continue

            consecutive += 1
            if consecutive >= len(data)-1:
                result.append(maxval)
                consecutive = 0

    except StopIteration:
        return result

print(intersection([[1,3,5,7], [1,1,3,5,7], [1,4,7,9]]))

[1, 7]


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接