在4D numpy数组上进行矩阵乘法

5
我需要对两个4D数组(m和n)执行矩阵乘法,它们的维度分别为2x2x2x2和2x3x2x2,结果应该是一个2x3x2x2的数组。经过大量研究(主要在此网站上),似乎可以使用np.einsum或np.tensordot以高效的方式完成此操作,但我无法复制Matlab给出的答案(手动验证)。我理解如何在2D数组上执行矩阵乘法时这些方法(einsum和tensordot)的工作原理(在此清楚地解释),但我无法正确地获取4D数组的轴索引。显然我错过了什么!我的实际问题涉及两个23x23x3x3的复数数组,但我的测试数组是:
a = np.array([[1, 7], [4, 3]]) 
b = np.array([[2, 9], [4, 5]]) 
c = np.array([[3, 6], [1, 0]]) 
d = np.array([[2, 8], [1, 2]]) 
e = np.array([[0, 0], [1, 2]])
f = np.array([[2, 8], [1, 0]])

m = np.array([[a, b], [c, d]])              # (2,2,2,2)
n = np.array([[e, f, a], [b, d, c]])        # (2,3,2,2)

我知道复数可能会带来更多问题,但现在我只想理解如何使用einsum和tensordot进行索引。我追求的答案是这个2x3x2x2的数组:

+----+-----------+-----------+-----------+
|    | 0         | 1         | 2         |
+====+===========+===========+===========+
|  0 | [[47 77]  | [[22 42]  | [[44 40]  |
|    |  [31 67]] |  [27 74]] |  [33 61]] |
+----+-----------+-----------+-----------+
|  1 | [[42 70]  | [[24 56]  | [[41 51]  |
|    |  [10 19]] |  [ 6 20]] |  [ 6 13]] |
+----+-----------+-----------+-----------+

我最接近的尝试是使用np.tensordot:

mn = np.tensordot(m,n, axes=([1,3],[0,2]))

这让我得到了一个2x2x3x2的数组,其中包含正确的数字但顺序不正确:

+----+-----------+-----------+
|    | 0         | 1         |
+====+===========+===========+
|  0 | [[47 77]  | [[31 67]  |
|    |  [22 42]  |  [24 74]  |
|    |  [44 40]] |  [33 61]] |
+----+-----------+-----------+
|  1 | [[42 70]  | [[10 19]  |
|    |  [24 56]  |  [ 6 20]  |
|    |  [41 51]] |  [ 6 13]] |
+----+-----------+-----------+

我也尝试了从这里实现的一些解决方案,但是没有成功。
如果您有任何改进意见,将不胜感激,谢谢。


在APL(一种编程语言)中,矩阵乘法运算符被推广到高维对象。乘法/加法是跨左侧对象的最后一个维度和第二个对象的第一个维度进行的。对于2x2x2x2 +.x 2x3x2x2的情况,“内部”(...x2,2x...)维度被删除,结果将是一个2x2x2x3x2x2对象。我想知道Matlab是否遵循同样的规则。 - rcgldr
3个回答

3

由于你的降维维度既不匹配(这会允许广播),也不是“内部”维度(这会与np.tensordot原生工作),因此最好的选择是使用np.einsum

np.einsum('ijkl,jmln->imkn', m, n)

array([[[[47, 77],
         [31, 67]],

        [[22, 42],
         [24, 74]],

        [[44, 40],
         [33, 61]]],


       [[[42, 70],
         [10, 19]],

        [[24, 56],
         [ 6, 20]],

        [[41, 51],
         [ 6, 13]]]])

我认为这种方法最容易应用/理解,但与@Divakar的tensordot解决方案相比,我很惊讶它要慢多少。我看到的所有信息都表明einsum将是最有效的方法。 - 4bears
@AndrewForbes 只是基于BLAS的tensordot太高效了。 - Divakar
很不幸,np.einsum虽然对于了解爱因斯坦求和符号的人非常清晰,但并不像各种*dot运算符那样被优化。但是由于您的玩具问题似乎与实际问题没有相同的形式,我认为过早进行优化是不值得的。 - Daniel F

3

您可以简单地交换tensordot结果上的轴,这样我们仍然可以利用基于BLAS的求和缩减与tensordot一起使用。

np.tensordot(m,n, axes=((1,3),(0,2))).swapaxes(1,2)

或者,我们可以在tensordot调用中交换mn的位置,并进行转置以重新排列所有轴 -

np.tensordot(n,m, axes=((0,2),(1,3))).transpose(2,0,3,1)

通过手动重塑和交换轴的方式,我们可以使用np.dot2D矩阵相乘,代码如下 -

m0,m1,m2,m3 = m.shape
n0,n1,n2,n3 = n.shape
m2D = m.swapaxes(1,2).reshape(-1,m1*m3)
n2D = n.swapaxes(1,2).reshape(n0*n2,-1)
out = m2D.dot(n2D).reshape(m0,m2,n1,n3).swapaxes(1,2)

运行时测试 -

将输入数组扩展为10x的形状:

In [85]: m = np.random.rand(20,20,20,20)

In [86]: n = np.random.rand(20,30,20,20)

# @Daniel F's soln with einsum
In [87]: %timeit np.einsum('ijkl,jmln->imkn', m, n)
10 loops, best of 3: 136 ms per loop

In [126]: %timeit np.tensordot(m,n, axes=((1,3),(0,2))).swapaxes(1,2)
100 loops, best of 3: 2.31 ms per loop

In [127]: %timeit np.tensordot(n,m, axes=((0,2),(1,3))).transpose(2,0,3,1)
100 loops, best of 3: 2.37 ms per loop

In [128]: %%timeit
     ...: m0,m1,m2,m3 = m.shape
     ...: n0,n1,n2,n3 = n.shape
     ...: m2D = m.swapaxes(1,2).reshape(-1,m1*m3)
     ...: n2D = n.swapaxes(1,2).reshape(n0*n2,-1)
     ...: out = m2D.dot(n2D).reshape(m0,m2,n1,n3).swapaxes(1,2)
100 loops, best of 3: 2.36 ms per loop

我可以看出某种轴交换可能会做到这一点,但我没有选择这条路,因为我认为tensordot可以在不重新排列的情况下完成它。为什么einsum和tensordot之间的速度差异如此之大? - 4bears
1
@AndrewForbes 要了解tensordot如何“展开”剩余的轴,请参考相关文章-https://dev59.com/NlgR5IYBdhLWcg3wAJT7#41871402/。 - Divakar

2

只是为了展示广播也可以工作:

(m[:, :, None, :, :, None] * n[None, :, :, None, :, :]).sum(axis=(1,4))

但是其他已发布的解决方案可能更快,至少对于大型数组而言。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接