为什么numpy的einsum比numpy的内置函数慢?

18

我通常能够从numpy的einsum函数中获得良好的性能(我也喜欢它的语法)。@Ophion对这个问题的回答表明,在测试的情况下,einsum始终优于“内置”函数(有时少量,有时很多)。但是我遇到了一个einsum非常慢的情况。考虑以下等效函数:

(M, K) = (1000000, 20)
C = np.random.rand(K, K)
X = np.random.rand(M, K)

def func_dot(C, X):
    Y = X.dot(C)
    return np.sum(Y * X, axis=1)

def func_einsum(C, X):
    return np.einsum('ik,km,im->i', X, C, X)

def func_einsum2(C, X):
    # Like func_einsum but break it into two steps.
    A = np.einsum('ik,km', X, C)
    return np.einsum('ik,ik->i', A, X)

我希望func_einsum能够运行最快,但这不是我遇到的情况。在配备超线程的四核CPU上运行,使用numpy版本1.9.0.dev-7ae0206和OpenBLAS进行多线程处理,我得到了以下结果:

In [2]: %time y1 = func_dot(C, X)
CPU times: user 320 ms, sys: 312 ms, total: 632 ms
Wall time: 209 ms
In [3]: %time y2 = func_einsum(C, X)
CPU times: user 844 ms, sys: 0 ns, total: 844 ms
Wall time: 842 ms
In [4]: %time y3 = func_einsum2(C, X)
CPU times: user 292 ms, sys: 44 ms, total: 336 ms
Wall time: 334 ms

当我把K增加到200时,差异更加明显:

In [2]: %time y1= func_dot(C, X)
CPU times: user 4.5 s, sys: 1.02 s, total: 5.52 s
Wall time: 2.3 s
In [3]: %time y2= func_einsum(C, X)
CPU times: user 1min 16s, sys: 44 ms, total: 1min 16s
Wall time: 1min 16s
In [4]: %time y3 = func_einsum2(C, X)
CPU times: user 15.3 s, sys: 312 ms, total: 15.6 s
Wall time: 15.6 s

有人能解释一下为什么这里使用 einsum 很慢吗?

如果有影响的话,这是我的 numpy 配置:

In [6]: np.show_config()
lapack_info:
    libraries = ['openblas']
    library_dirs = ['/usr/local/lib']
    language = f77
atlas_threads_info:
    libraries = ['openblas']
    library_dirs = ['/usr/local/lib']
    define_macros = [('ATLAS_WITHOUT_LAPACK', None)]
    language = c
    include_dirs = ['/usr/local/include']
blas_opt_info:
    libraries = ['openblas']
    library_dirs = ['/usr/local/lib']
    define_macros = [('ATLAS_INFO', '"\\"None\\""')]
    language = c
    include_dirs = ['/usr/local/include']
atlas_blas_threads_info:
    libraries = ['openblas']
    library_dirs = ['/usr/local/lib']
    define_macros = [('ATLAS_INFO', '"\\"None\\""')]
    language = c
    include_dirs = ['/usr/local/include']
lapack_opt_info:
    libraries = ['openblas', 'openblas']
    library_dirs = ['/usr/local/lib']
    define_macros = [('ATLAS_WITHOUT_LAPACK', None)]
    language = f77
    include_dirs = ['/usr/local/include']
lapack_mkl_info:
  NOT AVAILABLE
blas_mkl_info:
  NOT AVAILABLE
mkl_info:
  NOT AVAILABLE

5
当比较np.einsumnp.tensordot时,我注意到了相同的事情。我怀疑这可能只是通用性所付出的代价 - np.dot 调用了高度优化的 BLAS 子例程(dgemm等),用于计算两个矩阵之间的点积特殊情况,而 np.einsum 处理各种可能涉及多个输入矩阵的情况。我不确定其确切细节,但我怀疑设计 np.einsum 在所有这些情况下都能充分利用 BLAS 将会很困难。 - ali_m
2个回答

23

你可以两全其美:

def func_dot_einsum(C, X):
    Y = X.dot(C)
    return np.einsum('ij,ij->i', Y, X)

在我的系统上:

In [7]: %timeit func_dot(C, X)
10 loops, best of 3: 31.1 ms per loop

In [8]: %timeit func_einsum(C, X)
10 loops, best of 3: 105 ms per loop

In [9]: %timeit func_einsum2(C, X)
10 loops, best of 3: 43.5 ms per loop

In [10]: %timeit func_dot_einsum(C, X)
10 loops, best of 3: 21 ms per loop

当可用时,np.dot会使用BLAS、MKL或您拥有的任何库。因此,对np.dot的调用几乎肯定是多线程的。np.einsum具有自己的循环,因此不使用任何这些优化,除了其自身对SIMD的使用,以加速超出香草C实现的速度。
然后还有一个多输入einsum调用,运行速度要慢得多...numpy einsum的源代码非常复杂,我并不完全理解它。因此请注意,以下内容充其量只是推测,但我认为正在发生的事情是这样的...
当您运行诸如np.einsum('ij, ij->i', a, b)之类的内容时,与执行np.sum(a*b, axis=1)相比,避免了必须实例化带有所有乘积的中间数组,并对其进行两次循环的开销。所以在低级别上所发生的事情类似于:
for i in range(I):
    out[i] = 0
    for j in range(J):
        out[i] += a[i, j] * b[i, j]

假设你现在正在追求这样的目标:

np.einsum('ij,jk,ik->i', a, b, c)

您可以执行与以下操作相同的操作:

np.sum(a[:, :, None] * b[None, :, :] * c[:, None, :], axis=(1, 2))

我认为einsum的作用是运行最后的代码,而不必实例化巨大的中间数组,这肯定能够让事情更快:
In [29]: a, b, c = np.random.rand(3, 100, 100)

In [30]: %timeit np.einsum('ij,jk,ik->i', a, b, c)
100 loops, best of 3: 2.41 ms per loop

In [31]: %timeit np.sum(a[:, :, None] * b[None, :, :] * c[:, None, :], axis=(1, 2))
100 loops, best of 3: 12.3 ms per loop

但是如果你仔细观察,去除中间存储可能会带来糟糕的结果。这就是我认为 einsum 在低级别上所做的事情:

for i in range(I):
    out[i] = 0
    for j in range(J):
        for k in range(K):
            out[i] += a[i, j] * b[j, k] * c[i, k]

但是你正在重复大量的操作!如果您改为执行以下操作:

for i in range(I):
    out[i] = 0
    for j in range(J):
        temp = 0
        for k in range(K):
            temp += b[j, k] * c[i, k]
        out[i] += a[i, j] * temp

你可以少做 I * J * (K-1) 次乘法运算(并多做 I * J 次加法运算),从而节省大量时间。我猜测einsum在这个层面上并没有足够的智能来进行优化。在源代码中有一个提示,它只会对具有 1 或 2 个操作数的操作进行优化,而不是 3 个。总之,对于一般输入自动化处理似乎并不简单...

我最终也采用了dot-einsum解决方案,但希望使用仅einsum的方法会更快。您的答案很好地解释了为什么不是这样。谢谢。 - bogatron
4
更新:在numpy版本1.12.0及以上,有一个名为“optimize”的参数可告诉numpy进行优化。之所以不将其设为默认选项是因为存在内存问题(可能还涉及向后兼容性)。 - Imperishable Night
3
在上面的示例中添加 optimize=True%timeit np.einsum('ij,jk,ik->i', a, b, c, optimize=True) 将时间从 2.4 毫秒 减少到 410 微秒 - n1k31t4

6
einsum有一个特定的情况,针对“2个操作数,ndim=2”的情况。在这种情况下,有3个操作数和总共3个维度。因此,它必须使用一般的nditer
当试图理解字符串输入是如何解析时,我编写了一个纯Python einsum模拟器:https://github.com/hpaulj/numpy-einsum/blob/master/einsum_py.py (简化后的)einsum和乘积之和函数如下:
def myeinsum(subscripts, *ops, **kwargs):
    # dropin preplacement for np.einsum (more or less)
    <parse subscript strings>
    <prepare op_axes>
    x = sum_of_prod(ops, op_axes, **kwargs)
    return x

def sum_of_prod(ops, op_axes,...):
    ...
    it = np.nditer(ops, flags, op_flags, op_axes)
    it.operands[nop][...] = 0
    it.reset()
    for (x,y,z,w) in it:
        w[...] += x*y*z
    return it.operands[nop]

对于语句myeinsum('ik,km,im->i',X,C,X,debug=True)进行调试输出,其中(M,K)=(10,5)

{'max_label': 109, 
 'min_label': 105, 
 'nop': 3, 
 'shapes': [(10, 5), (5, 5), (10, 5)], 
 ....}}
 ...
iter labels: [105, 107, 109],'ikm'

op_axes [[0, 1, -1], [-1, 0, 1], [0, -1, 1], [0, -1, -1]]

如果你在Cython中编写一个类似这样的sum-of-prod函数,那么你应该得到接近广义einsum的结果。
使用完整的(M,K),这个模拟的einsum要慢6-7倍。
一些建立在其他答案基础上的计时结果:
In [84]: timeit np.dot(X,C)
1 loops, best of 3: 781 ms per loop

In [85]: timeit np.einsum('ik,km->im',X,C)
1 loops, best of 3: 1.28 s per loop

In [86]: timeit np.einsum('im,im->i',A,X)
10 loops, best of 3: 163 ms per loop

这个'im,im->i'的步骤比其他步骤快得多。和维度加总的操作m只有20。我猜einsum将其看作特例处理了。
In [87]: timeit np.einsum('im,im->i',np.dot(X,C),X)
1 loops, best of 3: 950 ms per loop

In [88]: timeit np.einsum('im,im->i',np.einsum('ik,km->im',X,C),X)
1 loops, best of 3: 1.45 s per loop

这些复合计算的时间只是相应部分的总和。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接