如何在Pytorch中将One-Hot向量转换为标签索引并进行反向转换?

10
如何在Pytorch中将标签向量转换为one-hot编码并进行反向转换?
解决该问题的方法是在浏览整个论坛讨论后复制到这里,而不是通过谷歌搜索找到简单的解决方案。

1
我真的不明白为什么要创建一个线程来复制粘贴另一个论坛上的解决方案。 - Ivan
1
@Ivan https://meta.stackoverflow.com/a/347922/913098 - Gulzar
重点是允许谷歌搜索。在论坛中搜索是浪费时间的,对每个人都是如此。 - Gulzar
2个回答

16

来自Pytorch论坛

import torch
import numpy as np


labels = torch.randint(0, 10, (10,))

# labels --> one-hot 
one_hot = torch.nn.functional.one_hot(labels)
# one-hot --> labels
labels_again = torch.argmax(one_hot, dim=1)

np.testing.assert_equals(labels.numpy(), labels_again.numpy())

2
请查看此答案,了解需要指定类数量的情况(这是大多数情况)。 - Gulzar

8

因为我不能对已接受的答案进行评论,所以我想补充一下:如果您的目标不包括所有类别(例如,因为您是分批训练),您可以将类别数量作为参数指定:

# labels --> one-hot 
one_hot = torch.nn.functional.one_hot(target, num_classes=7)

您现在拥有足够的声望来发表评论。我建议您这样做。 - DiMithras

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接