在Python中使用zip迭代器断言长度相等

49

我正在寻找一种美好的方式来zip多个可迭代对象,并在可迭代对象长度不相等时引发异常。

如果可迭代对象是列表或具有len方法,则此解决方案干净简单:

def zip_equal(it1, it2):
    if len(it1) != len(it2):
        raise ValueError("Lengths of iterables are different")
    return zip(it1, it2)

然而,如果it1it2是生成器,则上一个函数会失败,因为长度未定义TypeError: object of type 'generator' has no len()

我想象一下itertools模块提供了一种简单的实现方式,但到目前为止我还没有找到它。我已经想出了这个自制的解决方案:

def zip_equal(it1, it2):
    exhausted = False
    while True:
        try:
            el1 = next(it1)
            if exhausted: # in a previous iteration it2 was exhausted but it1 still has elements
                raise ValueError("it1 and it2 have different lengths")
        except StopIteration:
            exhausted = True
            # it2 must be exhausted too.
        try:
            el2 = next(it2)
            # here it2 is not exhausted.
            if exhausted:  # it1 was exhausted => raise
                raise ValueError("it1 and it2 have different lengths")
        except StopIteration:
            # here it2 is exhausted
            if not exhausted:
                # but it1 was not exhausted => raise
                raise ValueError("it1 and it2 have different lengths")
            exhausted = True
        if not exhausted:
            yield (el1, el2)
        else:
            return

该解决方案可以使用以下代码进行测试:

it1 = (x for x in ['a', 'b', 'c'])  # it1 has length 3
it2 = (x for x in [0, 1, 2, 3])     # it2 has length 4
list(zip_equal(it1, it2))           # len(it1) < len(it2) => raise
it1 = (x for x in ['a', 'b', 'c'])  # it1 has length 3
it2 = (x for x in [0, 1, 2, 3])     # it2 has length 4
list(zip_equal(it2, it1))           # len(it2) > len(it1) => raise
it1 = (x for x in ['a', 'b', 'c', 'd'])  # it1 has length 4
it2 = (x for x in [0, 1, 2, 3])          # it2 has length 4
list(zip_equal(it1, it2))                # like zip (or izip in python2)

有没有其他替代方案被我忽略了? 有没有更简单的实现我的zip_equal函数?

更新:

  • 需要Python 3.10或更新版本,请参见Asocia的答案
  • 彻底的性能基准测试和在Python<3.10上表现最佳的解决方案:Stefan的答案
  • 没有外部依赖的简单答案: Martijn Pieters'的答案(请查看注释以获得一些角落情况的错误修复)
  • 比Martijn's更复杂,但性能更好:cjerdonek's的答案
  • 如果您不介意使用包依赖,请参见pylang's的答案

3
这个问题(我猜它的答案)被 PEP 618 -- Add Optional Length-Checking To zip 引用,这也引入了 Python 3.10。这证明了“手写一个能够正确解决这个问题的健壮方案并不是一件容易的事情” :-) - Stefan Pochmann
6个回答

51
PEP 618中,为内置的zip函数引入了一个可选的布尔关键字参数strict
引用自Python 3.10的新特性

zip()函数现在有一个可选的strict标志,用于要求所有可迭代对象具有相等的长度。

当启用时,如果其中一个参数先耗尽,则会引发ValueError
>>> list(zip('ab', range(3)))
[('a', 0), ('b', 1)]
>>> list(zip('ab', range(3), strict=True))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ValueError: zip() argument 2 is longer than argument 1

32

我可以想到一个更简单的解决方案,使用itertools.zip_longest(),并且如果用于填充较短迭代器的哨兵值出现在生成的元组中,则引发异常:

from itertools import zip_longest

def zip_equal(*iterables):
    sentinel = object()
    for combo in zip_longest(*iterables, fillvalue=sentinel):
        if sentinel in combo:
            raise ValueError('Iterables have different lengths')
        yield combo

很遗憾,我们不能使用yield fromzip()来避免Python代码中的循环和每次迭代的测试。一旦最短的迭代器用完了,zip()将推进所有前面的迭代器,从而吞噬额外的单个项目的证据。


yield from 的解决方案非常好。谢谢您提供了两种不同的解决方案。 - colidyre
3
额外提醒一件事,第二种解决方案在一个情况下不适用:假设有两个迭代器,第二个比第一个短一个元素。因为zip已经在第一个迭代器上调用了__next__,所以两个迭代器都被消耗完了,即使第一个迭代器更长也是这样。 - magu_
4
另外,最好将 if sentinel in combo 替换为 if any(sentinel is c for c in combo),因为 a in bany(bi==a for b_ in b) 相同 - 等号有时会被覆盖(并且在 combo 中的元素是 numpy 数组时会导致错误)。 - Peter
1
@Peter 实际上 sentinel in combo 检查 身份相等性(至少现在是这样的:-),但是如果有一个元素声称与这个哨兵相等,那么这是错误的,而且 numpy 数组甚至可能会导致这种情况 崩溃,比如 np.ones(2) in [object()]。more-itertools 函数实际上像 Martijn 的函数一样,在循环中只检查身份。这可能是它速度较慢的主要原因,请参见我的答案中的基准测试 - Stefan Pochmann
1
@zeehio:这超级公平,他的解决方案很棒! - Martijn Pieters
显示剩余5条评论

7

一种比基于cjerdonek的解决方案更快的新解决方案和一个基准测试。首先进行基准测试,我的解决方案是绿色的。请注意,在所有情况下,“总大小”都相同,即两百万个值。x轴是可迭代对象的数量。从具有两百万个值的1个可迭代对象开始,然后是具有每个一百万个值的2个可迭代对象,一直到具有每个20个值的10万个可迭代对象。

benchmark plot

黑色的是Python中的zip函数,我在这里使用的是Python 3.8版本,所以它不会像这个问题要求的那样检查迭代器长度是否相等。但我把它作为参考/极限速度来提供。你可以看到我的解决方案非常接近。
对于可能是最常见的两个迭代器进行压缩的情况,我的解决方案比cjerdonek之前最快的解决方案快近三倍,并且比zip慢不了多少。以下是文本格式的时间:
         number of iterables     1     2     3     4     5    10   100  1000 10000 50000 100000
-----------------------------------------------------------------------------------------------
       more_itertools__pylang 209.3 132.1 105.8  93.7  87.4  74.4  54.3  51.9  53.9  66.9  84.5
   fillvalue__Martijn_Pieters 159.1 101.5  85.6  74.0  68.8  59.0  44.1  43.0  44.9  56.9  72.0
     chain_raising__cjerdonek  58.5  35.1  26.3  21.9  19.7  16.6  10.4  12.7  34.4 115.2 223.2
     ziptail__Stefan_Pochmann  10.3  12.4  10.4   9.2   8.7   7.8   6.7   6.8   9.4  22.6  37.8
                          zip  10.3   8.5   7.8   7.4   7.4   7.1   6.4   6.8   9.0  19.4  32.3

我的代码在线尝试!):

def zip_equal(*iterables):

    # For trivial cases, use pure zip.
    if len(iterables) < 2:
        return zip(*iterables)

    # Tail for the first iterable
    first_stopped = False
    def first_tail():
        nonlocal first_stopped 
        first_stopped = True
        return
        yield

    # Tail for the zip
    def zip_tail():
        if not first_stopped:
            raise ValueError('zip_equal: first iterable is longer')
        for _ in chain.from_iterable(rest):
            raise ValueError('zip_equal: first iterable is shorter')
            yield

    # Put the pieces together
    iterables = iter(iterables)
    first = chain(next(iterables), first_tail())
    rest = list(map(iter, iterables))
    return chain(zip(first, *rest), zip_tail())

基本思想是让zip(*iterables)完成所有工作,然后在它停止时检查是否所有可迭代对象的长度都相等。只有当:

  1. zip停止,因为第一个可迭代对象没有其他元素(即没有其他可迭代对象更)。
  2. 其他可迭代对象都没有进一步的元素(即没有其他可迭代对象更)。

我如何检查这些条件:

  • 由于我需要在 zip 结束后检查这些条件,所以我不能纯粹地返回 zip 对象。相反,我在其后面链接一个空的 zip_tail 迭代器来进行检查。
  • 为了支持检查第一个条件,我在其后面链接了一个空的 first_tail 迭代器,其唯一的工作是记录第一个可迭代对象的迭代停止情况(即它被要求提供另一个元素,但没有更多元素可提供,因此要求 first_tail 迭代器提供)。
  • 为了支持检查第二个条件,在将其他所有可迭代对象传递给 zip 之前,我获取它们的迭代器并将它们保存在列表中。

附注:more-itertools 基本上使用了 Martijn 相同的方法,但使用了正确的 is 检查,而不是 Martijn 的 not quite correct sentinel in combo。这可能是它速度较慢的主要原因。

基准代码(在线尝试!):
import timeit
import itertools
from itertools import repeat, chain, zip_longest
from collections import deque
from sys import hexversion, maxsize

#-----------------------------------------------------------------------------
# Solution by Martijn Pieters
#-----------------------------------------------------------------------------

def zip_equal__fillvalue__Martijn_Pieters(*iterables):
    sentinel = object()
    for combo in zip_longest(*iterables, fillvalue=sentinel):
        if sentinel in combo:
            raise ValueError('Iterables have different lengths')
        yield combo

#-----------------------------------------------------------------------------
# Solution by pylang
#-----------------------------------------------------------------------------

def zip_equal__more_itertools__pylang(*iterables):
    return more_itertools__zip_equal(*iterables)

_marker = object()

def _zip_equal_generator(iterables):
    for combo in zip_longest(*iterables, fillvalue=_marker):
        for val in combo:
            if val is _marker:
                raise UnequalIterablesError()
        yield combo

def more_itertools__zip_equal(*iterables):
    """``zip`` the input *iterables* together, but raise
    ``UnequalIterablesError`` if they aren't all the same length.

        >>> it_1 = range(3)
        >>> it_2 = iter('abc')
        >>> list(zip_equal(it_1, it_2))
        [(0, 'a'), (1, 'b'), (2, 'c')]

        >>> it_1 = range(3)
        >>> it_2 = iter('abcd')
        >>> list(zip_equal(it_1, it_2)) # doctest: +IGNORE_EXCEPTION_DETAIL
        Traceback (most recent call last):
        ...
        more_itertools.more.UnequalIterablesError: Iterables have different
        lengths

    """
    if hexversion >= 0x30A00A6:
        warnings.warn(
            (
                'zip_equal will be removed in a future version of '
                'more-itertools. Use the builtin zip function with '
                'strict=True instead.'
            ),
            DeprecationWarning,
        )
    # Check whether the iterables are all the same size.
    try:
        first_size = len(iterables[0])
        for i, it in enumerate(iterables[1:], 1):
            size = len(it)
            if size != first_size:
                break
        else:
            # If we didn't break out, we can use the built-in zip.
            return zip(*iterables)

        # If we did break out, there was a mismatch.
        raise UnequalIterablesError(details=(first_size, i, size))
    # If any one of the iterables didn't have a length, start reading
    # them until one runs out.
    except TypeError:
        return _zip_equal_generator(iterables)

#-----------------------------------------------------------------------------
# Solution by cjerdonek
#-----------------------------------------------------------------------------

class ExhaustedError(Exception):
    def __init__(self, index):
        """The index is the 0-based index of the exhausted iterable."""
        self.index = index

def raising_iter(i):
    """Return an iterator that raises an ExhaustedError."""
    raise ExhaustedError(i)
    yield

def terminate_iter(i, iterable):
    """Return an iterator that raises an ExhaustedError at the end."""
    return itertools.chain(iterable, raising_iter(i))

def zip_equal__chain_raising__cjerdonek(*iterables):
    iterators = [terminate_iter(*args) for args in enumerate(iterables)]
    try:
        yield from zip(*iterators)
    except ExhaustedError as exc:
        index = exc.index
        if index > 0:
            raise RuntimeError('iterable {} exhausted first'.format(index)) from None
        # Check that all other iterators are also exhausted.
        for i, iterator in enumerate(iterators[1:], start=1):
            try:
                next(iterator)
            except ExhaustedError:
                pass
            else:
                raise RuntimeError('iterable {} is longer'.format(i)) from None
            
#-----------------------------------------------------------------------------
# Solution by Stefan Pochmann
#-----------------------------------------------------------------------------

def zip_equal__ziptail__Stefan_Pochmann(*iterables):

    # For trivial cases, use pure zip.
    if len(iterables) < 2:
        return zip(*iterables)

    # Tail for the first iterable
    first_stopped = False
    def first_tail():
        nonlocal first_stopped 
        first_stopped = True
        return
        yield

    # Tail for the zip
    def zip_tail():
        if not first_stopped:
            raise ValueError(f'zip_equal: first iterable is longer')
        for _ in chain.from_iterable(rest):
            raise ValueError(f'zip_equal: first iterable is shorter')
            yield

    # Put the pieces together
    iterables = iter(iterables)
    first = chain(next(iterables), first_tail())
    rest = list(map(iter, iterables))
    return chain(zip(first, *rest), zip_tail())

#-----------------------------------------------------------------------------
# List of solutions to be speedtested
#-----------------------------------------------------------------------------

solutions = [
    zip_equal__more_itertools__pylang,
    zip_equal__fillvalue__Martijn_Pieters,
    zip_equal__chain_raising__cjerdonek,
    zip_equal__ziptail__Stefan_Pochmann,
    zip,
]

def name(solution):
    return solution.__name__[11:] or 'zip'

#-----------------------------------------------------------------------------
# The speedtest code
#-----------------------------------------------------------------------------

def test(m, n):
    """Speedtest all solutions with m iterables of n elements each."""

    all_times = {solution: [] for solution in solutions}
    def show_title():
        print(f'{m} iterators of length {n:,}:')
    if verbose: show_title()
    def show_times(times, solution):
        print(*('%3d ms ' % t for t in times),
              name(solution))
        
    for _ in range(3):
        for solution in solutions:
            times = sorted(timeit.repeat(lambda: deque(solution(*(repeat(i, n) for i in range(m))), 0), number=1, repeat=5))[:3]
            times = [round(t * 1e3, 3) for t in times]
            all_times[solution].append(times)
            if verbose: show_times(times, solution)
        if verbose: print()
        
    if verbose:
        print('best by min:')
        show_title()
        for solution in solutions:
            show_times(min(all_times[solution], key=min), solution)
        print('best by max:')
    show_title()
    for solution in solutions:
        show_times(min(all_times[solution], key=max), solution)
    print()

    stats.append((m,
                  [min(all_times[solution], key=min)
                   for solution in solutions]))

#-----------------------------------------------------------------------------
# Run the speedtest for several numbers of iterables
#-----------------------------------------------------------------------------

stats = []
verbose = False
total_elements = 2 * 10**6
for m in 1, 2, 3, 4, 5, 10, 100, 1000, 10000, 50000, 100000:
    test(m, total_elements // m)

#-----------------------------------------------------------------------------
# Print the speedtest results for use in the plotting script
#-----------------------------------------------------------------------------

print('data for plotting by https://replit.com/@pochmann/zipequal-plot')
names = [name(solution) for solution in solutions]
print(f'{names = }')
print(f'{stats = }')

绘图/表格的代码(也可在Replit上找到):

import matplotlib.pyplot as plt

names = ['more_itertools__pylang', 'fillvalue__Martijn_Pieters', 'chain_raising__cjerdonek', 'ziptail__Stefan_Pochmann', 'zip']
stats = [(1, [[208.762, 211.211, 214.189], [159.568, 162.233, 162.24], [57.668, 58.94, 59.23], [10.418, 10.583, 10.723], [10.057, 10.443, 10.456]]), (2, [[130.065, 130.26, 130.52], [100.314, 101.206, 101.276], [34.405, 34.998, 35.188], [12.152, 12.473, 12.773], [8.671, 8.857, 9.395]]), (3, [[106.417, 107.452, 107.668], [90.693, 91.154, 91.386], [26.908, 27.863, 28.145], [10.457, 10.461, 10.789], [8.071, 8.157, 8.228]]), (4, [[97.547, 98.686, 98.726], [77.076, 78.31, 79.381], [23.134, 23.176, 23.181], [9.321, 9.4, 9.581], [7.541, 7.554, 7.635]]), (5, [[86.393, 88.046, 88.222], [68.633, 69.649, 69.742], [19.845, 20.006, 20.135], [8.726, 8.935, 9.016], [7.201, 7.26, 7.304]]), (10, [[70.384, 71.762, 72.473], [57.87, 58.149, 58.411], [15.808, 16.252, 16.262], [7.568, 7.57, 7.864], [6.732, 6.888, 6.911]]), (100, [[53.108, 54.245, 54.465], [44.436, 44.601, 45.226], [10.502, 11.073, 11.109], [6.721, 6.733, 6.847], [6.753, 6.774, 6.815]]), (1000, [[52.119, 52.476, 53.341], [42.775, 42.808, 43.649], [12.538, 12.853, 12.862], [6.802, 6.971, 7.002], [6.679, 6.724, 6.838]]), (10000, [[54.802, 55.006, 55.187], [45.981, 46.066, 46.735], [34.416, 34.672, 35.009], [9.485, 9.509, 9.626], [9.036, 9.042, 9.112]]), (50000, [[66.681, 66.98, 67.441], [56.593, 57.341, 57.631], [113.988, 114.022, 114.106], [22.088, 22.412, 22.595], [19.412, 19.431, 19.934]]), (100000, [[86.846, 88.111, 88.258], [74.796, 75.431, 75.927], [218.977, 220.182, 223.343], [38.89, 39.385, 39.88], [32.332, 33.117, 33.594]])]

colors = {
    'more_itertools__pylang': 'm',
    'fillvalue__Martijn_Pieters': 'red',
    'chain_raising__cjerdonek': 'gold',
    'ziptail__Stefan_Pochmann': 'lime',
    'zip': 'black',
}

ns = [n for n, _ in stats]
print('%28s' % 'number of iterables', *('%5d' % n for n in ns))
print('-' * 95)
x = range(len(ns))
for i, name in enumerate(names):
    ts = [min(tss[i]) for _, tss in stats]
    color = colors[name]
    if color:
        plt.plot(x, ts, '.-', color=color, label=name)
        print('%29s' % name, *('%5.1f' % t for t in ts))
plt.xticks(x, ns, size=9)
plt.ylim(0, 133)
plt.title('zip_equal(m iterables with 2,000,000/m values each)', weight='bold')
plt.xlabel('Number of zipped *iterables* (not their lengths)', weight='bold')
plt.ylabel('Time (for complete iteration) in milliseconds', weight='bold')
plt.legend(loc='upper center')
#plt.show()
plt.savefig('zip_equal_plot.png', dpi=200)

我不这么认为。你为什么这样想? - Stefan Pochmann
嗯,restzip_tail 中被使用,但是在一个封闭但不是全局的作用域中声明。显然,即使没有它,它也可以正常工作,但是明确声明 rest 的非本地性对我来说更加明确和清晰。 - Dan Lenski
是的,只有在你对它进行赋值时才需要使用它,而我在那里并没有这样做。我认为这就是它的意义所在,我从来没有看到过人们在不需要的情况下将变量声明为nonlocal或global。我实际上会觉得这是误导性的。PEP 8根本没有提到nonlocal/global。即使您没有将first_stoppedchainValueError声明为这样,尽管它们都不是局部变量。 - Stefan Pochmann
话虽如此,在某些情况下可能会更快(但在这种情况下不是,因为我只访问了一次)。当我回到电脑而不是手机上时,我会尝试记得测试/研究它。 - Stefan Pochmann
1
@StefanPochmann:一种微优化方法,可以让Python层面处理更多的工作:用以下两行代码替换从iterables = iter(iterables)rest = list(map(iter, iterables))(总共三行)的代码:first, *rest = map(iter, iterables)first = chain(first, first_tail())(如果需要,第二行可以内联到return行中)。这样做可以作为Python提供的优化字节码的批量操作。在生产代码中可能不值得这么折腾,但我觉得很有趣。 :-) - ShadowRanger
显示剩余2条评论

7
使用 more_itertools.zip_equal (v8.3.0+): 代码
import more_itertools as mit

演示

list(mit.zip_equal(range(3), "abc"))
# [(0, 'a'), (1, 'b'), (2, 'c')]

list(mit.zip_equal(range(3), "abcd"))
# UnequalIterablesError

more_itertools 是一个第三方软件包,可以通过 λ pip install more_itertools 安装。


6
以下是一种方法,不需要在每次迭代循环中进行任何额外的检查。这对于长迭代尤其有用。其思路是在可迭代对象的末尾填充一个“值”,当到达该值时引发异常,然后仅在最后执行所需的验证。此方法使用了zip()和itertools.chain()。下面的代码适用于Python 3.5。
import itertools

class ExhaustedError(Exception):
    def __init__(self, index):
        """The index is the 0-based index of the exhausted iterable."""
        self.index = index

def raising_iter(i):
    """Return an iterator that raises an ExhaustedError."""
    raise ExhaustedError(i)
    yield

def terminate_iter(i, iterable):
    """Return an iterator that raises an ExhaustedError at the end."""
    return itertools.chain(iterable, raising_iter(i))

def zip_equal(*iterables):
    iterators = [terminate_iter(*args) for args in enumerate(iterables)]
    try:
        yield from zip(*iterators)
    except ExhaustedError as exc:
        index = exc.index
        if index > 0:
            raise RuntimeError('iterable {} exhausted first'.format(index)) from None
        # Check that all other iterators are also exhausted.
        for i, iterator in enumerate(iterators[1:], start=1):
            try:
                next(iterator)
            except ExhaustedError:
                pass
            else:
                raise RuntimeError('iterable {} is longer'.format(i)) from None

以下是使用示例。
>>> list(zip_equal([1, 2], [3, 4], [5, 6]))
[(1, 3, 5), (2, 4, 6)]

>>> list(zip_equal([1, 2], [3], [4]))
RuntimeError: iterable 1 exhausted first

>>> list(zip_equal([1], [2, 3], [4]))
RuntimeError: iterable 1 is longer

>>> list(zip_equal([1], [2], [3, 4]))
RuntimeError: iterable 2 is longer

我更喜欢这种方法。虽然它比被接受的答案更复杂,但它使用了EAFP而不是LBYL,并且还提供了更好的错误消息。太棒了。 - Rick
2
我编辑了我的问题,并简要讨论了当出现性能问题时指向您的答案。感谢您的解决方案! - zeehio
这种方法的唯一问题是,即使所有迭代器的长度相同,它也会生成不必要的异常。 - Shital Shah
你可以通过直接压缩第一个可迭代对象,而不是使用链接的raising-iter来使它更快。也就是说,在terminate_iter中执行if i == 0: return iterable。然后在except ExhaustedError中无条件引发异常,并将Check that all other部分的缩进级别减少一级。话虽如此,我现在采取了更大的步骤我的解决方案 - Stefan Pochmann

1

我想出了一个使用哨兵可迭代对象的解决方案,供您参考:

class _SentinelException(Exception):
    def __iter__(self):
        raise _SentinelException


def zip_equal(iterable1, iterable2):
    i1 = iter(itertools.chain(iterable1, _SentinelException()))
    i2 = iter(iterable2)
    try:
        while True:
            yield (next(i1), next(i2))
    except _SentinelException:  # i1 reaches end
        try:
            next(i2)  # check whether i2 reaches end
        except StopIteration:
            pass
        else:
            raise ValueError('the second iterable is longer than the first one')
    except StopIteration: # i2 reaches end, as next(i1) has already been called, i1's length is bigger than i2
        raise ValueError('the first iterable is longger the second one.')

这个解决方案相对于已接受的解决方案有什么优势? - zeehio
只是另一种解决方案。对我来说,作为来自C++世界的人,我不喜欢每个yield中都有“if sentinel in combo”的检查。但是由于我们在Python世界中,没有人关心性能。 - XU Weijiang
谢谢你的回答,但如果你真的关心性能,你应该对其进行基准测试。你的解决方案慢了80%。这是一个基准测试:https://gist.github.com/zeehio/cdf7d881cc7f612b2c853fbd3a18ccbe - zeehio
1
谢谢您的友好基准测试。对于误导感到抱歉。是的,它较慢,我应该早点想到,因为izip_longest是本地的。 - XU Weijiang
抱歉如果我的回复有些严厉。感谢你的答案,我们比较了性能。我们发现被接受的答案比其他解决方案更快。现在我们也有一种快速的方式来基准测试任何未来的解决方案。我们现在比一个星期前更加了解。 :-) - zeehio
显示剩余2条评论

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接