使用numpy进行独热编码

60
如果输入为零,我想要创建一个看起来像这样的数组:
[1,0,0,0,0,0,0,0,0,0]

如果输入为5:

[0,0,0,0,0,1,0,0,0,0]

我为上述内容写了:

np.put(np.zeros(10),5,1)

但它没有起作用。

有没有一种方法可以在一行中实现这个?


1
它为什么不起作用? - Mad Physicist
1
你为什么想要在一行代码中完成这个任务?如果你想让代码更简洁,可以写一个函数。 - PM 2Ring
当您至少获得一个解决问题的答案时,通常会选择其中一个答案。 - Mad Physicist
9个回答

121

通常,在机器学习中想要获得分类的一位有效编码时,您会有一个索引数组。

import numpy as np
nb_classes = 6
targets = np.array([[2, 3, 4, 0]]).reshape(-1)
one_hot_targets = np.eye(nb_classes)[targets]

one_hot_targets现在已经是独热编码形式

array([[[ 0.,  0.,  1.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  1.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  1.,  0.],
        [ 1.,  0.,  0.,  0.,  0.,  0.]]])

.reshape(-1)的作用是确保您具有正确的标签格式(您可能还有[[2],[3],[4],[0]])。-1是一个特殊值,表示“将所有剩余的东西放在这个维度中”。由于只有一个维度,它会将数组展平。

复制粘贴解决方案

def get_one_hot(targets, nb_classes):
    res = np.eye(nb_classes)[np.array(targets).reshape(-1)]
    return res.reshape(list(targets.shape)+[nb_classes])

软件包

你可以使用mpu.ml.indices2one_hot。它经过测试且易于使用:

import mpu.ml
one_hot = mpu.ml.indices2one_hot([1, 3, 0], nb_classes=5)

但是它是如何工作的?np.eye(nb_classes)应该是6x6矩阵,但它的形状变成了4x6。你能详细说明一下吗? - mrgloom
1
np.eye(nb_classes) 是一个 6x6 的矩阵。然后我选择目标中指定的行。我只选择了四行,所以它是一个 4x6 的矩阵。 - Martin Thoma
这似乎仅适用于二维目标,但可以通过执行.reshape(list(targets.shape)+[nb_classes])来推广到更多形状。 - siddhadev
你能解释一下为什么 np.eye(nb_classes)[np.array(targets).reshape(-1)] 能够工作吗?它是一个由 H*W 矩阵索引的 CxC 矩阵?!这里到底发生了什么? - gebbissimo
@gebbissimo 首先,尝试理解 np.eye(n) 的用途。然后看看 np.eye(5)[[3, 1]] 是什么意思。 - Martin Thoma

11

类似于:

np.array([int(i == 5) for i in range(10)])

这应该能解决问题。 但我想可能还有其他使用numpy的解决方案。

编辑:你的公式不能正常工作的原因是:np.put不返回任何内容,它只修改第一个参数中给定的元素。使用np.put() 的正确答案是:

a = np.zeros(10)
np.put(a,5,1)

问题在于无法在一行中完成,因为您需要在将其传递给np.put()之前定义数组


3
这是实现你想要的目标最为低效的方式。 - PM 2Ring
1
@PM2Ring 我知道我写的那个一行代码很糟糕,但你有没有任何关于列表和numpy数组应该做什么和不应该做什么的资料? - HolyDanna
1
@HolyDanna:在Python中,一般规则是Python循环比使用C代码执行的循环运行速度慢。因此,如果有明显的方法可以使用C循环而不是Python循环,则应该使用C循环。而使用Numpy的整个重点是在可能的情况下以C速度进行数组处理。我不熟悉numpy源代码,但numpy.zeros可能甚至比C for循环运行得更快,因为CPU可以非常快地用单个值填充一块内存。 - PM 2Ring
顺便说一句,我并不是在说你的第一个代码示例有问题。在非Numpy程序中,这是一种很好的方法,为了这个操作而导入Numpy是愚蠢的。但是,如果程序已经在使用Numpy,那么利用Numpy所提供的功能是有意义的。 - PM 2Ring

5
您可以使用列表推导式:
```python ```
[0 if i !=5 else 1 for i in range(10)]

转向
[0,0,0,0,0,1,0,0,0,0]

4

我不确定性能如何,但以下代码可以正常工作,而且很整洁。

x = np.array([0, 5])
x_onehot = np.identity(6)[x]

这是关于编程的内容,请将其从英语翻译成中文。请仅返回已翻译的文本,而不要进行解释。那基本上等同于接受的答案。谢谢您再次回答。 - Nik O'Lai

3

2

问题在于你没有将数组保存在任何地方。 put 函数直接在数组上进行操作并不返回任何值。由于你从未给数组命名,因此以后无法引用它。因此,这段代码:

one_pos = 5
x = np.zeros(10)
np.put(x, one_pos, 1)

使用 would work 也可以,但你也可以使用索引:

one_pos = 5
x = np.zeros(10)
x[one_pos] = 1

在我看来,如果没有特殊的原因要求成为一行代码,则这是正确的做法。这也更容易阅读,易读的代码是好的代码。


2
快速查看手册,您会发现np.put没有返回值。虽然您的技术很好,但是您访问的是None而不是结果数组。
对于1-D数组,最好直接使用直接索引,特别是对于这种简单情况。
以下是如何以最小修改重写您的代码:
arr = np.zeros(10)
np.put(arr, 5, 1)

以下是关于如何使用索引进行第二行的操作,而不是使用put的方法:
arr[5] = 1

2

np.put会对其数组参数进行原地变异。在Python中,执行原地变异的函数/方法通常会返回Nonenp.put遵循这个惯例。因此,如果a是一个一维数组,您可以执行以下操作:

a = np.put(a, 5, 1)

那么a将被None替换。

你的代码与此类似,但它将一个未命名的数组传递给np.put

一种简单而高效的方法是使用一个简单的函数,例如:

import numpy as np

def one_hot(i):
    a = np.zeros(10, 'uint8')
    a[i] = 1
    return a

a = one_hot(5) 
print(a)

输出

[0 0 0 0 0 1 0 0 0 0]

1
我会记住这个,以免对人不礼貌。 - HolyDanna

0
import time
start_time = time.time()
z=[]
for l in [1,2,3,4,5,6,1,2,3,4,4,6,]:
    a= np.repeat(0,10)
    np.put(a,l,1)
    z.append(a)
print("--- %s seconds ---" % (time.time() - start_time))

#--- 0.00174784660339 seconds ---

import time
start_time = time.time()
z=[]
for l in [1,2,3,4,5,6,1,2,3,4,4,6,]:
    z.append(np.array([int(i == l) for i in range(10)]))
print("--- %s seconds ---" % (time.time() - start_time))

#--- 0.000400066375732 seconds ---

使用 a=np.zeros(10),我得到了第一个版本稍微快一点的结果:第一个版本为 0.0007712841033935547 秒,而第二个版本为 0.0008835792541503906 秒 - HolyDanna
1
尝试使用a = np.zeros(10); a[l] = 1进行索引赋值,比调用函数更快。我的one_hot函数比这个内联版本稍慢,也是由于函数调用的开销,但它比其他技术更快。然而,这个时间信息并不是非常准确的,你应该使用timeit模块,并使用它的设施进行数百(或数千)次测试,以获得有意义的结果,这些结果不会被CPU执行的其他任务的“噪音”所淹没。 - PM 2Ring
谢谢。您知道有什么更好的方法来检查代码运行时间吗? - Abhijay Ghildyal
1
您IP地址为143.198.54.68,由于运营成本限制,当前对于免费用户的使用频率限制为每个IP每72小时10次对话,如需解除限制,请点击左下角设置图标按钮(手机用户先点击左上角菜单按钮)。 - PM 2Ring

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接