在Cython中定义NumPy数组而不产生Python开销

3
我一直在尝试学习Cython以加速我的一些计算。以下是我正在尝试做的一部分内容:使用递归公式积分微分方程,并利用NumPy数组。我已经比纯Python版本实现了大约100倍的速度提升。但是,通过查看-a Cython命令生成的HTML文件,我发现可以进一步提高速度。下面是我的代码(在HTML文件中变成黄色的行需要改为白色):
%%cython
import numpy as np
cimport numpy as np
cimport cython
from libc.math cimport exp,sqrt

@cython.boundscheck(False)
cdef double riccati_int(double j, double w, double h, double an, double d):
    cdef:
        double W
        double an1
    W = sqrt(w**2 + d**2)
    #dark_yellow
    an1 = ((d - (W + w) * an) * exp(-2 * W * h / j ) - d - (W - w) * an) / 
          ((d * an - W + w) * exp(-2 * W * h / j) - d * an - W - w) 
    return an1


def acalc(double j, double w):
    cdef:
        int xpos, i, n
        np.ndarray[np.int_t, ndim=1] xvals
        np.ndarray[np.double_t, ndim=1] h, a
    xpos = 74
    xvals = np.array([0, 8, 23, 123, 218], dtype=np.int)     #dark_yellow
    h = np.array([1, .1, .01, .1], dtype=np.double)          #dark_yellow
    a = np.empty(219, dtype=np.double)                       #dark_yellow
    a[0] = 1 / (w + sqrt(w**2 + 1))                          #light_yellow

    for i in range(h.size):                                  #dark_yellow
        for n in range(xvals[i], xvals[i + 1]):              #light_yellow
            if n < xpos:
                a[n+1] = riccati_int(j, w, h[i], a[n], 1.)   #light_yellow
            else:
                a[n+1] = riccati_int(j, w, h[i], a[n], 0.)   #light_yellow
    return a  

我认为上面标记的所有9行都可以通过适当的调整变为白色。一个问题是能否正确地定义NumPy数组。但更重要的是让第一行标记的代码高效地运行,因为这是大部分计算所在的地方。我尝试阅读生成的C代码,它显示在点击黄线后的HTML文件中,但我实在不知道如何阅读那段代码。如果有人能帮忙解决,将不胜感激。
2个回答

1

我认为您不需要关心不在循环中的黄色线。添加以下编译指令将使循环中的三条线更快:

@cython.cdivision(True)
cdef double riccati_int(double j, double w, double h, double an, double d):
    pass

@cython.boundscheck(False)
@cython.wraparound(False)
def acalc(double j, double w):
    pass

添加cdivision指令解决了最大的问题,但只提高了约10%的速度,所以似乎我已经尽可能地高效了。感谢您的帮助! - crr

0

我不确定是否有区别,但你可以使用内存视图来处理数组,例如:

cdef double [:] h = np.array([1, .1, .01, .1], dtype=np.double) #dark_yellow
cdef double [:] a = np.empty(219, dtype=np.double)              #dark_yellow

同时为四个静态值创建一个numpy数组有点过度了。这可以用静态C数组来替换

cdef double *h = [1, .1, .01, .1]

然而,正如提到的那样,循环中的内容最为重要。由于 line profiler 不适用于 cython(据我所知),因此使用 time 模块来对函数进行基准测试,除了使用 cProfile 之外,可能会给您一个想法,即必须在上下文中评估 cython 日志中线条颜色的强度。
建议使用 Python 类型进行索引,就像我学到的那样
size_t i, n
Py_ssize_t i, n

第二个是有符号版本


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接