理解如何应用于NumPy数组

19

我刚接触Python,正在学习TensorFlow。在一个使用notMNIST数据集的教程中,他们提供了示例代码来将标签矩阵转换为one-of-n编码数组。

目标是将由标签整数0…9组成的数组,转换为每个整数被转换为像这样的one-of-n编码数组的矩阵:

0 -> [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]
1 -> [0, 1, 0, 0, 0, 0, 0, 0, 0, 0]
2 -> [0, 0, 1, 0, 0, 0, 0, 0, 0, 0]
...

他们提供的代码如下:

# Map 0 to [1.0, 0.0, 0.0 ...], 1 to [0.0, 1.0, 0.0 ...]
labels = (np.arange(num_labels) == labels[:,None]).astype(np.float32)

然而,我完全不明白这段代码是如何做到这一点的。看起来它只是在生成一个0到9范围内的整数数组,然后将其与标签矩阵进行比较,并将结果转换为浮点数。那么,一个==运算符怎么会产生一个one-of-n编码矩阵呢?

2个回答

35

这里有几个要点:numpy的向量操作、添加单例轴和广播。

首先,您应该能够看到==如何发挥其作用。

假设我们从一个简单的标签数组开始。 ==采用矢量化方式运算,这意味着我们可以将整个数组与标量进行比较,并获得由每个元素比较值构成的数组。例如:

>>> labels = np.array([1,2,0,0,2])
>>> labels == 0
array([False, False,  True,  True, False], dtype=bool)
>>> (labels == 0).astype(np.float32)
array([ 0.,  0.,  1.,  1.,  0.], dtype=float32)

首先我们得到一个布尔数组,然后将其强制转换为浮点数:在Python中,False == 0,True == 1。因此,我们最终得到的数组在labels不等于0时为0,在等于0时为1。

但是,与0进行比较并没有什么特别之处,我们可以相应地将其与1、2或3进行比较以获得类似的结果:

>>> (labels == 2).astype(np.float32)
array([ 0.,  1.,  0.,  0.,  1.], dtype=float32)

实际上,我们可以循环遍历每个可能的标签并生成此数组。我们可以使用列表推导式:

>>> np.array([(labels == i).astype(np.float32) for i in np.arange(3)])
array([[ 0.,  0.,  1.,  1.,  0.],
       [ 1.,  0.,  0.,  0.,  0.],
       [ 0.,  1.,  0.,  0.,  1.]], dtype=float32)

但这并没有充分利用numpy。我们想做的是将每个可能的标签与每个元素进行比较,也就是说要进行比较。

>>> np.arange(3)
array([0, 1, 2])

随着

>>> labels
array([1, 2, 0, 0, 2])

这里就是 numpy 广播魔法发挥作用的地方。目前,labels 是一个形状为 (5,) 的一维对象。如果我们将其变成一个形状为 (5,1) 的二维对象,那么操作将会“广播”在最后一个轴上,并且我们将得到一个形状为 (5,3) 的输出,其中包含了每个范围中每个元素与 labels 中每个元素进行比较的结果。

首先,我们可以使用 None(或 np.newaxis)向 labels 添加一个“额外”轴,改变其形状:

>>> labels[:,None]
array([[1],
       [2],
       [0],
       [0],
       [2]])
>>> labels[:,None].shape
(5, 1)

然后我们可以进行比较(这是之前我们正在查看的排列的转置,但这并不重要)。

>>> np.arange(3) == labels[:,None]
array([[False,  True, False],
       [False, False,  True],
       [ True, False, False],
       [ True, False, False],
       [False, False,  True]], dtype=bool)
>>> (np.arange(3) == labels[:,None]).astype(np.float32)
array([[ 0.,  1.,  0.],
       [ 0.,  0.,  1.],
       [ 1.,  0.,  0.],
       [ 1.,  0.,  0.],
       [ 0.,  0.,  1.]], dtype=float32)

在NumPy中,广播是非常强大的,值得深入学习。


1
非常详细和清晰的解释。大多数参加Udacity深度学习课程的人都会遇到这个答案。 - AgentX

1
简而言之,应用于numpy数组的==意味着对数组应用逐元素==。结果是一个布尔数组。以下是一个示例:
>>> b = np.array([1,0,0,1,1,0])
>>> b == 1
array([ True, False, False,  True,  True, False], dtype=bool)

要计算数组b中有多少个1,您不需要将数组转换为浮点数,也就是说,可以省略.astype(np.float32),因为在Python中布尔值是整数的子类,在Python 3中True == 1 False == 0。因此,这里是如何计算b中有多少个1:

>>> np.sum((b == 1))
3

或者:

>>> np.count_nonzero(b == 1)
3

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接