我曾在使用灰度图像与VGG16一起工作时遇到了同样的问题。我解决了这个问题,方法如下:
假设我们的训练图像是在train_gray_images中,每行都包含未滚动的灰度图像强度值。如果我们直接将它传递给fit函数,会产生错误,因为fit函数期望的是一个3通道(RGB)图像数据集,而不是灰度数据集。因此,在传递给fit函数之前,请执行以下操作:
创建一个虚拟的RGB图像数据集,就像灰度数据集一样具有相同的形状(这里使用dummy_RGB_image)。唯一的区别是这里使用的通道数是3。
dummy_RGB_images = np.ndarray(shape=(train_gray_images.shape[0], train_gray_images.shape[1], train_gray_images.shape[2], 3), dtype= np.uint8)
因此,只需将整个数据集复制3次到“dummy_RGB_images”的每个通道中即可。(这里的维度为[样例数量,高度,宽度,通道])
dummy_RGB_images[:, :, :, 0] = train_gray_images[:, :, :, 0]
dummy_RGB_images[:, :, :, 1] = train_gray_images[:, :, :, 0]
dummy_RGB_images[:, :, :, 2] = train_gray_images[:, :, :, 0]
最后,传递 dummy_RGB_images
而不是灰度数据集,如下所示:
model.fit(dummy_RGB_images,...)