plt.imshow()中图像数据的维度无效。

6

我正在使用mnist数据集在Keras背景下训练胶囊网络。 训练后,我想显示来自mnist数据集的图像。为了加载图像,使用mnist.load_data()。数据存储为(x_train, y_train),(x_test, y_test)。 现在,为了可视化图像,我的代码如下:

img_path = x_test[1]  
print(img_path.shape)
plt.imshow(img_path)
plt.show()

该代码输出如下内容:
(28, 28, 1)

plt.imshow(img_path)出现的错误如下:

TypeError: Invalid dimensions for image data

如何显示PNG格式的图像。求助!

14
你应该删除多余的通道维度。尝试使用 plt.imshow(np.squeeze(img_path)) - sdcbr
@sdcbr 非常感谢。它有效了。 - Anusha Mehta
这个回答解决了你的问题吗?使用pylab.imshow()显示图像 - iacob
4个回答

7

根据@sdcbr的评论,使用np.squeeze可以减少不必要的维度。如果图像是二维的,则imshow函数工作得很好。如果图像有三个维度,则必须减少一个额外的维度。但是,对于更高维的数据,您将不得不将其减少到二维,因此可能需要多次应用np.squeeze。(或者您可以使用其他维度缩减函数处理更高维数据)

import numpy as np  
import matplotlib.pyplot as plt
img_path = x_test[1]  
print(img_path.shape)
if(len(img_path.shape) == 3):
    plt.imshow(np.squeeze(img_path))
elif(len(img_path.shape) == 2):
    plt.imshow(img_path)
else:
    print("Higher dimensional data")

3

例子:

plt.imshow(test_images[0])

类型错误:图像数据的形状无效(28,28,1)

更正:

plt.imshow((tf.squeeze(test_images[0])))

数字7


2
你可以使用tf.squeeze来从张量的形状中移除大小为1的维度。
plt.imshow( tf.shape( tf.squeeze(x_train) ) )

点击查看TF2.0示例


0

matplotlib.pyplot.imshow()不支持形状为(h, w, 1)的图像。只需通过将图像重新调整为(h, w)来删除图像的最后一个维度:newimage = reshape(img,(h,w))


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接