用Python编写更快的insert()实现

3

本质上,我需要编写一个更快的实现来替换insert()方法,在列表中特定位置插入元素。

输入以列表形式给出,格式为[(索引, 值), (索引, 值), (索引, 值)]

例如:在100万个元素的列表中插入1万个元素大约需要2.7秒。

def do_insertions_simple(l, insertions):
    """Performs the insertions specified into l.
    @param l: list in which to do the insertions.  Is is not modified.
    @param insertions: list of pairs (i, x), indicating that x should
        be inserted at position i.
    """
    r = list(l)
    for i, x in insertions:
        r.insert(i, x)
    return r

我可以帮助您翻译中文,涉及IT技术相关内容。您需要提供的内容是需要将插入操作完成所需时间加快8倍或更多。

我目前的实现方式:

def do_insertions_fast(l, insertions):
    """Implement here a faster version of do_insertions_simple """
    #insert insertions[x][i] at l[i]
    result=list(l)
    for x,y in insertions:
      result = result[:x]+list(y)+result[x:]
    return result

样例输入:

import string
l = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
insertions = [(0, 'a'), (2, 'b'), (2, 'b'), (7, 'c')]
r1 = do_insertions_simple(l, insertions)
r2 = do_insertions_fast(l, insertions)
print("r1:", r1)
print("r2:", r2)
assert_equal(r1, r2)

is_correct = False
for _ in range(20):
    l, insertions = generate_testing_case(list_len=100, num_insertions=20)
    r1 = do_insertions_simple(l, insertions)
    r2 = do_insertions_fast(l, insertions)
    assert_equal(r1, r2)
    is_correct = True

我运行上述代码时遇到的错误:
r1: ['a', 0, 'b', 'b', 1, 2, 3, 'c', 4, 5, 6, 7, 8, 9]
r2: ['a', 0, 'b', 'b', 1, 2, 3, 'c', 4, 5, 6, 7, 8, 9]
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-8-54e0c44a8801> in <module>()
     12     l, insertions = generate_testing_case(list_len=100, num_insertions=20)
     13     r1 = do_insertions_simple(l, insertions)
---> 14     r2 = do_insertions_fast(l, insertions)
     15     assert_equal(r1, r2)
     16     is_correct = True

<ipython-input-7-b421ee7cc58f> in do_insertions_fast(l, insertions)
      4     result=list(l)
      5     for x,y in insertions:
----> 6       result = result[:x]+list(y)+result[x:]
      7     return result
      8     #raise NotImplementedError()

TypeError: 'float' object is not iterable

该文件使用nose框架来检查我的答案等,所以如果有任何您不认识的函数,可能是该框架中的函数。
我知道它正确地插入了列表,但它仍然会出现“浮点对象不可迭代”的错误。
我也尝试过另一种方法,它确实起作用(切分了列表,添加了元素,添加了剩余的列表,然后更新了列表),但比insert()慢10倍。
我不确定如何继续下去。
编辑:我一直在错误地看待整个问题,现在我将尝试自己解决它,但如果我再次陷入困境,我将提出一个不同的问题并在此处链接。

list(y) 尝试迭代 y 并将其转换为列表。如果你只想将 y 转换为只有一个元素的列表,可以直接使用 [y] - b_c
2
我现在可以肯定地告诉你,你的方法并不比insert更快速 :) - iz_
顺便说一下,重新分配内存确实是一个问题(我夸大了它的重要性),但是只调整大小一次的想法在复制元素时并没有帮助。在上面链接的函数中,你需要“移动”现有的元素来为新元素腾出空间。这个“移动”过程非常耗时。 - iz_
1
如果您想要快速、任意插入,您需要另一种数据结构(如双向链表)。 - chepner
@chepner 如果他们不是一次性收到所有的插入请求,我会同意这个观点。 - Kelly Bundy
显示剩余11条评论
3个回答

3

根据你的问题,重点是:

我需要编写一个更快的实现来替换 insert() 函数,以在列表中特定位置插入元素。

这是不可能的。如果已经有更快的方法,那么现有的 insert() 函数一定会使用它。无论你做什么都不可能达到这个速度。

但你可以编写一个更快的方式来进行多次插入操作。

让我们看一个包含两个插入操作的示例:

>>> a = list(range(15))
>>> a.insert(5, 'X')
>>> a.insert(10, 'Y')
>>> a
[0, 1, 2, 3, 4, 'X', 5, 6, 7, 8, 'Y', 9, 10, 11, 12, 13, 14]

由于每次插入都会将其右侧的所有值向右移,因此通常这是一个O(m*(n+m))时间复杂度的算法,其中n是列表的原始大小,m是插入的数量。

另一种方法是根据插入点逐个构建结果:

>>> a = list(range(15))
>>> b = []
>>> b.extend(a[:5])
>>> b.append('X')
>>> b.extend(a[5:9])
>>> b.append('Y')
>>> b.extend(a[9:])
>>> b
[0, 1, 2, 3, 4, 'X', 5, 6, 7, 8, 'Y', 9, 10, 11, 12, 13, 14]

这需要O(n+m)的时间,因为所有值只复制一次,没有移动。但是正确确定每段的长度有些棘手,因为早期插入会影响后面的插入。特别是如果插入索引没有排序(在这种情况下,还需要O(m log m)的额外时间对它们进行排序)。这就是为什么我不得不使用[5:9]a[9:]而不是[5:10]a[10:]
(是的,我知道extend/append在内部如果容量耗尽会复制更多,但如果你足够理解这些东西指出来,那么你也理解它并不重要 :-))

这几乎可以肯定是预期的解决方案。然而,如果插入索引没有排序,那么它将需要 O(n + m log m) 的时间,因为你需要对它们进行排序,但是排序它们会很棘手,因为插入操作不可交换,所以如果你按照不同的顺序进行插入,实际上需要在不同的索引位置进行插入。 - kaya3
1
对的,我忽略了排序时间。我猜是因为我还希望它们已经排序了,只是在问题说明中被忽视了。我会提到这一点。是的,插入操作会相互影响,这就是为什么我的演示已经从索引9开始切片,而不是10。这也是我所说的棘手之处。 - Kelly Bundy
1
是的,这是正确的方法。我已经将你的想法转化为代码,并在这里提供了。如果你想把它融入到你的答案中,可以参考一下。 - iz_
1
@kaya3 我对此进行了一些扩展,并修复了天真方法的复杂度(如果m远大于n,则不是O(nm))。顺便说一下,我实际上在另一个案例中“利用”了插入操作相互影响的特性,我对此非常满意 :-) - Kelly Bundy
@iz_ 是的,预处理插入可能是解决问题的方法。特别是如果索引没有按顺序给出的话。我很怀疑这可能是一些编码网站上已经存在的问题,如果有提供的话,我可能会实际实现它。 - Kelly Bundy
显示剩余3条评论

2
一个选择是使用不同的数据结构,它支持更快的插入。
显而易见的建议是使用某种二叉树。只要能在O(log n)时间内找到正确的插入点,就可以将节点插入到平衡的二叉树中。解决这个问题的方法是让每个节点存储和维护其自己子树的基数,然后你就可以通过索引找到节点,而不必遍历整棵树。另一种可能性是使用跳表,它支持平均O(log n)时间的插入。
然而,问题在于你正在使用Python编写代码,因此在尝试编写比内置的list.insert方法更快的代码时,你会面临一个重大劣势,因为该方法是用C实现的,而Python代码比C代码慢得多。即使对于非常大的n,也不足以击败内置的O(n)实现,甚至n = 1,000,000可能也不够大,无法获得8倍或更多的优势。如果你尝试实现自己的数据结构,并且发现它速度不够快,那么这可能意味着很多浪费的努力。
我认为这个任务的预期解决方案会像 Heap Overflow 的答案一样。话虽如此,还有另一种方法来处理这个问题,值得考虑,因为它避免了在插入顺序错误时计算正确索引的复杂性。我的想法是利用 list.insert 的效率,但在较短的列表上调用它。
如果数据仍存储在Python列表中,则可以使用list.insert方法获得C实现的效率,但如果列表较短,则insert方法将更快。由于您只需要通过一个常数因子获胜,因此可以将输入列表分成大约256个大小相等的子列表。然后对于每个插入,将其插入到正确的子列表中的正确索引位置;最后将子列表合并在一起。时间复杂度为O(nm),与“naive”解决方案相同,但具有较低的常数因子。
为了计算正确的插入索引,我们需要从正在插入的子列表左侧减去子列表的长度;我们可以在prefix sum array中存储累积子列表长度,并使用numpy有效地更新该数组。这是我的实现:
from itertools import islice, chain, accumulate
import numpy as np

def do_insertions_split(lst, insertions, num_sublists=256):
    n = len(lst)
    sublist_len = n // num_sublists
    lst_iter = iter(lst)
    sublists = [list(islice(lst_iter, sublist_len)) for i in range(num_sublists-1)]
    sublists.append(list(lst_iter))
    lens = [0]
    lens.extend(accumulate(len(s) for s in sublists))
    lens = np.array(lens)

    for idx, val in insertions:
        # could use binary search, but num_sublists is small
        j = np.argmax(lens >= idx)
        sublists[j-1].insert(idx - lens[j-1], val)
        lens[j:] += 1

    return list(chain.from_iterable(sublists))

它的速度不如 @iz_ 的实现(从评论中链接),但它比简单算法快了近20倍,这已足够满足问题陈述的要求。下面的时间是使用 timeit 在一个长度为1,000,000的列表上进行10,000次插入操作测量得出的。
simple -> 2.1252768037122087 seconds
iz -> 0.041302349785668824 seconds
split -> 0.10893724981304054 seconds

请注意,我的解决方案仍然比@iz_的慢大约2.5倍。然而,@iz_的解决方案需要插入点被排序,而我的解决方案即使它们未排序也可以工作:
lst = list(range(1_000_000))
insertions = [(randint(0, len(lst)), "x") for _ in range(10_000)]

# uncomment if the insertion points should be sorted
# insertions.sort()

r1 = do_insertions_simple(lst, insertions)
r2 = do_insertions_iz(lst, insertions)
r3 = do_insertions_split(lst, insertions)

if r1 != r2: print('iz failed') # prints
if r1 != r3: print('split failed') # doesn't print

这是我的计时代码,以便其他人进行比较。我尝试了几个不同的 num_sublists 值; 200到1,000之间的任何值都似乎差不多好。
from timeit import timeit

algorithms = {
    'simple': do_insertions_simple,
    'iz': do_insertions_iz,
    'split': do_insertions_split,
}
reps = 10

for name, func in algorithms.items():
    t = timeit(lambda: func(lst, insertions), number=reps) / reps
    print(name, '->', t, 'seconds')

好的,这听起来像我预期的一样(我猜你也是)。如果numpy和list.insert的速度相同,那么大约sqrt(n)=1000应该是最好的选择,但由于移动内存可能仍然更快,所以稍微小一些的数字应该更好。 - Kelly Bundy
确实,你在我编辑的过程中抓住了我,我已经准备好了!我认为sqrt(n)也应该是一个高估值,因为我们在循环中使用了argmax以及对数组进行了+= 1操作。但是仔细考虑一下,我们可以动态选择num_sublistsint(sqrt(n))的某个常数倍,这样时间复杂度就是O(n + m sqrt(n)),从渐近上来看击败了朴素算法。 - kaya3

1

list(y) 试图迭代 y 并创建其元素的列表。如果 y 是一个整数,则不能迭代,会返回您提到的错误。您可能想要创建包含 y 的列表字面量,如下所示:[y]


谢谢!我把list(y)改成了[y],现在不再报错了。然而,就像iz_说的那样,它并不比insert()函数更快。 - Enscivwy
除非进行一些重大的欺骗,或者在很大程度上重新定义问题,否则你不太可能比内置的insert函数更快。任何加速都将来自于采用更高效的方法,而不是加快一个慢速的方法。 - Aaron
那么,我能采取更高效的方法来加快速度吗? (就像堆溢出指出的那样,如果问题没有解释清楚,我需要一种更快的方式来插入多个元素,而不是一个) - Enscivwy

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接