Python matplotlib,图像数据形状无效。

4

目前我有这段代码来显示三张图片:

imshow(image1, title='1')
imshow(image2, title='2')
imshow(image3, title='3')

它的功能很好。但是我想把它们放在一行而不是列中。

这是我尝试过的代码:

f = plt.figure()
f.add_subplot(1,3,1)
plt.imshow(image1)
f.add_subplot(1,3,2)
plt.imshow(image2)
f.add_subplot(1,3,3)
plt.imshow(image3)

出现错误:

类型错误:无法将CUDA张量转换为numpy。请先使用Tensor.cpu()将张量复制到主机内存。

如果我这样做:

f = plt.figure()
f.add_subplot(1,3,1)
plt.imshow(image1.cpu())
f.add_subplot(1,3,2)
plt.imshow(image2.cpu())
f.add_subplot(1,3,3)
plt.imshow(image3.cpu())

出现错误:

类型错误:图像数据的形状无效(1, 3, 128, 128)

如何修复此问题,或者是否有更简便的实现方法?


1
使用Matplotlib的subplots函数,通过ncols参数指定所需的行数。 - Lith
@Lith 我已经尝试了这段代码:`fig, axs = plt.subplots(nrows=1, ncols=3) axs[0].imshow(image1.cpu()) axs[1].imshow(image2.cpu()) axs[2].imshow(image3.cpu())`但仍然出现TypeError: Invalid shape (1, 3, 128, 128) for image data的错误提示。 - Hekes Pekes
我不了解Pytorch,但似乎cpu()方法正在将数组image转换为另一个维度为(1,3,128,128)的数组(我猜测原数组维度为(128,128)),这对于imshow函数是无效的。该函数的参数必须是一个二维数组(如果您使用RGB数据,则为三维数组),表示图像像素的数据值。 - Lith
1个回答

11

在matplotlib中,函数'imshow'将3通道图像作为(h, w, 3)格式输入,您可以在文档中查看。

看起来您传递了一个“批次”的单张图像(第一维)和该图像的三个通道(第二维)(h和w是第三和第四维度)。

您需要对图像进行重塑或查看操作(在转换为CPU后,尝试使用以下代码:

image1.squeeze().permute(1,2,0)

将得到所需形状的图像(128, 128, 3)。

squeeze()函数将移除第一个维度。premute()函数将重新排列维度,使第一个维度移到第三个位置,而另外两个维度则移到开头。

此外,有关GPU和CPU问题的进一步讨论,请查看此链接:link

希望这能有所帮助。


1
谢谢,那很有帮助! - Hekes Pekes

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接