我应该如何使用Cython更好地解决微分方程,以获得更快的计算速度?

8
我希望能够缩短Scipy的odeint求解微分方程所需的时间。
为了练习,我使用了Python科学计算中提到的示例作为模板。因为odeint需要一个函数f作为参数,所以我编写了一个静态类型的Cython版本,并希望odeint的运行时间能够显著减少。
函数f包含在名为ode.pyx的文件中,如下所示:
import numpy as np
cimport numpy as np
from libc.math cimport sin, cos

def f(y, t, params):
  cdef double theta = y[0], omega = y[1]
  cdef double Q = params[0], d = params[1], Omega = params[2]
  cdef double derivs[2]
  derivs[0] = omega
  derivs[1] = -omega/Q + np.sin(theta) + d*np.cos(Omega*t)
  return derivs

def fCMath(y, double t, params):
  cdef double theta = y[0], omega = y[1]
  cdef double Q = params[0], d = params[1], Omega = params[2]
  cdef double derivs[2]
  derivs[0] = omega
  derivs[1] = -omega/Q + sin(theta) + d*cos(Omega*t)
  return derivs

我随后创建了一个名为setup.py的文件来编译该函数:

from distutils.core import setup
from Cython.Build import cythonize

setup(ext_modules=cythonize('ode.pyx'))

解决微分方程的脚本(还包括Python版本的f)称为solveODE.py,其代码如下:

import ode
import numpy as np
from scipy.integrate import odeint
import time

def f(y, t, params):
    theta, omega = y
    Q, d, Omega = params
    derivs = [omega,
             -omega/Q + np.sin(theta) + d*np.cos(Omega*t)]
    return derivs

params = np.array([2.0, 1.5, 0.65])
y0 = np.array([0.0, 0.0])
t = np.arange(0., 200., 0.05)

start_time = time.time()
odeint(f, y0, t, args=(params,))
print("The Python Code took: %.6s seconds" % (time.time() - start_time))

start_time = time.time()
odeint(ode.f, y0, t, args=(params,))
print("The Cython Code took: %.6s seconds ---" % (time.time() - start_time))

start_time = time.time()
odeint(ode.fCMath, y0, t, args=(params,))
print("The Cython Code incorpoarting two of DavidW_s suggestions took: %.6s seconds ---" % (time.time() - start_time))

我接下来运行:

python setup.py build_ext --inplace
python solveODE.py 

在终端中。

Python版本的时间大约为0.055秒,而Cython版本需要大约0.04秒。

有人推荐使用Cython改进我的微分方程求解尝试吗?最好不要修改odeint例程本身。

编辑

我在两个文件ode.pyxsolveODE.py中采纳了DavidW的建议。使用这些建议只需大约0.015秒即可运行代码。


你应该将这个内容发布到codereview上。 - Farhan.K
1
我可能会尝试使用numba而不是cython,但任何差异可能很小。大部分计算时间可能是在odeint调用您的函数时发生的上下文切换。您可能会从编写自己的数值积分函数(再次使用cython或numba)以避免上下文切换中获得最佳收益。 - Aaron
1
@fabian 我没有阅读源代码本身,但是你的函数 fode.f 是 Python 对象,每次调用至少需要一次上下文切换(0-200 步长为 0.05 的 4000 次调用),否则 odeint 将无法使用任何旧的自定义用户函数。我已经通过 numba 获得了 4 倍的加速,但现在我正在努力获得更多... - Aaron
1
@Farhan.K 不要仅仅因为他们想让代码更快就建议使用CodeReview。请注意在各自论坛上标签的受欢迎程度。如果您想优化C++或Java代码,CR非常好,但是在处理像“Cython”这样的专业软件包时效果不佳。 - hpaulj
我想指出的是,你的代码正在运行python模式。它是静态编译的,但你可能需要考虑cpdef。不过,我认为(正如其他人指出的那样),重要的工作将由ode求解器完成。 - cvanelteren
显示剩余6条评论
4个回答

5

最简单的改变(可能会给你带来很多好处)是使用C数学库中的sincos代替对数字进行操作,而不是使用numpy。调用numpy并计算出它不是数组所花费的时间相当昂贵。

from libc.math cimport sin, cos

    # later
    -omega/Q + sin(theta) + d*cos(Omega*t)

我会倾向于为输入项d分配一个类型(在不改变接口的情况下,其他输入项都不容易进行类型分配):

def f(y, double t, params):

我认为我也会像你在Python版本中一样返回一个列表。我认为使用C数组并没有太多的优势。


谢谢您的建议!确实,通过使用C数学库,代码相对于我的版本提高了大约40%,总体上比Python代码快大约两倍。按照您的建议键入“t”进一步改善了代码几个百分点。 - fabian

3
简而言之,使用numba.jit可以使速度提高3倍...
我对cython没有太多经验,但我的机器似乎在你的纯python版本中获得了类似的计算时间,所以我们应该能够大致比较。我使用numba编译函数f(稍微重写了一下,以便更好地与编译器配合)。
def f(y, t, params):
    return np.array([y[1], -y[1]/params[0] + np.sin(y[0]) + params[1]*np.cos(params[2]*t)])

numba_f = numba.jit(f)

numba_f 替换你的 ode.f,输出如下...
The Python Code took: 0.0468 seconds
The Numba Code took: 0.0155 seconds

我想知道是否可以复制odeint并使用numba编译以进一步提高速度...(我不能这样做)

这是我的Runge-Kutta数值微分方程积分器:

#function f is provided inline (not as an arg)
def runge_kutta(y0, steps, dt, args=()): #improvement on euler's method. *note: time steps given in number of steps and dt
    Y = np.empty([steps,y0.shape[0]])
    Y[0] = y0
    t = 0
    n = 0
    for n in range(steps-1):
        #calculate coeficients
        k1 = f(Y[n], t, args) #(euler's method coeficient) beginning of interval
        k2 = f(Y[n] + (dt * k1 / 2), t + (dt/2), args) #interval midpoint A
        k3 = f(Y[n] + (dt * k2 / 2), t + (dt/2), args) #interval midpoint B
        k4 = f(Y[n] + dt * k3, t + dt, args) #interval end point

        Y[n + 1] = Y[n] + (dt/6) * (k1 + 2*k2 + 2*k3 + k4) #calculate Y(n+1)
        t += dt #calculate t(n+1)
    return Y

天真的循环函数通常在编译后是最快的,尽管这可能可以重新组织以获得更好的速度。值得注意的是,它给出了与odeint不同的答案,在大约2000步后偏差高达0.001,并且在3000步后完全不同。对于该函数的numba版本,我只是用numba_f替换了f,并添加了@numba.jit作为装饰器进行编译。在这种情况下,像预期的那样,纯Python版本非常慢,但是numba版本与使用odeint的numba版本一样快(再次提醒,结果可能因情况而异)。

using custom integrator
The Python Code took: 0.2340 seconds
The Numba Code took: 0.0156 seconds

这里是关于预先编译的一个例子。由于我没有必要的工具链进行编译,也没有管理员权限来安装它,所以出现了缺少必要编译器的错误提示,但除此之外应该可以正常工作。
import numpy as np
from numba.pycc import CC

cc = CC('diffeq')

@cc.export('func', 'f8[:](f8[:], f8, f8[:])')
def func(y, t, params):
    return np.array([y[1], -y[1]/params[0] + np.sin(y[0]) + params[1]*np.cos(params[2]*t)])

cc.compile()

非常感谢您详细的回答!我从中学到了很多。不过,当我使用@jit声明时,我的代码被严重减速。实际上,代码大约需要2秒的时间。我的Python版本随Anaconda 4.3.1一起发布,提供了numbda 0.30.1。您对这样慢速结果有什么想法吗? - fabian
1
@fabian 初始编译需要时间,但每次后续运行应该都会很快。有人已经在主要评论线程中提到了这一点。这就像在运行时进行Cython编译而不是事先进行。如果您查看numba的文档,它确实也支持预编译,但我从未使用过。 - Aaron
1
这是相关的文档,可以实现这个功能。你基本上需要创建一个库,并将其编译成扩展库。 - Aaron
1
@fabian 或者在 jit 中使用 cache=True(http://numba.pydata.org/numba-doc/0.30.1/reference/jit-compilation.html)。 - DavidW
1
@fabian 我添加了一个编译的示例。但是我现在这台电脑上没有编译器,所以无法进行测试。(缺少vcvarsall.bat错误) - Aaron
显示剩余2条评论

2

NumbaLSODA 花费了0.00088秒(比Cython快17倍)。

from NumbaLSODA import lsoda_sig, lsoda
import numba as nb
import numpy as np
import time

@nb.cfunc(lsoda_sig)
def f(t, y_, dy, p_):
    p = nb.carray(p_, (3,))
    y = nb.carray(y_, (2,))
    theta, omega = y
    Q, d, Omega = p
    dy[0] = omega
    dy[1] = -omega/Q + np.sin(theta) + d*np.cos(Omega*t)

funcptr = f.address # address to ODE function
y0 = np.array([0.0, 0.0])
data = np.array([2.0, 1.5, 0.65])
t = np.arange(0., 200., 0.05)

start_time = time.time()
usol, success = lsoda(funcptr, y0, t, data = data)
print("NumbaLSODA took: %.8s seconds ---" % (time.time() - start_time))

结果
NumbaLSODA took: 0.000880 seconds ---

2
如果其他人使用其他模块回答这个问题,我也可以加入讨论:
我是JiTCODE的作者,它接受用SymPy符号编写的ODE,然后将此ODE转换为Python模块的C代码,编译此C代码,加载结果并将其用作SciPy's ODE的导数。你的例子在JiTCODE中的翻译如下:
from jitcode import jitcode, provide_basic_symbols
import numpy as np
from sympy import sin, cos
import time

Q = 2.0
d = 1.5
Ω = 0.65

t, y = provide_basic_symbols()

f = [
    y(1),
    -y(1)/Q + sin(y(0)) + d*cos(Ω*t)
    ]

initial_state = np.array([0.0,0.0])

ODE = jitcode(f)
ODE.set_integrator("lsoda")
ODE.set_initial_value(initial_state,0.0)

start_time = time.time()
data = np.vstack(ODE.integrate(T) for T in np.arange(0.05, 200., 0.05))
end_time = time.time()
print("JiTCODE took: %.6s seconds" % (end_time - start_time))

这需要0.11秒,与基于odeint的解决方案相比速度非常慢,但这不是由于实际积分而是结果处理的方式:虽然odeint在内部高效地直接创建数组,但这里是通过Python完成的。根据您的操作,这可能是一个关键的劣势,但对于更粗糙的采样或更大的微分方程,这很快就变得不相关了。

因此,让我们删除数据收集,只看积分,将最后几行替换为以下内容:

ODE = jitcode(f)
ODE.set_integrator("lsoda", max_step=0.05, nsteps=1e10)
ODE.set_initial_value(initial_state,0.0)

start_time = time.time()
ODE.integrate(200.0)
end_time = time.time()
print("JiTCODE took: %.6s seconds" % (end_time - start_time))

请注意,我设置了 max_step=0.05,以强制积分器至少做和您示例中一样多的步骤,并确保唯一的区别是积分结果没有存储到某个数组中。这将在0.010秒内运行。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接