如何在Pytorch中将标签向量转换为one-hot编码并进行反向转换?
解决该问题的方法是在浏览整个论坛讨论后复制到这里,而不是通过谷歌搜索找到简单的解决方案。
解决该问题的方法是在浏览整个论坛讨论后复制到这里,而不是通过谷歌搜索找到简单的解决方案。
import torch
import numpy as np
labels = torch.randint(0, 10, (10,))
# labels --> one-hot
one_hot = torch.nn.functional.one_hot(labels)
# one-hot --> labels
labels_again = torch.argmax(one_hot, dim=1)
np.testing.assert_equals(labels.numpy(), labels_again.numpy())
因为我不能对已接受的答案进行评论,所以我想补充一下:如果您的目标不包括所有类别(例如,因为您是分批训练),您可以将类别数量作为参数指定:
# labels --> one-hot
one_hot = torch.nn.functional.one_hot(target, num_classes=7)