这里有几个要点:numpy的向量操作、添加单例轴和广播。
首先,您应该能够看到==
如何发挥其作用。
假设我们从一个简单的标签数组开始。 ==
采用矢量化方式运算,这意味着我们可以将整个数组与标量进行比较,并获得由每个元素比较值构成的数组。例如:
>>> labels = np.array([1,2,0,0,2])
>>> labels == 0
array([False, False, True, True, False], dtype=bool)
>>> (labels == 0).astype(np.float32)
array([ 0., 0., 1., 1., 0.], dtype=float32)
首先我们得到一个布尔数组,然后将其强制转换为浮点数:在Python中,False == 0,True == 1。因此,我们最终得到的数组在labels
不等于0时为0,在等于0时为1。
但是,与0进行比较并没有什么特别之处,我们可以相应地将其与1、2或3进行比较以获得类似的结果:
>>> (labels == 2).astype(np.float32)
array([ 0., 1., 0., 0., 1.], dtype=float32)
实际上,我们可以循环遍历每个可能的标签并生成此数组。我们可以使用列表推导式:
>>> np.array([(labels == i).astype(np.float32) for i in np.arange(3)])
array([[ 0., 0., 1., 1., 0.],
[ 1., 0., 0., 0., 0.],
[ 0., 1., 0., 0., 1.]], dtype=float32)
但这并没有充分利用numpy。我们想做的是将每个可能的标签与每个元素进行比较,也就是说要进行比较。
>>> np.arange(3)
array([0, 1, 2])
随着
>>> labels
array([1, 2, 0, 0, 2])
这里就是 numpy 广播魔法发挥作用的地方。目前,labels
是一个形状为 (5,) 的一维对象。如果我们将其变成一个形状为 (5,1) 的二维对象,那么操作将会“广播”在最后一个轴上,并且我们将得到一个形状为 (5,3) 的输出,其中包含了每个范围中每个元素与 labels 中每个元素进行比较的结果。
首先,我们可以使用 None
(或 np.newaxis
)向 labels
添加一个“额外”轴,改变其形状:
>>> labels[:,None]
array([[1],
[2],
[0],
[0],
[2]])
>>> labels[:,None].shape
(5, 1)
然后我们可以进行比较(这是之前我们正在查看的排列的转置,但这并不重要)。
>>> np.arange(3) == labels[:,None]
array([[False, True, False],
[False, False, True],
[ True, False, False],
[ True, False, False],
[False, False, True]], dtype=bool)
>>> (np.arange(3) == labels[:,None]).astype(np.float32)
array([[ 0., 1., 0.],
[ 0., 0., 1.],
[ 1., 0., 0.],
[ 1., 0., 0.],
[ 0., 0., 1.]], dtype=float32)
在NumPy中,广播是非常强大的,值得深入学习。