Cython优化

5

我正在使用Python编写一个相当大的模拟程序,希望从Cython中获得更好的性能。然而,在下面的代码中,即使它包含一个相当大的循环(约100k次迭代),我似乎并没有得到太多的性能提升。我是犯了一些初学者的错误还是这个循环大小太小以至于没有太大的影响?(在我的测试中,Cython代码只快了大约2倍)。

import numpy as np;
cimport numpy as np;
import math

ctypedef np.complex64_t cpl_t
cpl = np.complex64

def example(double a, np.ndarray[cpl_t,ndim=2] A):

    cdef int N = 100

    cdef np.ndarray[cpl_t,ndim=3] B = np.zeros((3,N,N),dtype = cpl)

    cdef Py_ssize_t n, m;
    for n in range(N):
        for m in range(N):

            if np.sqrt(A[0,n]) > 1:
                B[0,n,m] = A[0,n] + 1j * A[0,m]

    return B;

7
你在循环中调用了 np.sqrt,这看起来会造成性能问题。反正变量 a 的值从未改变,为什么不在循环之前加上 if a <= 1: return B 呢?这样可以避免在循环中重复计算。 - user2357112
@GWW:在我看来,那像是第一行。 - user2357112
@user2357112 谢谢,那是我忽略了的事情,我可以将其移出。实际上,在Cython中,像np.sqrt()或np.exp()这样的数学运算是我应该避免使用的吗? - physicsGuy
1
这些调用会回到Python,所以如果你想在GIL之外运行(例如多线程),你可能会想要避免它们。 - Sergei Lebedev
在Cython中,通常最好使用C标准库的数学运算。 - DavidW
1个回答

9
你应该使用编译指令。我用Python写了你的函数。
import numpy as np

def example_python(a, A):
    N = 100
    B = np.zeros((3,N,N),dtype = np.complex)
    aux = np.sqrt(A[0])
    for n in range(N):
        if aux[n] > 1:
            for m in range(N):
                B[0,n,m] = A[0,n] + 1j * A[0,m]
return B

而在Cython中(您可以在此处了解有关编译器指令的信息

import cython
import numpy as np
cimport numpy as np

ctypedef np.complex64_t cpl_t
cpl = np.complex64

@cython.boundscheck(False) # compiler directive
@cython.wraparound(False) # compiler directive
def example_cython(double a, np.ndarray[cpl_t,ndim=2] A):

    cdef int N = 100
    cdef np.ndarray[cpl_t,ndim=3] B = np.zeros((3,N,N),dtype = cpl)
    cdef np.ndarray[float, ndim=1] aux
    cdef Py_ssize_t n, m
    aux = np.sqrt(A[0,:]).real
    for n in range(N):
        if aux[n] > 1.:
            for m in range(N):
                B[0,n,m] = A[0,n] + 1j * A[0,m]
    return B

我比较这两个函数
c = np.array(np.random.rand(100,100)+1.5+1j*np.random.rand(100,100), dtype=np.complex64)

%timeit example_python(100, c)
10 loops, best of 3: 61.8 ms per loop

%timeit example_cython(100, c)
10000 loops, best of 3: 134 µs per loop

在这种情况下,Cython比Python快大约450倍。

非常感谢,那正是我正在寻找的!我还不知道编译器指令。 - physicsGuy

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接