PyTorch:'ToTensor()'将彩色图像转换为9张灰度图像

4

我发现当我使用'ToTensor'将图片转换后,一张图片会变成9张显示。我查看了官方文档但无法找到原因。所以为什么一张图片会变成9张图片?问题如下图。

a = plt.imread('test.jpg')
plt.imshow(a)
plt.show()

enter image description here

transform = transforms.Compose([transforms.ToTensor()])
b = transform(a)
b = b.view(375,500,3)
plt.imshow(b)

enter image description here

1个回答

3

当您使用transforms.ToTensor()时,默认情况下,它会将输入数组从HWC更改为CHW顺序。对于绘图,您需要将通道推回到最后一个维度。

plt.imshow(b.permute(2, 0, 1))

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接