一种比基于cjerdonek的解决方案更快的新解决方案和一个基准测试。首先进行基准测试,我的解决方案是绿色的。请注意,在所有情况下,“总大小”都相同,即两百万个值。x轴是可迭代对象的数量。从具有两百万个值的1个可迭代对象开始,然后是具有每个一百万个值的2个可迭代对象,一直到具有每个20个值的10万个可迭代对象。
![benchmark plot](https://istack.dev59.com/tO7En.webp)
黑色的是Python中的
zip
函数,我在这里使用的是Python 3.8版本,所以它不会像这个问题要求的那样检查迭代器长度是否相等。但我把它作为参考/极限速度来提供。你可以看到我的解决方案非常接近。
对于可能是最常见的两个迭代器进行压缩的情况,我的解决方案比cjerdonek之前最快的解决方案快近三倍,并且比
zip
慢不了多少。以下是文本格式的时间:
number of iterables 1 2 3 4 5 10 100 1000 10000 50000 100000
-----------------------------------------------------------------------------------------------
more_itertools__pylang 209.3 132.1 105.8 93.7 87.4 74.4 54.3 51.9 53.9 66.9 84.5
fillvalue__Martijn_Pieters 159.1 101.5 85.6 74.0 68.8 59.0 44.1 43.0 44.9 56.9 72.0
chain_raising__cjerdonek 58.5 35.1 26.3 21.9 19.7 16.6 10.4 12.7 34.4 115.2 223.2
ziptail__Stefan_Pochmann 10.3 12.4 10.4 9.2 8.7 7.8 6.7 6.8 9.4 22.6 37.8
zip 10.3 8.5 7.8 7.4 7.4 7.1 6.4 6.8 9.0 19.4 32.3
我的代码(在线尝试!):
def zip_equal(*iterables):
if len(iterables) < 2:
return zip(*iterables)
first_stopped = False
def first_tail():
nonlocal first_stopped
first_stopped = True
return
yield
def zip_tail():
if not first_stopped:
raise ValueError('zip_equal: first iterable is longer')
for _ in chain.from_iterable(rest):
raise ValueError('zip_equal: first iterable is shorter')
yield
iterables = iter(iterables)
first = chain(next(iterables), first_tail())
rest = list(map(iter, iterables))
return chain(zip(first, *rest), zip_tail())
基本思想是让
zip(*iterables)
完成所有工作,然后在它停止时检查是否所有可迭代对象的长度都相等。只有当:
zip
停止,因为第一个可迭代对象没有其他元素(即没有其他可迭代对象更短)。
- 其他可迭代对象都没有进一步的元素(即没有其他可迭代对象更长)。
我如何检查这些条件:
- 由于我需要在
zip
结束后检查这些条件,所以我不能纯粹地返回 zip
对象。相反,我在其后面链接一个空的 zip_tail
迭代器来进行检查。
- 为了支持检查第一个条件,我在其后面链接了一个空的
first_tail
迭代器,其唯一的工作是记录第一个可迭代对象的迭代停止情况(即它被要求提供另一个元素,但没有更多元素可提供,因此要求 first_tail
迭代器提供)。
- 为了支持检查第二个条件,在将其他所有可迭代对象传递给
zip
之前,我获取它们的迭代器并将它们保存在列表中。
附注:more-itertools 基本上使用了 Martijn 相同的方法,但使用了正确的 is
检查,而不是 Martijn 的 not quite correct sentinel in combo
。这可能是它速度较慢的主要原因。
基准代码(
在线尝试!):
import timeit
import itertools
from itertools import repeat, chain, zip_longest
from collections import deque
from sys import hexversion, maxsize
def zip_equal__fillvalue__Martijn_Pieters(*iterables):
sentinel = object()
for combo in zip_longest(*iterables, fillvalue=sentinel):
if sentinel in combo:
raise ValueError('Iterables have different lengths')
yield combo
def zip_equal__more_itertools__pylang(*iterables):
return more_itertools__zip_equal(*iterables)
_marker = object()
def _zip_equal_generator(iterables):
for combo in zip_longest(*iterables, fillvalue=_marker):
for val in combo:
if val is _marker:
raise UnequalIterablesError()
yield combo
def more_itertools__zip_equal(*iterables):
"""``zip`` the input *iterables* together, but raise
``UnequalIterablesError`` if they aren't all the same length.
>>> it_1 = range(3)
>>> it_2 = iter('abc')
>>> list(zip_equal(it_1, it_2))
[(0, 'a'), (1, 'b'), (2, 'c')]
>>> it_1 = range(3)
>>> it_2 = iter('abcd')
>>> list(zip_equal(it_1, it_2)) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
more_itertools.more.UnequalIterablesError: Iterables have different
lengths
"""
if hexversion >= 0x30A00A6:
warnings.warn(
(
'zip_equal will be removed in a future version of '
'more-itertools. Use the builtin zip function with '
'strict=True instead.'
),
DeprecationWarning,
)
try:
first_size = len(iterables[0])
for i, it in enumerate(iterables[1:], 1):
size = len(it)
if size != first_size:
break
else:
return zip(*iterables)
raise UnequalIterablesError(details=(first_size, i, size))
except TypeError:
return _zip_equal_generator(iterables)
class ExhaustedError(Exception):
def __init__(self, index):
"""The index is the 0-based index of the exhausted iterable."""
self.index = index
def raising_iter(i):
"""Return an iterator that raises an ExhaustedError."""
raise ExhaustedError(i)
yield
def terminate_iter(i, iterable):
"""Return an iterator that raises an ExhaustedError at the end."""
return itertools.chain(iterable, raising_iter(i))
def zip_equal__chain_raising__cjerdonek(*iterables):
iterators = [terminate_iter(*args) for args in enumerate(iterables)]
try:
yield from zip(*iterators)
except ExhaustedError as exc:
index = exc.index
if index > 0:
raise RuntimeError('iterable {} exhausted first'.format(index)) from None
for i, iterator in enumerate(iterators[1:], start=1):
try:
next(iterator)
except ExhaustedError:
pass
else:
raise RuntimeError('iterable {} is longer'.format(i)) from None
def zip_equal__ziptail__Stefan_Pochmann(*iterables):
if len(iterables) < 2:
return zip(*iterables)
first_stopped = False
def first_tail():
nonlocal first_stopped
first_stopped = True
return
yield
def zip_tail():
if not first_stopped:
raise ValueError(f'zip_equal: first iterable is longer')
for _ in chain.from_iterable(rest):
raise ValueError(f'zip_equal: first iterable is shorter')
yield
iterables = iter(iterables)
first = chain(next(iterables), first_tail())
rest = list(map(iter, iterables))
return chain(zip(first, *rest), zip_tail())
solutions = [
zip_equal__more_itertools__pylang,
zip_equal__fillvalue__Martijn_Pieters,
zip_equal__chain_raising__cjerdonek,
zip_equal__ziptail__Stefan_Pochmann,
zip,
]
def name(solution):
return solution.__name__[11:] or 'zip'
def test(m, n):
"""Speedtest all solutions with m iterables of n elements each."""
all_times = {solution: [] for solution in solutions}
def show_title():
print(f'{m} iterators of length {n:,}:')
if verbose: show_title()
def show_times(times, solution):
print(*('%3d ms ' % t for t in times),
name(solution))
for _ in range(3):
for solution in solutions:
times = sorted(timeit.repeat(lambda: deque(solution(*(repeat(i, n) for i in range(m))), 0), number=1, repeat=5))[:3]
times = [round(t * 1e3, 3) for t in times]
all_times[solution].append(times)
if verbose: show_times(times, solution)
if verbose: print()
if verbose:
print('best by min:')
show_title()
for solution in solutions:
show_times(min(all_times[solution], key=min), solution)
print('best by max:')
show_title()
for solution in solutions:
show_times(min(all_times[solution], key=max), solution)
print()
stats.append((m,
[min(all_times[solution], key=min)
for solution in solutions]))
stats = []
verbose = False
total_elements = 2 * 10**6
for m in 1, 2, 3, 4, 5, 10, 100, 1000, 10000, 50000, 100000:
test(m, total_elements // m)
print('data for plotting by https://replit.com/@pochmann/zipequal-plot')
names = [name(solution) for solution in solutions]
print(f'{names = }')
print(f'{stats = }')
绘图/表格的代码(也可在Replit上找到):
import matplotlib.pyplot as plt
names = ['more_itertools__pylang', 'fillvalue__Martijn_Pieters', 'chain_raising__cjerdonek', 'ziptail__Stefan_Pochmann', 'zip']
stats = [(1, [[208.762, 211.211, 214.189], [159.568, 162.233, 162.24], [57.668, 58.94, 59.23], [10.418, 10.583, 10.723], [10.057, 10.443, 10.456]]), (2, [[130.065, 130.26, 130.52], [100.314, 101.206, 101.276], [34.405, 34.998, 35.188], [12.152, 12.473, 12.773], [8.671, 8.857, 9.395]]), (3, [[106.417, 107.452, 107.668], [90.693, 91.154, 91.386], [26.908, 27.863, 28.145], [10.457, 10.461, 10.789], [8.071, 8.157, 8.228]]), (4, [[97.547, 98.686, 98.726], [77.076, 78.31, 79.381], [23.134, 23.176, 23.181], [9.321, 9.4, 9.581], [7.541, 7.554, 7.635]]), (5, [[86.393, 88.046, 88.222], [68.633, 69.649, 69.742], [19.845, 20.006, 20.135], [8.726, 8.935, 9.016], [7.201, 7.26, 7.304]]), (10, [[70.384, 71.762, 72.473], [57.87, 58.149, 58.411], [15.808, 16.252, 16.262], [7.568, 7.57, 7.864], [6.732, 6.888, 6.911]]), (100, [[53.108, 54.245, 54.465], [44.436, 44.601, 45.226], [10.502, 11.073, 11.109], [6.721, 6.733, 6.847], [6.753, 6.774, 6.815]]), (1000, [[52.119, 52.476, 53.341], [42.775, 42.808, 43.649], [12.538, 12.853, 12.862], [6.802, 6.971, 7.002], [6.679, 6.724, 6.838]]), (10000, [[54.802, 55.006, 55.187], [45.981, 46.066, 46.735], [34.416, 34.672, 35.009], [9.485, 9.509, 9.626], [9.036, 9.042, 9.112]]), (50000, [[66.681, 66.98, 67.441], [56.593, 57.341, 57.631], [113.988, 114.022, 114.106], [22.088, 22.412, 22.595], [19.412, 19.431, 19.934]]), (100000, [[86.846, 88.111, 88.258], [74.796, 75.431, 75.927], [218.977, 220.182, 223.343], [38.89, 39.385, 39.88], [32.332, 33.117, 33.594]])]
colors = {
'more_itertools__pylang': 'm',
'fillvalue__Martijn_Pieters': 'red',
'chain_raising__cjerdonek': 'gold',
'ziptail__Stefan_Pochmann': 'lime',
'zip': 'black',
}
ns = [n for n, _ in stats]
print('%28s' % 'number of iterables', *('%5d' % n for n in ns))
print('-' * 95)
x = range(len(ns))
for i, name in enumerate(names):
ts = [min(tss[i]) for _, tss in stats]
color = colors[name]
if color:
plt.plot(x, ts, '.-', color=color, label=name)
print('%29s' % name, *('%5.1f' % t for t in ts))
plt.xticks(x, ns, size=9)
plt.ylim(0, 133)
plt.title('zip_equal(m iterables with 2,000,000/m values each)', weight='bold')
plt.xlabel('Number of zipped *iterables* (not their lengths)', weight='bold')
plt.ylabel('Time (for complete iteration) in milliseconds', weight='bold')
plt.legend(loc='upper center')
plt.savefig('zip_equal_plot.png', dpi=200)