给定两个任意形状的numpy.ndarray
对象A
和B
,我想计算一个numpy.ndarray
C
,使得对于所有的i
,C[i] == np.dot(A[i], B[i])
。如何做到这一点?
例1:A.shape==(2,3,4)
和B.shape==(2,4,5)
,那么我们应该有C.shape==(2,3,5)
。
例2:A.shape==(2,3,4)
和B.shape==(2,4)
,那么我们应该有C.shape==(2,3)
。
给定两个任意形状的numpy.ndarray
对象A
和B
,我想计算一个numpy.ndarray
C
,使得对于所有的i
,C[i] == np.dot(A[i], B[i])
。如何做到这一点?
例1:A.shape==(2,3,4)
和B.shape==(2,4,5)
,那么我们应该有C.shape==(2,3,5)
。
例2:A.shape==(2,3,4)
和B.shape==(2,4)
,那么我们应该有C.shape==(2,3)
。
np.einsum
。 einsum
在这里很有帮助,因为我们需要沿着输入数组的第一个轴对齐并减少最后轴上的值。实现看起来会像这样 -def dotprod_axis0(A,B):
N,nA,nB = A.shape[0], A.shape[-1], B.shape[1]
Ar = A.reshape(N,-1,nA)
Br = B.reshape(N,nB,-1)
return np.squeeze(np.einsum('ijk,ikl->ijl',Ar,Br))
I. A:2D,B:2D
In [119]: # Inputs
...: A = np.random.randint(0,9,(3,4))
...: B = np.random.randint(0,9,(3,4))
...:
In [120]: for i in range(A.shape[0]):
...: print np.dot(A[i], B[i])
...:
33
86
48
In [121]: dotprod_axis0(A,B)
Out[121]: array([33, 86, 48])
II. A : 3D,B : 3D
In [122]: # Inputs
...: A = np.random.randint(0,9,(2,3,4))
...: B = np.random.randint(0,9,(2,4,5))
...:
In [123]: for i in range(A.shape[0]):
...: print np.dot(A[i], B[i])
...:
[[ 74 70 53 118 43]
[ 47 43 29 95 30]
[ 41 37 26 23 15]]
[[ 50 86 33 35 82]
[ 78 126 40 124 140]
[ 67 88 35 47 83]]
In [124]: dotprod_axis0(A,B)
Out[124]:
array([[[ 74, 70, 53, 118, 43],
[ 47, 43, 29, 95, 30],
[ 41, 37, 26, 23, 15]],
[[ 50, 86, 33, 35, 82],
[ 78, 126, 40, 124, 140],
[ 67, 88, 35, 47, 83]]])
III. A : 3D,B : 2D
In [125]: # Inputs
...: A = np.random.randint(0,9,(2,3,4))
...: B = np.random.randint(0,9,(2,4))
...:
In [126]: for i in range(A.shape[0]):
...: print np.dot(A[i], B[i])
...:
[ 87 105 53]
[152 135 120]
In [127]: dotprod_axis0(A,B)
Out[127]:
array([[ 87, 105, 53],
[152, 135, 120]])
IV. A:2D,B:3D
In [128]: # Inputs
...: A = np.random.randint(0,9,(2,4))
...: B = np.random.randint(0,9,(2,4,5))
...:
In [129]: for i in range(A.shape[0]):
...: print np.dot(A[i], B[i])
...:
[76 93 31 75 16]
[ 33 98 49 117 111]
In [130]: dotprod_axis0(A,B)
Out[130]:
array([[ 76, 93, 31, 75, 16],
[ 33, 98, 49, 117, 111]])
dot
(而不是矩阵-向量或者dot
在更高维度上做的奇怪的东西),那么足够新的NumPy版本(1.10+)可以让你这样做。C = numpy.matmul(A, B)
而且足够新的 Python 版本(3.5+)可以让您将其编写为:
C = A @ B
假设您的NumPy也足够新。
A.shape[0] == B.shape[0]
为True。 - piRSquaredeinsum('ijk,ik...->ij...',A,B)
处理您的两种情况。它只限制A
为三维,而B
可以是2、3等。 - hpaulj