在Python中高效地合并重叠的元组列表

3
我正在尝试优化一个在Python中合并重叠区间的函数。给定一个区间列表,其中每个区间表示为一个元组(起始,结束),我需要合并所有重叠的区间,并返回一个覆盖输入列表中所有区间的非重叠区间列表。
例如,给定输入列表intervals = [(1, 3), (2, 6), (8, 10), (15, 18)],输出应为[(1, 6), (8, 10), (15, 18)]。
这是我尝试的方法:
def merge_intervals(intervals):
    sorted_by_lower_bound = sorted(intervals, key=lambda x: x[0])
    merged = []

    for higher in sorted_by_lower_bound:
        if not merged or merged[-1][1] < higher[0]:
            merged.append(higher)
        else:
            merged[-1] = (merged[-1][0], max(merged[-1][1], higher[1]))

    return merged

print(merge_intervals([(1, 3), (2, 6), (8, 10), (15, 18)]))

这个函数似乎工作正常,但我不确定它是否是最高效的方式或者最符合Python风格的方式。我期望能找到一个内置函数或者更直接的方法来实现这个功能,可能使用Python标准库。
问题:在Python中是否有更高效或更符合惯用法的方式来合并区间,也许可以使用一个库函数或者不同的算法来提高性能,特别是在处理大量区间时?

这里描述了一个算法(链接:https://stackoverflow.com/questions/7468948/problem-calculating-overlapping-date-ranges/7469347#7469347)。可能既不更高效也不更符合Python风格。 - undefined
3
这已经是最好的了,唯一的问题是,你不需要使用lambda,按第一个元素排序是默认的。 - undefined
  1. 过早优化是错误的。
  2. 你可以使用numba包来完成相同的工作,但是使用编译后的代码,速度更快。
- undefined
1个回答

0
你可以尝试使用Numba来运行以下代码:
from numba import jit
import numpy as np

@jit
def merge_intervals_numba(intervals):
    intervals = np.array(sorted(intervals))
    merged = []

    for interval in intervals:
        if not merged or merged[-1][1] < interval[0]:
            merged.append(interval)
        else:
            merged[-1][1] = max(merged[-1][1], interval[1])
    
    return np.array(merged)

通过使用numba,如果你处理的是一个非常大的间隔列表,你可以显著加快函数的速度。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接