In [331]: A=np.random.rand(100,200,300)
In [332]: B=A
建议的
einsum
直接从原始数据中工作。
C[i,j,k] = np.dot(A[i,k,:], B[j,k,:]
表达式:
In [333]: np.einsum( 'ikm, jkm-> ijk', A, B).shape
Out[333]: (100, 100, 200)
In [334]: timeit np.einsum( 'ikm, jkm-> ijk', A, B).shape
800 ms ± 25.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
matmul
会在最后两个维度上进行点积运算,并将前面的一个或多个维度作为批量维度。 在你的情况下,'k'是批量维度,而'm'应该遵守“最后一个A和B倒数第二个”规则。因此,需要重写ikm,jkm ...
以适应规则,并相应地转置A
和B
:
In [335]: np.einsum('kim,kmj->kij', A.transpose(1,0,2), B.transpose(1,2,0)).shape
Out[335]: (200, 100, 100)
In [336]: timeit np.einsum('kim,kmj->kij',A.transpose(1,0,2), B.transpose(1,2,0)).shape
774 ms ± 22.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
在性能上没有太大区别。但现在使用matmul
:
In [337]: (A.transpose(1,0,2)@B.transpose(1,2,0)).transpose(1,2,0).shape
Out[337]: (100, 100, 200)
In [338]: timeit (A.transpose(1,0,2)@B.transpose(1,2,0)).transpose(1,2,0).shape
64.4 ms ± 1.17 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
并验证值是否匹配(虽然往往情况下,如果形状匹配,则值也会匹配)。
In [339]: np.allclose((A.transpose(1,0,2)@B.transpose(1,2,0)).transpose(1,2,0),np.einsum( 'ikm, jkm->
...: ijk', A, B))
Out[339]: True
我不会试图测量内存使用情况,但时间上的改进也表明它是更好的。
在某些情况下,einsum
会被优化为使用 matmul
。在这里似乎不是这种情况,尽管我们可以玩一下它的参数。我有点惊讶 matmul
的性能如此好。
===
我模糊地记得另一个关于当两个数组是同一物品时,matmul
会走捷径的SO问题,A@A
我在这些测试中使用了 B=A
。
In [350]: timeit (A.transpose(1,0,2)@B.transpose(1,2,0)).transpose(1,2,0).shape
60.6 ms ± 1.17 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [352]: B2=np.random.rand(100,200,300)
In [353]: timeit (A.transpose(1,0,2)@B2.transpose(1,2,0)).transpose(1,2,0).shape
97.4 ms ± 164 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
但是这只有一点点的改变。
In [356]: np.__version__
Out[356]: '1.16.4'
我的BLAS等库是标准的Linux,没有什么特别的。
einsum
和@
(np ver 1.15.3)得到了类似的时间。你的改进是因为你使用的是np>1.16.0吗? - Brenllanumpy
,现在我得到了类似的时间。可能是因为 这个。 - BrenllaB=A
,而matmul在A@A
情况下采取了适度的快捷方式,但这并不能解释大部分时间差异。 - hpaulj