numpy einsum的替代方案

7

当我计算一个有 N 行和 n 列的矩阵 X 的三阶矩时,我通常使用 einsum

M3 = sp.einsum('ij,ik,il->jkl',X,X,X) /N

这通常能正常工作,但现在我正在使用更大的值,即n = 120N = 100000,而einsum会返回以下错误:

ValueError: 迭代器太大

嵌套3个循环的替代方法不可行,所以我想知道是否有任何替代方法。

1个回答

5
请注意,计算这个需要至少进行 ~n3 × N = 1730亿次操作(不考虑对称性),因此它将很慢,除非numpy可以访问GPU或其他设备。在具有 ~3 GHz CPU的现代计算机上,整个计算预计需要大约60秒才能完成,假设没有SIMD /并行速度提升。
为了测试,让我们从N = 1000开始。我们将使用此来检查正确性和性能:
#!/usr/bin/env python3

import numpy
import time

numpy.random.seed(0)

n = 120
N = 1000
X = numpy.random.random((N, n))

start_time = time.time()

M3 = numpy.einsum('ij,ik,il->jkl', X, X, X)

end_time = time.time()

print('check:', M3[2,4,6], '= 125.401852515?')
print('check:', M3[4,2,6], '= 125.401852515?')
print('check:', M3[6,4,2], '= 125.401852515?')
print('check:', numpy.sum(M3), '= 218028826.631?')
print('total time =', end_time - start_time)

这需要大约8秒钟时间。这是基准。

让我们从三层嵌套循环作为替代方案开始:

M3 = numpy.zeros((n, n, n))
for j in range(n):
    for k in range(n):
        for l in range(n):
            M3[j,k,l] = numpy.sum(X[:,j] * X[:,k] * X[:,l])
# ~27 seconds

这需要大约半分钟,不好!一个原因是因为这实际上是四个嵌套循环:numpy.sum也可以被视为一个循环。

我们注意到,可以将总和转换为点积以去除第四个循环:

M3 = numpy.zeros((n, n, n))
for j in range(n):
    for k in range(n):
        for l in range(n):
            M3[j,k,l] = X[:,j] * X[:,k] @ X[:,l]
# 14 seconds

现在好多了,但仍然很慢。但我们注意到点积可以改为矩阵乘法来消除一个循环:

M3 = numpy.zeros((n, n, n))
for j in range(n):
    for k in range(n):
        M3[j,k] = X[:,j] * X[:,k] @ X
# ~0.5 seconds

什么?现在这比 einsum 更加高效!我们也可以检查答案是否确实正确。

我们能否进一步优化?可以!我们可以通过以下方式消除 k 循环:

M3 = numpy.zeros((n, n, n))
for j in range(n):
    Y = numpy.repeat(X[:,j], n).reshape((N, n))
    M3[j] = (Y * X).T @ X
# ~0.3 seconds

我们也可以使用广播(即对X的每一行执行a * [b,c] == [a*b, a*c])来避免使用numpy.repeat(感谢@Divakar):

M3 = numpy.zeros((n, n, n))
for j in range(n):
    Y = X[:,j].reshape((N, 1))
    ## or, equivalently: 
    # Y = X[:, numpy.newaxis, j]
    M3[j] = (Y * X).T @ X
# ~0.16 seconds

如果我们将N扩展到100000,程序的预计执行时间为16秒,这已经在理论限制范围内,因此消除j可能帮助不大(但这可能会使代码变得非常难以理解)。我们可以接受这个作为最终解决方案。
注意:如果您正在使用Python 2,则a @ b等同于a.dot(b)。

非常好的想法。如果我可以在这里添加一些广播,我们可以避免创建 Y 并直接获得迭代输出:(X[:,None,j]*X).T @ X。这应该能够进一步提高性能。 - Divakar
@Divakar:谢谢!已更新。 - kennytm

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接