Keras的ImageDataGenerator不按预期工作

3
我将尝试使用Keras构建一个自编码器,基于文档中的[这个例子][1]。由于我的数据量很大,我想使用生成器来避免将其加载到内存中。
我的模型如下:
model = Sequential()
model.add(Convolution2D(16, 3, 3, activation='relu', border_mode='same', input_shape=(3, 256, 256)))
model.add(MaxPooling2D((2, 2), border_mode='same'))
model.add(Convolution2D(8, 3, 3, activation='relu', border_mode='same'))
model.add(MaxPooling2D((2, 2), border_mode='same'))
model.add(Convolution2D(8, 3, 3, activation='relu', border_mode='same'))
model.add(MaxPooling2D((2, 2), border_mode='same'))
model.add(Convolution2D(8, 3, 3, activation='relu', border_mode='same'))
model.add(UpSampling2D((2, 2)))
model.add(Convolution2D(8, 3, 3, activation='relu', border_mode='same'))
model.add(UpSampling2D((2, 2)))
model.add(Convolution2D(16, 3, 3, activation='relu'))
model.add(UpSampling2D((2, 2)))
model.add(Convolution2D(1, 3, 3, activation='sigmoid', border_mode='same'))

model.compile(optimizer='adadelta', loss='binary_crossentropy')

我的生成器:

from keras.preprocessing.image import ImageDataGenerator
train_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory('IMAGE DIRECTORY', color_mode='rgb', class_mode='binary', batch_size=32, target_size=(256, 256))

然后拟合模型:

model.fit_generator(
        train_generator,
        samples_per_epoch=1,
        nb_epoch=1,
        verbose=1,
        )

我遇到了这个错误:

异常:检查模型目标时出错:期望的convolution2d_76应该有4维,但得到的数组形状为(32,1)

这看起来像是批次的大小而不是一个样本。我做错了什么?

1个回答

3
错误最可能是由于class_mode='binary'引起的。它使生成器产生二进制类,因此输出的形状为(batch_size, 1),而您的模型产生了四维输出(因为最后一层是卷积层)。
我猜你想让标签成为图像本身。根据flow_from_directoryDirectoryIterator的来源,仅通过更改class_mode是不可能实现的。一个可能的解决方案是这样的:
train_generator_ = train_datagen.flow_from_directory('IMAGE DIRECTORY', color_mode='rgb', class_mode=None, batch_size=32, target_size=(256, 256))
def train_generator():
    for x in train_iterator_:
        yield x, x

注意,我将class_mode设置为None。这使生成器仅返回image而不是tuple(image, label)。然后我定义了一个新的生成器,将图像作为输入和标签同时返回。

太棒了,非常感谢!现在一切都正常工作了。我之前很困惑,因为我认为它是在谈论输入形状,但在实施您的解决方案后,我发现问题出在输出形状上。 - Lester

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接