我需要对两个4D数组(m和n)执行矩阵乘法,它们的维度分别为2x2x2x2和2x3x2x2,结果应该是一个2x3x2x2的数组。经过大量研究(主要在此网站上),似乎可以使用np.einsum或np.tensordot以高效的方式完成此操作,但我无法复制Matlab给出的答案(手动验证)。我理解如何在2D数组上执行矩阵乘法时这些方法(einsum和tensordot)的工作原理(在此清楚地解释),但我无法正确地获取4D数组的轴索引。显然我错过了什么!我的实际问题涉及两个23x23x3x3的复数数组,但我的测试数组是:
a = np.array([[1, 7], [4, 3]])
b = np.array([[2, 9], [4, 5]])
c = np.array([[3, 6], [1, 0]])
d = np.array([[2, 8], [1, 2]])
e = np.array([[0, 0], [1, 2]])
f = np.array([[2, 8], [1, 0]])
m = np.array([[a, b], [c, d]]) # (2,2,2,2)
n = np.array([[e, f, a], [b, d, c]]) # (2,3,2,2)
我知道复数可能会带来更多问题,但现在我只想理解如何使用einsum和tensordot进行索引。我追求的答案是这个2x3x2x2的数组:
+----+-----------+-----------+-----------+
| | 0 | 1 | 2 |
+====+===========+===========+===========+
| 0 | [[47 77] | [[22 42] | [[44 40] |
| | [31 67]] | [27 74]] | [33 61]] |
+----+-----------+-----------+-----------+
| 1 | [[42 70] | [[24 56] | [[41 51] |
| | [10 19]] | [ 6 20]] | [ 6 13]] |
+----+-----------+-----------+-----------+
我最接近的尝试是使用np.tensordot:
mn = np.tensordot(m,n, axes=([1,3],[0,2]))
这让我得到了一个2x2x3x2的数组,其中包含正确的数字但顺序不正确:
+----+-----------+-----------+
| | 0 | 1 |
+====+===========+===========+
| 0 | [[47 77] | [[31 67] |
| | [22 42] | [24 74] |
| | [44 40]] | [33 61]] |
+----+-----------+-----------+
| 1 | [[42 70] | [[10 19] |
| | [24 56] | [ 6 20] |
| | [41 51]] | [ 6 13]] |
+----+-----------+-----------+
我也尝试了从这里实现的一些解决方案,但是没有成功。
如果您有任何改进意见,将不胜感激,谢谢。