使用numpy.einsum进行转置乘矩阵:x^T * x

4
针对一个二维矩阵 X(形状为 (m,n)),我要计算 X.T * X,其中 * 表示矩阵乘法。参考这篇文章的解释,我原以为可以通过 np.einsum('ji,ik->jk', X, X) 实现。在左边写上 ji,首先会将第一个 X 参数转置,然后再乘以第二个 X 参数。

但是这样做会出现错误(当 (m,n) = (3,4) 时):

ValueError: operands could not be broadcast together with remapped shapes [original->remapped]: (4,3)->(4,newaxis,3) (4,3)->(3,4)

然而,下面这行代码却有效: np.einsum('ij,jk->ik', X.T, X)。我错过了什么?为什么它甚至还添加了一个中间轴?

1个回答

7
使用X.T * X(*为矩阵相乘),您正在将第一个X的转置的第二轴与第二个X的第一轴相加。现在,第一个X的转置的第二轴将与第一个X的第一轴相同。因此,我们只需从这两个X中减少第一个轴,而它们的其余轴保持不变。
要在einsum上复制它,请保持字符串符号中的第一个字符不变,但为两个输入的第二轴选择不同的字符,如下所示 -
np.einsum('ji,jk->ik', X, X)

因此,j 被求和约简,而其余轴 - ik 保留在输出中。
同样,这比本地矩阵乘法慢:X.T.dot(X)。但是,我猜想这篇文章更多地是作为对于 einsum 的学习之用。

我还有点困惑。在矩阵乘法A x B中,我们希望将A的每一行与B的每一列相乘,对吗?所以'ji,jk->ik'看起来像是我们正在取转置的,并且乘以矩阵的...是这样吗? - DilithiumMatrix
没事了...只是有点困惑。这很有道理,谢谢@Divakar! - DilithiumMatrix

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接