将Python代码向量化以提高性能

3
我正在用Python编写科学代码,计算系统的能量。 这是我的函数: cte1、cte2、cte3、cte4是之前计算过的常量; pii是np.pi(预先计算的,否则会减慢循环速度)。我计算总能量的3个分量,然后将它们相加。
def calc_energy(diam): 
    Energy1 = cte2*((pii*diam**2/4)*t)
    Energy2 = cte4*(pii*diam)*t
    d=diam/t
    u=np.sqrt((d)**2/(1+d**2))
    cc= u**2
    E = sp.special.ellipe(cc) 
    K = sp.special.ellipk(cc) 
    Id=cte3*d*(d**2+(1-d**2)*E/u-K/u)
    Energy3 = cte*t**3*Id
    total_energy = Energy1+Energy2+Energy3
    return (total_energy,Energy1)

我的第一个想法是简单地循环遍历直径的所有值:

start_diam, stop_diam, step_diam = 1e-10, 500e-6, 1e-9 #Diametre
diametres = np.arange(start_diam,stop_diam,step_diam)

for d in diametres:  
    res1,res2 = calc_energy(d)
    totalEnergy.append(res1)
    Energy1.append(res2)

为了加快计算速度,我决定使用NumPy进行向量化,代码如下:

diams = diametres.reshape(-1,1) #If not reshaped, calculations won't run
r1 = np.apply_along_axis(calc_energy,1,diams)

然而,“向量化”解决方案不能正常工作。在计时时,第一种解决方案需要5秒,而第二种解决方案需要18秒。

我想我做错了一些事情,但无法确定具体是什么。


我尝试将我的代码进行Cython化。只需很少的更改,我就获得了40%到50%的性能提升。然而,我猜我可以通过简单地进行向量化来获得同等的改进。 - Mike
1
要进行矢量化,您需要避免在数组的每个项上应用函数,因为这样会带来Python的所有开销。相反,您可以批量进行计算,例如Energy1 = cte2*((pii*diametres**2/4)*t)将返回一个Energy1值数组。 - roganjosh
1
您的函数 calc_energy,目前可以直接接受数组 diametres 作为输入,不需要进行修改。这样是否能够得到正确的输出呢?我无法对其进行检查以确定它是否提供了正确的值。 - roganjosh
1
@roganjosh,是的,我刚刚检查了calc_energy(diametres) - 它将返回产生相同总和的数组。我认为你应该把它发布为答案。 - MaxU - stand with Ukraine
2
抱歉我离开电脑了。解决方案有效。现在计算时间不到0.1秒,而不是5秒!性能非常好。 - Mike
显示剩余2条评论
1个回答

2

使用当前的方法,您正在将Python函数应用于数组的每个元素,这会带来额外的开销。相反,您可以将整个数组传递给函数,并获得一个答案数组。您现有的函数似乎没有任何修改就可以正常工作。

import numpy as np
from scipy import special
cte = 2
cte1 = 2
cte2 = 2
cte3 = 2
cte4 = 2
pii = np.pi

t = 2

def calc_energy(diam): 
    Energy1 = cte2*((pii*diam**2/4)*t)
    Energy2 = cte4*(pii*diam)*t
    d=diam/t
    u=np.sqrt((d)**2/(1+d**2))
    cc= u**2
    E = special.ellipe(cc) 
    K = special.ellipk(cc) 
    Id=cte3*d*(d**2+(1-d**2)*E/u-K/u)
    Energy3 = cte*t**3*Id
    total_energy = Energy1+Energy2+Energy3
    return (total_energy,Energy1)

start_diam, stop_diam, step_diam = 1e-10, 500e-6, 1e-9 #Diametre
diametres = np.arange(start_diam,stop_diam,step_diam)

a = calc_energy(diametres) # Pass the whole array

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接