Cython 优化 NumPy 数组求和的关键部分

6

假设有一个列表 L = [A_1, A_2, ..., A_n],每个A_i都是长度为1024的numpy.int32数组。

(大多数情况下 1000 < n < 4000)。

经过一些分析,我发现其中最耗时的操作之一是加和:

def summation():
    # L is a global variable, modified outside of this function
    b = numpy.zeros(1024, numpy.int32)
    for a in L:
        b += a
    return b

注意:我不认为我可以定义大小为1024 x n的2D数组,因为n是不固定的:某些元素在运行时动态添加/删除到L中,因此len(L) = n可能在1000和4000之间变化。

我能通过使用Cython获得显着的改进吗?如果可以,应该如何重构这个小函数(是否需要添加一些cdef类型?)

或者你能看到其他可能的改进吗?

1个回答

2

以下是Cython代码,请确保L中的每个数组都是C_CONTIGUOUS:

import cython
import numpy as np
cimport numpy as np

@cython.boundscheck(False)
@cython.wraparound(False)
def sum_list(list a):
    cdef int* x
    cdef int* b
    cdef int i, j
    cdef int count
    count = len(a[0])
    res = np.zeros_like(a[0])
    b = <int *>((<np.ndarray>res).data)
    for j in range(len(a)):
        x = <int *>((<np.ndarray>a[j]).data)
        for i in range(count):
            b[i] += x[i]
    return res

在我的电脑上,它大约快了4倍。


非常感谢!这对我帮助很大!您知道如果a [0],a [1]等是int16 numpy数组,而我希望结果res仍然是int32 numpy数组,我该如何修改此代码吗? - Basj
如果输入数组是 int16,而输出仍然是 int32,我将 cdef int* x 替换为 cdef short* x,并将 x = <int *>((<np.ndarray>a[j]).data) 替换为 <short *>。您认为这是最好的方法吗? - Basj

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接