在TensorFlow中,函数'tf.one_hot'中的参数'axis'是什么意思?

27

有人能帮忙解释一下在TensorFlow的one_hot函数中,axis是什么吗?

根据文档

axis:要填充的轴(默认值:-1,新的最内层轴)

我找到的最接近答案的SO解释Pandas相关:

不确定上下文是否适用。

3个回答

23

这是一个例子:

x = tf.constant([0, 1, 2])

...是输入张量,其中N=4(每个索引都被转换为4D向量)。

axis=-1

计算one_hot_1 = tf.one_hot(x, 4).eval()会生成一个大小为(3,4)的张量:

[[ 1.  0.  0.  0.]
 [ 0.  1.  0.  0.]
 [ 0.  0.  1.  0.]]

...其中最后一个维度是单热编码(清晰可见)。这相应于默认的axis=-1,即最后一个维度。

axis=0

现在,计算one_hot_2 = tf.one_hot(x, 4, axis=0).eval()会得到一个(4, 3)张量,它不立即被识别为单热编码:

[[ 1.  0.  0.]
 [ 0.  1.  0.]
 [ 0.  0.  1.]
 [ 0.  0.  0.]]

这是因为one-hot编码沿着0轴完成,必须转置矩阵才能看到先前的编码。当输入维度更高时,情况变得更加复杂,但是思想是相同的:区别在于用于one-hot编码的额外维度的放置。


1
谢谢您的解释,虽然对我来说有些难以理解,所以我会逐个问题地问。您是如何得出 x = tf.constant([[1, 1, 2], [0, 1, 2]]) 会产生一个 4D 向量的呢?...是因为它是由两个二维数组作为元素组成的数组吗? - tinonetic
我觉得我需要阅读。我有太多问题。您能否用二维数组简化您的答案?如果失败,我认为我将不得不回到第一原则......毫无疑问,您的答案可能是正确的。 - tinonetic
没关系。我尝试为每个维度选择不同的尺寸,以避免误解。输入“x”是(2,3)。编码结果为4D,因为我们设置了“N=4”(考虑类别数)。这就是为什么结果是(2,3,4)或(4,3,2),具体取决于位置。 - Maxim
实际上,你也可以通过 x = tf.constant([0, 1, 2]) 看到差异:结果要么是 (3, 4) 要么是 (4, 3) - Maxim
哈哈,我要回去从头开始学习了。我完全不明白这个问题。这与你的答案无关,而是我的知识深度不够。无论如何,谢谢你。我会把它标记为答案,相信你对这个主题有足够的了解...然后我会回来重新阅读 :) - tinonetic
没问题。慢慢来,我已经更新了答案并提供了更简单的示例。 - Maxim

22

对我来说,轴的意思是“你在哪里添加附加数字以增加维度”。至少这是我的解释,并且作为一种记忆提示。

例如,您有 [1,2,3,0,2,1],它的形状为 (1,6)。这意味着它是一个一维数组。one_hot 添加零,并将原始数组中的位置转换为 1,使得原始数组必须比原始数组多一个维度,并且轴告诉函数在哪里添加它,这个新维度将标识示例


axis=1

您添加了第二个维度,第一个维度保持不变。这将导致一个(6,4)的数组。因此,在结果数组中,您使用第一个维度(0)来确定您看到的示例,使用第二个维度(1,新维度)来确定该类是否处于活动状态。newArr[0][1]=1 表示示例 0,类别 1,在这种情况下表示示例 0 属于类别 1。

   0   1   2   3  <- class

[[ 0.  1.  0.  0.]   <- example 0
 [ 0.  0.  1.  0.]   <- example 1
 [ 0.  0.  0.  1.]   <- example 2
 [ 1.  0.  0.  0.]   <- example 3
 [ 0.  0.  1.  0.]   <- example 4
 [ 0.  1.  0.  0.]]  <- example 5

axis=0

当你添加第一个维度时,现有的维度会向后移动。这将产生一个(4,6)数组。因此在结果数组中,您使用第一个维度(0,新维度)来确定该类是否处于活动状态,第二个维度(1)用于确定您看到的是哪个示例。newArr [0] [1] = 0表示类0,示例1,在这种情况下表示示例1不属于类0。
   0   1   2   3   4   5  <- example

[[ 0.  0.  0.  1.  0.  0.]   <- class 0
 [ 1.  0.  0.  0.  0.  1.]   <- class 1
 [ 0.  1.  0.  0.  1.  0.]   <- class 2
 [ 0.  0.  1.  0.  0.  0.]]  <- class 3

解释得很好,但输入数组不是(6,1),而是(1,6)吗? - Nijan
我对你的回答不满意...抱歉。 - nitin bakaya

0
对我来说,我是这样理解的 - (注意documentation中引用的索引仅是类别标签的信息,可以是标量、向量或矩阵) 如果您的索引只是一个标量,则不需要轴。 但是,如果它是一个向量,则可以选择特征和类别的方向2。 在这种情况下,one-hot向量的图像将以行作为深度(类别)和列作为相应的特征(标签),因此该轴的值为0。 同样,如果希望特征x深度,则轴的值为-1。
同样,如果索引是矩阵,则有以下方向选择。

(批处理是您索引中的行)3

batch x features x depth if axis == -1
batch x depth x features if axis == 1
depth x batch x features if axis == 0

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接