如何在keras中进行多类图像分类?

8

这是我所做的。我获取了狗/猫图像分类的代码,并编译和运行,得到了80%的准确性。我向训练和验证文件夹中添加了一个更多的类(飞机)文件夹。对以下代码进行了更改。

model.compile(loss='categorical_crossentropy',
              optimizer='rmsprop',
              metrics=['accuracy'])

train_generator = train_datagen.flow_from_directory(
    train_data_dir,
    target_size=(img_width, img_height),
    batch_size=batch_size,
    class_mode='categorical')

validation_generator = test_datagen.flow_from_directory(
    validation_data_dir,
    target_size=(img_width, img_height),
    batch_size=batch_size,
    class_mode='categorical')

binary class_mode 改为 categorical,将损失函数更改为 categorical_crossentropy。还将输出布局从 sigmoid 更改为 softmax。收到以下错误消息。

ValueError: Error when checking target: expected activation_10 to have shape (None, 1) but got array with shape (16, 3)

我需要明确地将训练标签更改为类别,就像下面提到的那样吗?(我从此网站上读到的使用Keras进行多标签分类
train_labels = to_categorical(train_labels, num_classes=num_classes) 

我不确定这里发生了什么,请帮忙。我对深度学习比较新。

模型

model = Sequential()

model.add(Conv2D(32, (3, 3), input_shape=input_shape))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(32, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(64, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Flatten())
model.add(Dense(64))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(1))
model.add(Activation('softmax'))

model.compile(loss='categorical_crossentropy',
              optimizer='rmsprop',
              metrics=['accuracy'])

# this is the augmentation configuration we will use for training
train_datagen = ImageDataGenerator(
    rescale=1. / 255,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True)
# this is the augmentation configuration we will use for testing:
# only rescaling
test_datagen = ImageDataGenerator(rescale=1. / 255)

train_generator = train_datagen.flow_from_directory(
    train_data_dir,
    target_size=(img_width, img_height),
    batch_size=batch_size,
    class_mode='categorical')


validation_generator = test_datagen.flow_from_directory(
    validation_data_dir,
    target_size=(img_width, img_height),
    batch_size=batch_size,
    class_mode='categorical')
model.fit_generator(
    train_generator,
    steps_per_epoch=nb_train_samples // batch_size,
    epochs=epochs,
    validation_data=validation_generator,
    validation_steps=nb_validation_samples // batch_size)
2个回答

11

对于多类分类,最后一个密集层的节点数必须等于类别数,接下来是softmax激活函数,即你的模型的最后两层应该是:

最后一个密集层的节点数必须等于类别数,接下来是softmax激活函数,即你的模型的最后两层应该是:

model.add(Dense(num_classes))
model.add(Activation('softmax'))

此外,你的标签(包括训练和测试)必须进行一次独热编码。假设你最初的猫和狗的标记是整数(0/1),而你的新类别(飞机)最初标记为“2”,那么你应该将它们转换如下:

train_labels = keras.utils.to_categorical(train_labels, num_classes)
test_labels = keras.utils.to_categorical(test_labels, num_classes)

最后,在术语层面上,您正在进行的是多类别而不是多标签分类(我已编辑您帖子的标题)——后者用于样本可能同时属于多个类别的问题。


1
我不需要将其更改为分类,因为我正在使用自己的图像数据集作为输入。将“steps_per_epoch = nb_train_samples // batch_size”更改为“samples_per_epoch = nb_train_samples”解决了这个问题。 - Edwin Varghese

2

对于多类分类问题,神经网络的最后一层大小必须等于类别数目。

例如,对于您的问题(3个类别),代码应该如下所示:

model = Sequential()

model.add(Conv2D(32, (3, 3), input_shape=input_shape))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(32, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(64, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Flatten())
model.add(Dense(64))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(3))
model.add(Activation('softmax'))

1
“softmax” 不是用于多标签分类的正确激活函数。 - DollarAkshay
2
你正在混淆多类别和多标签分类。多标签意味着一张图片可以属于多个类别。如果在最后加上一个softmax层,就是说一个类别的概率取决于其他类别。 - DollarAkshay
不,我不是。 https://zh.wikipedia.org/wiki/Softmax函数 -> Softmax函数用于各种多类分类方法中。 - Vadim
1
你只是盲目地引用链接,而不理解我在说什么。softmax函数用于各种多类别分类方法。是的,softmax是用于多类别,而不是多标签。 - DollarAkshay
1
为什么首先要讨论多标签?问题中是否有任何指示多标签的内容?或者你只是将多类别和多标签混淆了,正如@AkshayLAradhya正确地争论的那样,它们是不同的东西? - desertnaut
显示剩余4条评论

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接