深度复制嵌套可迭代对象(或针对可迭代对象的迭代器改进版itertools.tee)

7

前言

我正在处理一个测试,其中涉及到嵌套可迭代对象(所谓的嵌套可迭代对象是指只包含可迭代对象作为元素的可迭代对象)。

假设我们有以下测试级联:

from itertools import tee
from typing import (Any,
                    Iterable)


def foo(nested_iterable: Iterable[Iterable[Any]]) -> Any:
    ...


def test_foo(nested_iterable: Iterable[Iterable[Any]]) -> None:
    original, target = tee(nested_iterable)  # this doesn't copy iterators elements

    result = foo(target)

    assert is_contract_satisfied(result, original)


def is_contract_satisfied(result: Any,
                          original: Iterable[Iterable[Any]]) -> bool:
    ...

例如,foo可能是一个简单的标识函数。
def foo(nested_iterable: Iterable[Iterable[Any]]) -> Iterable[Iterable[Any]]:
    return nested_iterable

合同只是检查扁平化的可迭代对象是否具有相同的元素。

from itertools import (chain,
                       starmap,
                       zip_longest)
from operator import eq
...
flatten = chain.from_iterable


def is_contract_satisfied(result: Iterable[Iterable[Any]],
                          original: Iterable[Iterable[Any]]) -> bool:
    return all(starmap(eq,
                       zip_longest(flatten(result), flatten(original),
                                   # we're assuming that ``object()``
                                   # will create some unique object
                                   # not presented in any of arguments
                                   fillvalue=object())))

但是,如果nested_iterable中的某些元素是一个迭代器,那么当使用tee进行浅拷贝时,它可能会被耗尽。也就是说,对于给定的foois_contract_satisfied语句而言,它们并不是深层次的拷贝。

>>> test_foo([iter(range(10))])

导致可预测的

Traceback (most recent call last):
  ...
    test_foo([iter(range(10))])
  File "...", line 19, in test_foo
    assert is_contract_satisfied(result, original)
AssertionError

问题

如何深度复制任意嵌套可迭代对象?

注意

我知道有 copy.deepcopy 函数,但它对文件对象无效。


你是否有什么不愿意将嵌套迭代器简单地转换为嵌套列表的原因呢? - juanpa.arrivillaga
1
@juanpa.arrivillaga: 是的,我正在编写一个库,可以处理任意可迭代对象(包括有限和无限的、用户定义的或来自标准库),并编写基于属性的测试。 - Azat Ibrakov
2个回答

5

朴素解法

直接的算法是:

  1. 对原始嵌套可迭代对象进行逐元素复制。
  2. 对逐元素复制体进行 n 次复制。
  3. 获取每个独立副本相关的坐标。

可以像下面这样实现:

from itertools import tee
from operator import itemgetter
from typing import (Any,
                    Iterable,
                    Tuple,
                    TypeVar)

Domain = TypeVar('Domain')


def copy_nested_iterable(nested_iterable: Iterable[Iterable[Domain]],
                         *,
                         count: int = 2
                         ) -> Tuple[Iterable[Iterable[Domain]], ...]:
    def shallow_copy(iterable: Iterable[Domain]) -> Tuple[Iterable[Domain], ...]:
        return tee(iterable, count)

    copies = shallow_copy(map(shallow_copy, nested_iterable))
    return tuple(map(itemgetter(index), iterables)
                 for index, iterables in enumerate(copies))

优点:

  • 相当易于阅读和解释。

缺点:

  • 如果我们想要扩展我们的方法来处理更深层次的可迭代对象(如嵌套的可迭代对象等),这种方法似乎并不有用。

我们可以做得更好。

改进后的解决方案

如果我们查看itertools.tee函数文档,它包含了Python的示例代码,通过functools.singledispatch装饰器的帮助,可以进行重写,如下所示:

from collections import (abc,
                         deque)
from functools import singledispatch
from itertools import repeat
from typing import (Iterable,
                    Tuple,
                    TypeVar)

Domain = TypeVar('Domain')


@functools.singledispatch
def copy(object_: Domain,
         *,
         count: int) -> Iterable[Domain]:
    raise TypeError('Unsupported object type: {type}.'
                    .format(type=type(object_)))

# handle general case
@copy.register(object)
# immutable strings represent a special kind of iterables
# that can be copied by simply repeating
@copy.register(bytes)
@copy.register(str)
# mappings cannot be copied as other iterables
# since they are iterable only by key
@copy.register(abc.Mapping)
def copy_object(object_: Domain,
                *,
                count: int) -> Iterable[Domain]:
    return itertools.repeat(object_, count)


@copy.register(abc.Iterable)
def copy_iterable(object_: Iterable[Domain],
                  *,
                  count: int = 2) -> Tuple[Iterable[Domain], ...]:
    iterator = iter(object_)
    # we are using `itertools.repeat` instead of `range` here
    # due to efficiency of the former
    # more info at
    # https://dev59.com/VWox5IYBdhLWcg3wql9t#9098860
    queues = [deque() for _ in repeat(None, count)]

    def replica(queue: deque) -> Iterable[Domain]:
        while True:
            if not queue:
                try:
                    element = next(iterator)
                except StopIteration:
                    return
                element_copies = copy(element,
                                           count=count)
                for sub_queue, element_copy in zip(queues, element_copies):
                    sub_queue.append(element_copy)
            yield queue.popleft()

    return tuple(replica(queue) for queue in queues)

优点:

  • 处理深层次嵌套甚至是同时包含可迭代和非可迭代元素的混合元素,
  • 可以扩展为用户定义的结构(例如用于制作它们的独立深度副本)。

缺点:

  • 可读性较差(但正如我们所知道的"实用胜于纯粹"),
  • 提供与分派相关的一些开销(但这没关系,因为它基于字典查找,具有O(1)复杂度)。

测试

准备工作

让我们将嵌套迭代器定义如下

nested_iterable = [range(10 ** index) for index in range(1, 7)]

由于迭代器的创建并不涉及底层副本性能,因此让我们定义迭代器耗尽的函数(在这里描述)。

exhaust_iterable = deque(maxlen=0).extend

时间

使用 timeit

import timeit

def naive(): exhaust_iterable(copy_nested_iterable(nested_iterable))

def improved(): exhaust_iterable(copy_iterable(nested_iterable))

print('naive approach:', min(timeit.repeat(naive)))
print('improved approach:', min(timeit.repeat(improved)))

我在我的Windows 10 x64笔记本电脑上安装了Python 3.5.4

naive approach: 5.1863865
improved approach: 3.5602296000000013

内存

使用memory_profiler

Line #    Mem usage    Increment   Line Contents
================================================
    78     17.2 MiB     17.2 MiB   @profile
    79                             def profile_memory(nested_iterable: Iterable[Iterable[Any]]) -> None:
    80     68.6 MiB     51.4 MiB       result = list(flatten(flatten(copy_nested_iterable(nested_iterable))))

对于“naive”方法和

Line #    Mem usage    Increment   Line Contents
================================================
    78     17.2 MiB     17.2 MiB   @profile
    79                             def profile_memory(nested_iterable: Iterable[Iterable[Any]]) -> None:
    80     68.7 MiB     51.4 MiB       result = list(flatten(flatten(copy_iterable(nested_iterable))))

用于“改进”的一个。

注意:我运行了不同的脚本,因为一次性创建它们不会具有代表性,因为第二个语句将重复使用之前在幕后创建的int对象。


结论

我们可以看到这两个函数的性能类似,但最后一个支持更深层次的嵌套,而且看起来很容易扩展。

广告

我已经在lz0.4.0版本中添加了“改进”的解决方案,可以像这样使用:

>>> from lz.replication import replicate
>>> iterable = iter(range(5))
>>> list(map(list, replicate(iterable,
                             count=3)))
[[0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4]]

它使用 hypothesis 框架 进行基于属性的测试,因此我们可以确信它按预期工作。

0
回答你的问题:如何深度复制嵌套可迭代对象? 你可以使用标准库中的deepcopy
>>> from copy import deepcopy
>>> 
>>> ni = [1, [2,3,4]]
>>> ci = deepcopy(ni)
>>> ci[1][0] = "Modified"
>>> ci
[1, ['Modified', 3, 4]]
>>> ni
[1, [2,3,4]]

更新

@Azat Ibrakov说:你正在使用序列,请尝试对文件对象进行深复制(提示:它将失败)

不,对文件对象进行深复制不会失败,你可以深度复制文件对象,演示:

import copy

with open('example.txt', 'w') as f:
     f.writelines(["{}\n".format(i) for i in range(100)])

with open('example.txt', 'r') as f:
    l = [1, [f]]
    c = copy.deepcopy(l)
    print(isinstance(c[1][0], file))  # Prints  True.
    print("\n".join(dir(c[1][0])))

输出:

True
__class__
__delattr__
__doc__
__enter__
__exit__
__format__
__getattribute__
...
write
writelines
xreadlines

问题在于概念。

根据Python迭代器协议,通过执行next函数可以获取某些容器中包含的项目,请参见此处的文档

您不会拥有实现迭代器协议的对象(如文件对象)的所有项目,直到遍历整个迭代器(执行next()直到引发StopIteration异常)。

这是因为您无法确定执行迭代器的next(对于Python 2.x的__next__)方法的结果。

请参见以下示例:

import random

class RandomNumberIterator:

    def __init__(self):
        self.count = 0
        self.internal_it = range(10)  # For later demostration on deepcopy

    def __iter__(self):
        return self

    def next(self):
        self.count += 1
        if self.count == 10:
            raise StopIteration
        return random.randint(0, 1000)

ri = RandomNumberIterator()

for i in ri:
    print(i)  # This will print randor numbers each time.
              # Can you come out with some sort of mechanism to be able
              # to copy **THE CONTENT** of the `ri` iterator? 

你可以再次:

from copy import deepcopy

cri = deepcopy(ri)

for i in cri.internal_it:
    print(i)   # Will print numbers 0..9
               # Deepcopy on ri successful!

一个文件对象在这里是一个特殊情况,涉及到文件处理程序,在此之前,您可以深度复制一个文件对象,但它将处于“已关闭”状态。
另一种方法是,您可以在可迭代对象上调用list,这将自动评估可迭代对象,然后您就可以测试可迭代对象的内容
回到文件:
with open('example.txt', 'w') as f:
         f.writelines(["{}\n".format(i) for i in range(5)])

with open('example.txt', 'r') as f:
    print(list(f))  # Prints ['0\n', '1\n', '2\n', '3\n', '4\n']

所以,恢复

您可以深度复制嵌套的可迭代对象,但是在复制它们时无法评估可迭代对象,这是没有意义的(请记住RandomNumberIterator)。

如果您需要测试可迭代对象的内容,则需要对其进行评估。


2
你正在处理序列,尝试对文件对象进行深拷贝(提示:会失败)。 - Azat Ibrakov
你使用的是哪个Python版本?对于Python 3,deepcopy文件对象将会出现TypeError: cannot serialize '_io.TextIOWrapper' object错误。 - Azat Ibrakov
“在复制可迭代对象时无法进行评估”是什么意思?我可以使用itertools.tee成功地复制普通的可迭代对象,然后独立地对每个对象进行评估,甚至是潜在的无限对象。 - Azat Ibrakov
在Python 2.7中,如果我执行copy.deepcopy(file),我会得到<closed file '<uninitialized file>', mode '<uninitialized file>' at 0x7f4a99ac3930>,并且在尝试像list(file_copy)这样迭代它时,会引发ValueError: I/O operation on closed file,而原始文件按预期工作,因此,不能使用copy.deepcopy函数创建文件对象的可用副本。 - Azat Ibrakov
@Azat Ibrakov,根据您的代码示例,您正在使用类型注释,因此我假设您正在使用Python 3.x,所以我也在我的答案中使用了Python 3x。另一方面,文件对象与文件内容不同。 - Raydel Miranda
我知道文件对象是什么以及文件内容是什么,不过我不知道你在用哪个版本的Python,但是你的代码在Windows和Linux下的Python3.3+和Python2.7都无法工作,所以我不知道你要卖给我什么,另外有些用户自定义的可迭代结构可能不支持 copy.deepcopy - Azat Ibrakov

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接