我能否加速这个基本的线性代数代码?

11

我在想是否有可能使用Numpy或数学技巧来优化以下内容。

def f1(g, b, dt, t1, t2):
  p = np.copy(g)
  for i in range(dt):
    p += t1*np.tanh(np.dot(p, b)) + t2*p
  return p

其中g是长度为n的向量,b是一个 nxn 的矩阵,dt是迭代次数,t1t2是标量。

我已经很快地想不出如何进一步优化这个函数了,因为在循环中,在方程的所有三个项中都使用了p:当加到自身时;在点积中;在标量乘法中。

但也许有一种不同的方式来表示这个函数或者有其他的技巧来提高它的效率。如果可能的话,我更倾向于不使用Cython等工具,但如果速度提升显著,我愿意使用它。预先感谢你们的帮助,如果这个问题在某些方面超出范围,请原谅。

更新:

到目前为止提供的答案更加关注输入/输出值的取值,以避免不必要的操作。我现在已经更新了MWE,为变量提供了适当的初始化值(我没有期望优化的想法来自那边 - 抱歉)。g将在范围[-1, 1]内,b将在范围[-无穷大,无穷大]内。近似输出不是一个选择,因为返回的向量稍微相似的输入可能会得到相同的向量,因此这不是一个选择。


MWE:

import numpy as np
import timeit

iterations = 10000

setup = """
import numpy as np
n  = 100
g  = np.random.uniform(-1, 1, (n,)) # Updated.
b  = np.random.uniform(-1, 1, (n,n)) # Updated.
dt = 10
t1 = 1
t2 = 1/2

def f1(g, b, dt, t1, t2):
  p = np.copy(g)
  for i in range(dt):
    p += t1*np.tanh(np.dot(p, b)) + t2*p
  return p
"""

functions = [
  """
    p = f1(g, b, dt, t1, t2)
  """
]

if __name__ == '__main__':
  for function in functions:
    print(function)
    print('Time = {}'.format(timeit.timeit(function, setup=setup,
                                           number=iterations)))

1
你的意思是一个长度为n的向量,而不是“n维向量”吗? - zhangxaochen
3
非常好的问题!然而,加速这个可能会很困难... 这是 numba.jit(http://numba.pydata.org/)应该适用的类型(Cython也是如此)。 然而,这是一个重要的依赖项,如果可能的话,您确实表示更愿意坚持使用“纯粹”的numpy。 - Joe Kington
1
@JoeKington 谢谢。我记得大约一个月前在这段代码上尝试使用Cython,但速度提升微不足道。我想这是因为 (a) 我对 Cython 不熟悉; (b) 我保持了大部分相同的东西,而没有例如在点积中使用 C 循环。如果Cython可以显著提高性能,我会非常高兴切换到它。现在看看numba.jit -- 谢谢! :-) - sudosensei
2
如果迭代次数只有10次,Cython/Numba帮助不大。没有数学技巧,DGEMM调用的效率将对整体速度产生最大的贡献。您当前正在使用优化的BLAS与NumPy吗?当然,这是在您使用比您描述的更大的矩阵的情况下,因为每次调用仅需要250微秒。 - Daniel
1
@Ophion 当我使用Cython运行我的测试时,我进行了1000多次迭代,但是我没有看到任何显著的改进 - 但是,我将这归因于我对Cython知识的不足,而不是潜力的不足。另外,是的,我正在使用优化的BLAS与Numpy。迭代次数和矩阵长度通常都会小于1000,但是这个函数被调用的次数超过一百万次。即使仅仅减少10%的执行时间,也将是一个显著的改进。 - sudosensei
显示剩余4条评论
2个回答

4
让代码在没有cython或jit的情况下更快运行将非常困难,一些数学技巧可能更容易。我认为如果我们定义一个k(g, b) = f1(g, b, n + 1, t1, t2)/f1(g, b, n, t1, t2)(当n>0),那么k函数应该有一个极限t1+t2(我还没有确凿的证据,只是一种直觉;这也可能是E(g)=0和E(p)=0的一个特例)。对于t1=1和t2=0.5,k()似乎很快地接近极限,在N>100时,它几乎是1.5的恒定值。因此,我认为数值逼近法应该是最简单的方法。enter image description here
In [81]:

t2=0.5
data=[f1(g, b, i+2, t1, t2)/f1(g, b, i+1, t1, t2) for i in range(1000)]
In [82]:

plt.figure(figsize=(10,5))
plt.plot(data[0], '.-', label='1')
plt.plot(data[4], '.-', label='5')
plt.plot(data[9], '.-', label='10')
plt.plot(data[49], '.-', label='50')
plt.plot(data[99], '.-', label='100')
plt.plot(data[999], '.-', label='1000')
plt.xlim(xmax=120)
plt.legend()
plt.savefig('limit.png')

In [83]:

data[999]
Out[83]:
array([ 1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,
        1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,
        1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,
        1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,
        1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,
        1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,
        1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,
        1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,
        1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,
        1.5])

1
这也支持了我的回答,因为函数关系本质上是线性递归。 然而需要注意的是,OP指出迭代次数大约为10,因此大N的极限可能无效(尽管在这种情况下它仍将有效)。 - Hooked
+1!好主意!但请看我的更新帖子,了解期望的输入是什么。对于混淆感到抱歉——我只是没有预料到优化想法会以输入/输出值和近似形式出现。你是正确的:输出将饱和,但值并不总是足够大才会发生这种情况——实际上,输入始于非常小的幅度,通常随着每次函数调用而增加。 - sudosensei

4

我有些犹豫地给出这个答案,因为我认为它可能是你提供的输入数据的结果。尽管如此,请注意当 x>>1 时,tanh(x) ~ 1。在我运行过的所有时间里,你的输入数据都满足 x = np.dot(p,b) >> 1,因此我们可以用 f2 替换 f1

def f1(g, b, dt, t1, t2):
  p = np.copy(g)
  for i in range(dt):
      p += t1*np.tanh(np.dot(p, b)) + t2*p
  return p

def f2(g, b, dt, t1, t2):
  p = np.copy(g)
  for i in range(dt):
      p += t1 + t2*p
  return p

print np.allclose(f1(g,b,dt,t1,t2), f2(g,b,dt,t1,t2))

的确表明这两个函数在数值上是等价的。请注意,f2是一个非齐次线性递推关系,如果选择这样做,可以一步解决。


+1!谢谢。但是,点积的tanh不一定近似于1。请查看我的更新帖子,了解预期输入以及为什么近似不是我情况下可行的选项。对于造成的困惑,我表示歉意。 - sudosensei

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接