这里是原始输入变量:
A = np.array([[1,1,1,1],[2,2,2,2]])
B = np.array([[1,2,3,4],[1,1,1,1],[1,2,1,9]])
A
B
A是一个2x4的数组。
B是一个3x4的数组。
我们想要通过一个完全向量化的操作来计算欧几里得距离矩阵运算,其中dist[i,j]
包含A中第i个实例和B中第j个实例之间的距离。因此,在这个例子中,dist
是2x3的。
距离
![enter image description here](https://istack.dev59.com/syuCq.webp)
可以用numpy来表述
dist = np.sqrt(np.sum(np.square(A-B))) # DOES NOT WORK
# Traceback (most recent call last):
# File "<stdin>", line 1, in <module>
# ValueError: operands could not be broadcast together with shapes (2,4) (3,4)
然而,如上所示,问题在于逐元素减法操作
A-B
涉及不兼容的数组大小,特别是第一维中的 2 和 3。
A has dimensions 2 x 4
B has dimensions 3 x 4
为了进行按元素减法,我们必须填充A或B中的一个以满足numpy的广播规则。我选择用额外的一维来填充A,使其变成2 x 1 x 4,这样可以使数组的维度对齐以进行广播。有关numpy广播的更多信息,请参见
scipy手册中的教程和
本教程的最后一个示例。
您可以使用np.newaxis值或np.reshape命令执行填充。我两种方法都会展示:
A[:,np.newaxis,:] has dimensions 2 x 1 x 4
B has dimensions 3 x 4
np.reshape(A, (2,1,4)) has dimensions 2 x 1 x 4
B has dimensions 3 x 4
如您所见,使用任何一种方法都可以使维度对齐。我将使用第一种方法,即 np.newaxis
。因此,现在可以使用以下代码创建一个 2x3x4 的数组 A-B:
diff = A[:,np.newaxis,:] - B
diff.shape
现在我们可以将这个差异表达式放入“dist”方程语句中,以获得最终结果:
dist = np.sqrt(np.sum(np.square(A[:,np.newaxis,:] - B), axis=2))
dist
请注意,
sum
是在
axis=2
上进行的,这意味着对 2x3x4 数组的第三个轴求和(其中轴 id 从 0 开始)。
如果您的数组很小,则上述命令将正常工作。但是,如果您有大型数组,则可能会遇到内存问题。请注意,在上面的示例中,numpy 在内部创建了一个 2x3x4 数组来执行广播。如果我们将 A 的维度通用化为
a x z
,将 B 的维度通用化为
b x z
,则 numpy 将在广播时在内部创建一个
a x b x z
数组。
我们可以通过进行一些数学操作来避免创建此中间数组。因为您正在计算欧几里得距离作为平方差的总和,所以我们可以利用平方差可重写的数学事实。
![enter image description here](https://istack.dev59.com/Hf8Gj.webp)
注意中间项涉及按元素相乘求和。这个乘法求和更为常见的称呼是点乘。因为A和B都是矩阵,所以这个操作实际上是矩阵乘法。我们可以将上述式子重写为:
![enter image description here](https://istack.dev59.com/QxNaK.webp)
我们可以编写以下NumPy代码:
threeSums = np.sum(np.square(A)[:,np.newaxis,:], axis=2) - 2 * A.dot(B.T) + np.sum(np.square(B), axis=1)
dist = np.sqrt(threeSums)
dist
请注意,上面的答案与之前的实现完全相同。这里的优点是我们不需要为广播创建中间的2x3x4数组。
为了完整起见,让我们再次检查
threeSums
中每个加数的维度是否允许广播。
np.sum(np.square(A)[:,np.newaxis,:], axis=2) has dimensions 2 x 1
2 * A.dot(B.T) has dimensions 2 x 3
np.sum(np.square(B), axis=1) has dimensions 1 x 3
因此,正如预期的那样,最终的dist
数组具有2x3的尺寸。
使用点积代替逐元素乘法之和的方法也在this tutorial中讨论过。