两组区间的差异

11
我正在尝试编写一些代码来计算两组区间 A - B 的差,区间端点为整数,但我很难想出有效的解决方案,如有任何建议,将不胜感激。
示例:[(1, 4), (7, 9)] - [(3,5)] = [(1, 3), (7, 9)]

这是迄今为止我所尝试的最佳方法(两个列表已经排序)

class tp():
   def __repr__(self):
       return '(%.2f,%.2f)' % (self.start, self.end)
   def __init__(self,start,end): 
       self.start=start
       self.end=end



z=[tp(3,5)] #intervals to be subtracted
s=[tp(1, 4)),tp(7, 9), tp(3,4),tp(4,6)]

for x in s[:]:
   if z.end < x.start:
    break
   elif z.start < x.start and z.end > x.start and z.end < x.end:
    x.start=z.end
   elif z.start < x.start and z.end > x.end:
    s.remove(x)
   elif z.start > x.start and z.end < x.end:
    s.append(tp(x.start,z.start))
    s.append(tp(z.end,x.end))
    s.remove(x)
   elif z.start > x.start and z.start < x.end and z.end > x.end:
    x.end=z.start
   elif z.start > x.end:
    continue

我尝试对集合1中的每个元素与集合2中的每个元素之间进行差异,这样可以得到正确的答案,但我希望有更高效的方法。 - desprategstreamer
如果你想要更高效的解决方案,你应该向我们展示你已经完成的工作,我们会给予帮助。 - Gabriel L.
好的,我已经编辑了问题。 - desprategstreamer
你需要Python方面的帮助吗?另外,最后一段代码看起来缩进不正确。 - godel9
你说“这两个列表已经排序了”,但是你的例子s=[tp(1, 4)),tp(7, 9), tp(3,4),tp(4,6)]并没有按照任何明显的排序方式进行排序。此外,你的代码似乎假设z是一个单独的tp,但是你的例子显示它是一个列表。能否请你澄清一下你的期望? - rici
3个回答

15

唯一让这个操作变得高效的方法是保持区间列表的排序和非重叠 (可以在 O(n log n)内完成)。详见下文注释。

当两个列表都是排序和非重叠时,任何集合操作(并集、交集、差集、对称差集)都可以使用简单的合并完成。

合并操作很简单:同时按顺序循环遍历两个参数的端点。(注意,由于我们要求区间不重叠,所以每个区间列表的端点已经排序。) 对于发现的每个端点,决定是否将其包括在结果中。若结果目前有奇数个端点且新端点不在结果中,则将其添加到结果中;同样地,如果结果目前具有偶数个端点且新端点在结果中,则将其添加到结果中。在此操作结束后,结果是一个端点列表,交替包含区间的开始和结束。

以下是Python示例:

# In all of the following, the list of intervals must be sorted and 
# non-overlapping. We also assume that the intervals are half-open, so
# that x is in tp(start, end) iff start <= x and x < end.

def flatten(list_of_tps):
    """Convert a list of intervals to a list of endpoints"""
    return reduce(lambda ls, ival: ls + [ival.start, ival.end],
                  list_of_tps, [])
    
def unflatten(list_of_endpoints):
    """Convert a list of endpoints, with an optional terminating sentinel,
       into a list of intervals"""
    return [tp(list_of_endpoints[i], list_of_endpoints[i + 1])
            for i in range(0, len(list_of_endpoints) - 1, 2)]
    
def merge(a_tps, b_tps, op):
    """Merge two lists of intervals according to the boolean function op"""
    a_endpoints = flatten(a_tps)
    b_endpoints = flatten(b_tps)
    
    sentinel = max(a_endpoints[-1], b_endpoints[-1]) + 1
    a_endpoints += [sentinel]
    b_endpoints += [sentinel]
    
    a_index = 0
    b_index = 0
    
    res = []
    
    scan = min(a_endpoints[0], b_endpoints[0])
    while scan < sentinel:
        in_a = not ((scan < a_endpoints[a_index]) ^ (a_index % 2))
        in_b = not ((scan < b_endpoints[b_index]) ^ (b_index % 2))
        in_res = op(in_a, in_b)
        
        if in_res ^ (len(res) % 2): res += [scan]
        if scan == a_endpoints[a_index]: a_index += 1
        if scan == b_endpoints[b_index]: b_index += 1
        scan = min(a_endpoints[a_index], b_endpoints[b_index])
    
    return unflatten(res)

def interval_diff(a, b):
    return merge(a, b, lambda in_a, in_b: in_a and not in_b)

def interval_union(a, b):
    return merge(a, b, lambda in_a, in_b: in_a or in_b)

def interval_intersect(a, b):
    return merge(a, b, lambda in_a, in_b: in_a and in_b)

注意事项

  1. 区间 [a, b)[b, c) 不会重叠,因为它们是不相交的;b 只属于第二个区间。将这两个区间合并后仍然是 [a,c)。但为了本答案中的函数更加精确,最好将区间定义延伸至不相邻的情况,包括区间相邻的情况;否则,可能会在输出中不必要地包含相邻点。(这严格来说不算错误,但如果函数的输出是确定性的,则易于测试。)

    下面是一个示例函数实现,将任意一组区间标准化为排序的、不重叠的区间。

def interval_normalise(a):
    rv = sorted(a, key = lambda x: x.start)
    out = 0
    for scan in range(1, len(rv)):
        if rv[scan].start > rv[out].end:
            if rv[out].end > rv[out].start: out += 1
            rv[out] = rv[scan]
        elif rv[scan].end > rv[out].end:
            rv[out] = tp(rv[out].start, rv[scan].end)
    if rv and rv[out].end > rv[out].start: out += 1
    return rv[:out]

谢谢,您的答案真的帮了我,再次抱歉问题没有表达清楚,我在这方面是新手。 - desprategstreamer
1
@tommy.carstensen:这并不是一种替代方案,实际上它是完全相同的函数,只是因为某些我不完全理解的原因而被放置在functools模块中。我在我的代码中添加了一个注释。 - rici
1
@boris:我的答案的第一行说列表需要“排序且不重叠”。你的示例不符合这些限制。我在我的答案中添加了一个归一化函数以防有用。 - rici
@boris:事实上,重新阅读答案后我发现它在叙述中提到了三次不重叠的要求,还在Python代码的注释中提到了一次。 :-) - rici
我同意其他人的看法,这是一个很好的解决方案。一个小的简化是将 in_a = not ((scan < a_endpoints[a_index]) ^ (a_index % 2)) 替换为 in_a = (scan < a_endpoints[a_index]) == (a_index % 2),并且对于 in_b 赋值也是如此。这基于 not(x ^ y)x == y 的等价性。 - Peter Simon
显示剩余9条评论

2
这可以通过扫描线算法来解决。其思想是将两个集合中所有区间的起点存放在一个排序数组中,将终点存放在另一个排序数组中,并标记它们属于哪个集合。例如:
       A              B
[(1, 4), (7, 9)] - [(3,5)]
A: start:[1,7] end:[4,9], B: start:[3]end:[5]
start:[(1,a),(3,b),(7,a)]
end: [(4,a),(5,b),(9,a)]

现在有两个指针,一个指向每个数组的开头。在循环中,递增指向最低值的指针,添加从a开始到b或a结束的间隔。例如,对于上面的示例,我们将按以下顺序迭代点。
(1,a) (3,b) (4,a) (5,b) (7,a) (9,a)
# and adding intervals where we have seen an start a and an end a or b
(1,3) (7,9)

这导致在区间数量方面有线性解决方案。

谢谢回答,但我认为这个方法在处理 [4 7] - [2,5] 时会失败。答案应该是 [5 7],但我们不考虑它,因为它不以 a 开头。 - desprategstreamer

0

使用numpy的另一种实现方式。我认为使用整数端点更自然,因此假设区间是闭合的。 对于下面我建议的方法,我们绝对需要注意包括负无穷和正无穷在内的(半闭合)区间。

def merge_intervals(intervals):
    # Normalize to sorted non-overlaping intervals. Aimilar idea as in
    # https://www.geeksforgeeks.org/merging-intervals/
    if len(intervals)==0: return intervals 
    assert np.all(intervals[:,0]<=intervals[:,1]), f"merge_intervals: intervals not well defined. intervals={intervals}"
    if len(intervals)==1: return intervals    
    intervals = np.sort(intervals.copy(),axis=0)
    stack = []
    # insert first interval into stack
    stack.append(intervals[0])
    for i in intervals[1:]:
        # Check for overlapping interval,
        # if interval overlap
        if i[0] > stack[-1][1]+1:
            stack.append(i)
        else:
            stack[-1][1] = max(stack[-1][1], i[1])
    return np.array(stack)

def union_intervals(a,b):
    return merge_intervals(np.r_[a,b])

# The folowing is the key function. Needs to work  
# well with infinities and empty sets.
def complement_intervals(a): 
    if len(a)==0: return np.array([[-np.inf,np.inf]])
    a_normalized = merge_intervals(a)
    result0 = np.r_[-np.inf,a_normalized[:,1]+1]
    result1 = np.r_[a_normalized[:,0]-1,np.inf] 
    non_empty = np.logical_and(result0 < np.inf, result1 > -np.inf)
    result = np.c_[result0[non_empty],result1[non_empty]]
    if np.array_equal(result,np.array([[np.inf,-np.inf]])):
        result = np.array([])
    return merge_intervals(result)

def intersection_intervals(a,b):
    union_of_complements = union_intervals(complement_intervals(a),complement_intervals(b))
    return  complement_intervals(union_of_complements)

def difference_intervals(a,b):
    return intersection_intervals(a,complement_intervals(b))

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接