矩阵乘法:在Python中将矩阵的每一行与另一个二维矩阵相乘

6

我想从这个矩阵乘法中删除循环(并了解更多关于优化代码的知识),我认为我需要某种形式的np.broadcastingnp.einsum,但在阅读它们后,我仍然不确定如何将它们用于我的问题。

A = np.array([[1, 2, 3, 4, 5],
         [6, 7, 8, 9, 10],
         [11,12,13,14,15]])
#A is a 3x5 matrix, such that the shape of A is (3, 5) (and A[0] is (5,))

B = np.array([[1,0,0],
         [0,2,0],
         [0,0,3]])
#B is a 3x3 (diagonal) matrix, with a shape of (3, 3)

C = np.zeros(5)
for i in range(5):
    C[i] = np.linalg.multi_dot([A[:,i].T, B, A[:,i]])

#Each row of matrix math is [1x3]*[3x3]*[3x1] to become a scaler value in each row
#C becomes a [5x1] matrix with a shape of (5,)

我知道我不能只使用np.multidot,因为那会得到一个(5,5)的数组。

我还找到了这个:在Numpy中将矩阵乘以另一个矩阵的每一行,但我不确定它是否与我的问题实际上相同。


是的,C是我的期望输出,但我希望能够在不需要循环的情况下得到它。 - LED
3个回答

5
In [601]: C
Out[601]: array([436., 534., 644., 766., 900.])

对于einsum来说,这是一件自然而然的事情。我使用和你一样的i来表示传递到结果的索引。jk是用于乘积求和的索引。

In [602]: np.einsum('ji,jk,ki->i',A,B,A)
Out[602]: array([436, 534, 644, 766, 900])

可能也可以使用mutmul完成,但可能需要添加一个维度,然后再压缩。

使用diagdot方法比必要的工作要多得多。 diag会丢掉很多值。

要使用matmul,我们必须使i维度成为3D数组的第一维。那是“被动”的一个,它会传递到结果中:

In [603]: A.T[:,None,:]@B@A.T[:,:,None]
Out[603]: 
array([[[436]],     # (5,1,1) result

       [[534]],

       [[644]],

       [[766]],

       [[900]]])
In [604]: (A.T[:,None,:]@B@A.T[:,:,None]).squeeze()
Out[604]: array([436, 534, 644, 766, 900])

或者将额外的维度索引掉:(A.T[:,None,:]@B@A.T[:,:,None])[:,0,0]


没有参数的squeeze()是危险的,如果AB中的其他维度为1,则会产生错误结果。看起来你想在这里使用.squeeze(axis=(1,2)) - Eric
作为注意事项,np.einsum('ji,jk,ki->i', A, B, A, optimize=True)将自动处理matmul调用。 - Daniel

1
你可以将调用 dot 链接在一起,然后获得对角线:
# your original output:
# >>> C
# array([436., 534., 644., 766., 900.])

>>> np.diag(np.dot(np.dot(A.T,B), A))
array([436, 534, 644, 766, 900])

或者等价地,使用您原始的multi_dot思路,但取得结果5x5数组的对角线。这可能会有一些性能提升(根据文档)。
>>> np.diag(np.linalg.multi_dot([A.T, B, A]))
array([436, 534, 644, 766, 900])

1
取对角线意味着你要放弃一些计算,因此这种方法不如其他方法高效。 - Eric

0

补充一下答案。如果你想要矩阵相乘,可以使用广播技术。编辑:请注意这是元素级别的乘法,不是点积。如果需要点积,可以使用点积方法。

 B [...,None] * A

给出:

array([[[ 1,  2,  3,  4,  5],
        [ 0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0]],

       [[ 0,  0,  0,  0,  0],
        [12, 14, 16, 18, 20],
        [ 0,  0,  0,  0,  0]],

       [[ 0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0],
        [33, 36, 39, 42, 45]]])

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接